diff --git a/.gitignore b/.gitignore index 9ae0d9c96f188bc6357832f22b4125694302b104..be75938ec401b1d72fa54773c85191aaac7d7f35 100644 --- a/.gitignore +++ b/.gitignore @@ -6,7 +6,7 @@ node_modules /bazel-* /bazel_pip /tools/python_bin_path.sh -/tools/git/gen +/tensorflow/tools/git/gen /pip_test /_python_build *.pyc @@ -22,3 +22,15 @@ Pods Podfile.lock *.pbxproj *.xcworkspacedata +/tensorflow/contrib/lite/downloads/** +/tensorflow/contrib/lite/gen/** +/tensorflow/contrib/lite/examples/ios/simple/data/*.txt +/tensorflow/contrib/lite/examples/ios/simple/data/*.tflite +xcuserdata/** + +# Android +.gradle +.idea +*.iml +local.properties +gradleBuild diff --git a/AUTHORS b/AUTHORS index a46ae7e616ab3a420d9fb2691ee8d8650032a39f..aa4be5169dcc68c579863e8ba6307cd00e9f9a68 100644 --- a/AUTHORS +++ b/AUTHORS @@ -7,4 +7,4 @@ # The email address is not required for organizations. Google Inc. -Yuan Tang terrytangyuan@gmail.com +Yuan Tang diff --git a/CODEOWNERS b/CODEOWNERS index 6e4b4f5f3f751ca9ab39a5772458349b00f06d57..57a4df40e651f45dc03493af631d73332e46c182 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -11,6 +11,7 @@ # NEED OWNER: tensorflow/contrib/avro/* #tensorflow/contrib/batching/* @alextp @chrisolston #tensorflow/contrib/bayesflow/* @ebrevdo @rsepassi @jvdillon +#tensorflow/contrib/boosted_trees/* @sshrdp @yk5 @nataliaponomareva #tensorflow/contrib/cmake/* @mrry @benoitsteiner #tensorflow/contrib/copy_graph/* @tucker @poxvoculi #tensorflow/contrib/crf/* @kentonl diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md index 10fd595fec7f240c3fdc871e1f32cc83f2ffd46d..ff11d131409b65880f16b80f9fe38dc39ac0e5fa 100644 --- a/CODE_OF_CONDUCT.md +++ b/CODE_OF_CONDUCT.md @@ -42,7 +42,7 @@ The Code of Conduct also applies within project spaces and in public spaces when Conflicts in an open source project can take many forms, from someone having a bad day and using harsh and hurtful language in the issue queue, to more serious instances such as sexist/racist statements or threats of violence, and everything in between. -If the behaviour is threatening or harassing, or for other reasons requires immediate escalation, please see below. +If the behavior is threatening or harassing, or for other reasons requires immediate escalation, please see below. However, for the vast majority of issues, we aim to empower individuals to first resolve conflicts themselves, asking for help when needed, and only after that fails to escalate further. This approach gives people more control over the outcome of their dispute. @@ -55,14 +55,14 @@ If you are experiencing or witnessing conflict, we ask you to use the following ## Reporting Violations -Violations of the Code of Conduct can be reported to TensorFlow’s Project Steward at conduct@tensorflow.org. The Project Steward will determine whether the Code of Conduct was violated, and will issue an appropriate sanction, possibly including a written warning or expulsion from the project, project sponsored spaces, or project forums. We ask that you make a good-faith effort to resolve your conflict via the conflict resolution policy before submitting a report. +Violations of the Code of Conduct can be reported to TensorFlow’s Project Stewards, Edd Wilder-James (ewj@google.com) and Sarah Novotny (sarahnovotny@google.com). The Project Steward will determine whether the Code of Conduct was violated, and will issue an appropriate sanction, possibly including a written warning or expulsion from the project, project sponsored spaces, or project forums. We ask that you make a good-faith effort to resolve your conflict via the conflict resolution policy before submitting a report. Violations of the Code of Conduct can occur in any setting, even those unrelated to the project. We will only consider complaints about conduct that has occurred within one year of the report. ## Enforcement -If the Project Steward receives a report alleging a violation of the Code of Conduct, the Project Steward will notify the accused of the report, and provide them an opportunity to discuss the report before a sanction is issued. The Project Steward will do their utmost to keep the reporter anonymous. If the act is ongoing (such as someone engaging in harassment), or involves a threat to anyone's safety (e.g. threats of violence), the Project Steward may issue sanctions without notice. +If the Project Stewards receive a report alleging a violation of the Code of Conduct, the Project Stewards will notify the accused of the report, and provide them an opportunity to discuss the report before a sanction is issued. The Project Stewards will do their utmost to keep the reporter anonymous. If the act is ongoing (such as someone engaging in harassment), or involves a threat to anyone's safety (e.g. threats of violence), the Project Stewards may issue sanctions without notice. ## Attribution diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 43abdaafbf45379430920cd027b26299cd62553b..1b537ca73cc94e992e7537fe69c8d0cc8fd13102 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -114,6 +114,7 @@ pylint --rcfile=/tmp/pylintrc myfile.py * [Google Java Style Guide](https://google.github.io/styleguide/javaguide.html) * [Google JavaScript Style Guide](https://google.github.io/styleguide/jsguide.html) * [Google Shell Style Guide](https://google.github.io/styleguide/shell.xml) +* [Google Objective-C Style Guide](http://google.github.io/styleguide/objcguide.html) #### Running sanity check diff --git a/ISSUE_TEMPLATE.md b/ISSUE_TEMPLATE.md index 2bf2c754cf64ec3bac22a22fbafcebbd4dc54bf4..1a401997c649518766acb2ebb0dea1c128bd0ba4 100644 --- a/ISSUE_TEMPLATE.md +++ b/ISSUE_TEMPLATE.md @@ -19,6 +19,7 @@ If you open a GitHub issue, here is our policy: - **TensorFlow version (use command below)**: - **Python version**: - **Bazel version (if compiling from source)**: +- **GCC/Compiler version (if compiling from source)**: - **CUDA/cuDNN version**: - **GPU model and memory**: - **Exact command to reproduce**: diff --git a/README.md b/README.md index 24bbb6cec10e16c7b6ae37b7cf8b6f90ebe5e5dd..aff3427bddb307aea6d6c2466eac14c9edffcc32 100644 --- a/README.md +++ b/README.md @@ -73,11 +73,11 @@ $ python ## For more information -* [TensorFlow website](https://www.tensorflow.org) +* [TensorFlow Website](https://www.tensorflow.org) * [TensorFlow White Papers](https://www.tensorflow.org/about/bib) * [TensorFlow Model Zoo](https://github.com/tensorflow/models) * [TensorFlow MOOC on Udacity](https://www.udacity.com/course/deep-learning--ud730) -* [TensorFlow course at Stanford](https://web.stanford.edu/class/cs20si) +* [TensorFlow Course at Stanford](https://web.stanford.edu/class/cs20si) Learn more about the TensorFlow community at the [community page of tensorflow.org](https://www.tensorflow.org/community) for a few ways to participate. diff --git a/RELEASE.md b/RELEASE.md index d8db1f72004b5d944e3035a0f33dfc34a674b7ee..e04bd3fc505d51ade9e9fa12c822cb695e90b4f3 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -494,7 +494,7 @@ answered questions, and were part of inspiring discussions. This release contains contributions from many people at Google, as well as: A. Besir Kurtulmus, Adal Chiriliuc, @akash, Alec-Desouza, Alex Rothberg, Alex -Sergeev, Alexander Heinecke, Allen Guo, Andreas Madsen, Ankesh Anand, Anton +Sergeev, Alexander Heinecke, Allen Guo, Andreas Madsen, Ankesh Anand, Anton Loss, @Aravind, @Arie, Ashutosh Das, AuréLien Geron, Bairen Yi, @bakunyo, Ben Visser, Brady Zhou, Calpa Liu, Changming Sun, Chih Cheng Liang, Christopher Berner, Clark Zinzow, @Conchylicultor, Dan Ellis, Dan J, Dan Jarvis, Daniel diff --git a/configure.py b/configure.py index bc7859fee4d2aca9bd7ca24e85ad820c49e01e4a..7a9d315eb0ededf273d1cee3d06cb9b53864a834 100644 --- a/configure.py +++ b/configure.py @@ -25,15 +25,19 @@ import re import subprocess import sys +# pylint: disable=g-import-not-at-top try: from shutil import which except ImportError: from distutils.spawn import find_executable as which +# pylint: enable=g-import-not-at-top _TF_BAZELRC = os.path.join(os.path.dirname(os.path.abspath(__file__)), '.tf_configure.bazelrc') -_DEFAULT_CUDA_VERSION = '8.0' -_DEFAULT_CUDNN_VERSION = '6' +_TF_WORKSPACE = os.path.join(os.path.dirname(os.path.abspath(__file__)), + 'WORKSPACE') +_DEFAULT_CUDA_VERSION = '9.0' +_DEFAULT_CUDNN_VERSION = '7' _DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,5.2' _DEFAULT_CUDA_PATH = '/usr/local/cuda' _DEFAULT_CUDA_PATH_LINUX = '/opt/cuda' @@ -41,6 +45,14 @@ _DEFAULT_CUDA_PATH_WIN = ('C:/Program Files/NVIDIA GPU Computing ' 'Toolkit/CUDA/v%s' % _DEFAULT_CUDA_VERSION) _TF_OPENCL_VERSION = '1.2' _DEFAULT_COMPUTECPP_TOOLKIT_PATH = '/usr/local/computecpp' +_DEFAULT_TRISYCL_INCLUDE_DIR = '/usr/local/triSYCL/include' +_SUPPORTED_ANDROID_NDK_VERSIONS = [10, 11, 12, 13, 14, 15] + +_DEFAULT_PROMPT_ASK_ATTEMPTS = 10 + + +class UserInputError(Exception): + pass def is_windows(): @@ -155,7 +167,7 @@ def get_python_path(environ_cp, python_bin_path): try: library_paths = run_shell( [python_bin_path, '-c', - 'import site; print("\\n".join(site.getsitepackages()))']).split("\n") + 'import site; print("\\n".join(site.getsitepackages()))']).split('\n') except subprocess.CalledProcessError: library_paths = [run_shell( [python_bin_path, '-c', @@ -226,17 +238,9 @@ def setup_python(environ_cp): # Set-up env variables used by python_configure.bzl write_action_env_to_bazelrc('PYTHON_BIN_PATH', python_bin_path) write_action_env_to_bazelrc('PYTHON_LIB_PATH', python_lib_path) - write_to_bazelrc('build --define PYTHON_BIN_PATH="%s"' % python_bin_path) - write_to_bazelrc('build --define PYTHON_LIB_PATH="%s"' % python_lib_path) write_to_bazelrc('build --force_python=py%s' % python_major_version) write_to_bazelrc('build --host_force_python=py%s' % python_major_version) write_to_bazelrc('build --python_path=\"%s"' % python_bin_path) - write_to_bazelrc('test --force_python=py%s' % python_major_version) - write_to_bazelrc('test --host_force_python=py%s' % python_major_version) - write_to_bazelrc('test --define PYTHON_BIN_PATH="%s"' % python_bin_path) - write_to_bazelrc('test --define PYTHON_LIB_PATH="%s"' % python_lib_path) - write_to_bazelrc('run --define PYTHON_BIN_PATH="%s"' % python_bin_path) - write_to_bazelrc('run --define PYTHON_LIB_PATH="%s"' % python_lib_path) environ_cp['PYTHON_BIN_PATH'] = python_bin_path # Write tools/python_bin_path.sh @@ -485,7 +489,14 @@ def set_cc_opt_flags(environ_cp): cc_opt_flags = get_from_env_or_user_or_default(environ_cp, 'CC_OPT_FLAGS', question, default_cc_opt_flags) for opt in cc_opt_flags.split(): - write_to_bazelrc('build:opt --cxxopt=%s --copt=%s' % (opt, opt)) + write_to_bazelrc('build:opt --copt=%s' % opt) + # It should be safe on the same build host. + write_to_bazelrc('build:opt --host_copt=-march=native') + write_to_bazelrc('build:opt --define with_default_optimizations=true') + # TODO(mikecase): Remove these default defines once we are able to get + # TF Lite targets building without them. + write_to_bazelrc('build --copt=-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK') + write_to_bazelrc('build --host_copt=-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK') def set_tf_cuda_clang(environ_cp): @@ -555,6 +566,218 @@ def set_clang_cuda_compiler_path(environ_cp): clang_cuda_compiler_path) +def prompt_loop_or_load_from_env( + environ_cp, + var_name, + var_default, + ask_for_var, + check_success, + error_msg, + suppress_default_error=False, + n_ask_attempts=_DEFAULT_PROMPT_ASK_ATTEMPTS +): + """Loop over user prompts for an ENV param until receiving a valid response. + + For the env param var_name, read from the environment or verify user input + until receiving valid input. When done, set var_name in the environ_cp to its + new value. + + Args: + environ_cp: (Dict) copy of the os.environ. + var_name: (String) string for name of environment variable, e.g. "TF_MYVAR". + var_default: (String) default value string. + ask_for_var: (String) string for how to ask for user input. + check_success: (Function) function that takes one argument and returns a + boolean. Should return True if the value provided is considered valid. May + contain a complex error message if error_msg does not provide enough + information. In that case, set suppress_default_error to True. + error_msg: (String) String with one and only one '%s'. Formatted with each + invalid response upon check_success(input) failure. + suppress_default_error: (Bool) Suppress the above error message in favor of + one from the check_success function. + n_ask_attempts: (Integer) Number of times to query for valid input before + raising an error and quitting. + + Returns: + [String] The value of var_name after querying for input. + + Raises: + UserInputError: if a query has been attempted n_ask_attempts times without + success, assume that the user has made a scripting error, and will continue + to provide invalid input. Raise the error to avoid infinitely looping. + """ + default = environ_cp.get(var_name) or var_default + full_query = '%s [Default is %s]: ' % ( + ask_for_var, + default, + ) + + for _ in range(n_ask_attempts): + val = get_from_env_or_user_or_default(environ_cp, + var_name, + full_query, + default) + if check_success(val): + break + if not suppress_default_error: + print(error_msg % val) + environ_cp[var_name] = '' + else: + raise UserInputError('Invalid %s setting was provided %d times in a row. ' + 'Assuming to be a scripting mistake.' % + (var_name, n_ask_attempts)) + + environ_cp[var_name] = val + return val + + +def create_android_ndk_rule(environ_cp): + """Set ANDROID_NDK_HOME and write Android NDK WORKSPACE rule.""" + if is_windows() or is_cygwin(): + default_ndk_path = cygpath('%s/Android/Sdk/ndk-bundle' % + environ_cp['APPDATA']) + elif is_macos(): + default_ndk_path = '%s/library/Android/Sdk/ndk-bundle' % environ_cp['HOME'] + else: + default_ndk_path = '%s/Android/Sdk/ndk-bundle' % environ_cp['HOME'] + + def valid_ndk_path(path): + return (os.path.exists(path) and + os.path.exists(os.path.join(path, 'source.properties'))) + + android_ndk_home_path = prompt_loop_or_load_from_env( + environ_cp, + var_name='ANDROID_NDK_HOME', + var_default=default_ndk_path, + ask_for_var='Please specify the home path of the Android NDK to use.', + check_success=valid_ndk_path, + error_msg=('The path %s or its child file "source.properties" ' + 'does not exist.') + ) + + write_android_ndk_workspace_rule(android_ndk_home_path) + + +def create_android_sdk_rule(environ_cp): + """Set Android variables and write Android SDK WORKSPACE rule.""" + if is_windows() or is_cygwin(): + default_sdk_path = cygpath('%s/Android/Sdk' % environ_cp['APPDATA']) + elif is_macos(): + default_sdk_path = '%s/library/Android/Sdk/ndk-bundle' % environ_cp['HOME'] + else: + default_sdk_path = '%s/Android/Sdk' % environ_cp['HOME'] + + def valid_sdk_path(path): + return (os.path.exists(path) and + os.path.exists(os.path.join(path, 'platforms')) and + os.path.exists(os.path.join(path, 'build-tools'))) + + android_sdk_home_path = prompt_loop_or_load_from_env( + environ_cp, + var_name='ANDROID_SDK_HOME', + var_default=default_sdk_path, + ask_for_var='Please specify the home path of the Android SDK to use.', + check_success=valid_sdk_path, + error_msg=('Either %s does not exist, or it does not contain the ' + 'subdirectories "platforms" and "build-tools".')) + + platforms = os.path.join(android_sdk_home_path, 'platforms') + api_levels = sorted(os.listdir(platforms)) + api_levels = [x.replace('android-', '') for x in api_levels] + + def valid_api_level(api_level): + return os.path.exists(os.path.join(android_sdk_home_path, + 'platforms', + 'android-' + api_level)) + + android_api_level = prompt_loop_or_load_from_env( + environ_cp, + var_name='ANDROID_API_LEVEL', + var_default=api_levels[-1], + ask_for_var=('Please specify the Android SDK API level to use. ' + '[Available levels: %s]') % api_levels, + check_success=valid_api_level, + error_msg='Android-%s is not present in the SDK path.') + + build_tools = os.path.join(android_sdk_home_path, 'build-tools') + versions = sorted(os.listdir(build_tools)) + + def valid_build_tools(version): + return os.path.exists(os.path.join(android_sdk_home_path, + 'build-tools', + version)) + + android_build_tools_version = prompt_loop_or_load_from_env( + environ_cp, + var_name='ANDROID_BUILD_TOOLS_VERSION', + var_default=versions[-1], + ask_for_var=('Please specify an Android build tools version to use. ' + '[Available versions: %s]') % versions, + check_success=valid_build_tools, + error_msg=('The selected SDK does not have build-tools version %s ' + 'available.')) + + write_android_sdk_workspace_rule(android_sdk_home_path, + android_build_tools_version, + android_api_level) + + +def write_android_sdk_workspace_rule(android_sdk_home_path, + android_build_tools_version, + android_api_level): + print('Writing android_sdk_workspace rule.\n') + with open(_TF_WORKSPACE, 'a') as f: + f.write(""" +android_sdk_repository( + name="androidsdk", + api_level=%s, + path="%s", + build_tools_version="%s")\n +""" % (android_api_level, android_sdk_home_path, android_build_tools_version)) + + +def write_android_ndk_workspace_rule(android_ndk_home_path): + print('Writing android_ndk_workspace rule.') + ndk_api_level = check_ndk_level(android_ndk_home_path) + if int(ndk_api_level) not in _SUPPORTED_ANDROID_NDK_VERSIONS: + print('WARNING: The API level of the NDK in %s is %s, which is not ' + 'supported by Bazel (officially supported versions: %s). Please use ' + 'another version. Compiling Android targets may result in confusing ' + 'errors.\n' % (android_ndk_home_path, ndk_api_level, + _SUPPORTED_ANDROID_NDK_VERSIONS)) + with open(_TF_WORKSPACE, 'a') as f: + f.write(""" +android_ndk_repository( + name="androidndk", + path="%s", + api_level=%s)\n +""" % (android_ndk_home_path, ndk_api_level)) + + +def check_ndk_level(android_ndk_home_path): + """Check the revision number of an Android NDK path.""" + properties_path = '%s/source.properties' % android_ndk_home_path + if is_windows() or is_cygwin(): + properties_path = cygpath(properties_path) + with open(properties_path, 'r') as f: + filedata = f.read() + + revision = re.search(r'Pkg.Revision = (\d+)', filedata) + if revision: + return revision.group(1) + return None + + +def workspace_has_any_android_rule(): + """Check the WORKSPACE for existing android_*_repository rules.""" + with open(_TF_WORKSPACE, 'r') as f: + workspace = f.read() + has_any_rule = re.search(r'^android_[ns]dk_repository', + workspace, + re.MULTILINE) + return has_any_rule + + def set_gcc_host_compiler_path(environ_cp): """Set GCC_HOST_COMPILER_PATH.""" default_gcc_host_compiler_path = which('gcc') or '' @@ -564,23 +787,16 @@ def set_gcc_host_compiler_path(environ_cp): # os.readlink is only available in linux default_gcc_host_compiler_path = os.path.realpath(cuda_bin_symlink) - ask_gcc_path = ( - 'Please specify which gcc should be used by nvcc as the ' - 'host compiler. [Default is %s]: ') % default_gcc_host_compiler_path - while True: - gcc_host_compiler_path = get_from_env_or_user_or_default( - environ_cp, 'GCC_HOST_COMPILER_PATH', ask_gcc_path, - default_gcc_host_compiler_path) - - if os.path.exists(gcc_host_compiler_path): - break - - # Reset and retry - print('Invalid gcc path. %s cannot be found' % gcc_host_compiler_path) - environ_cp['GCC_HOST_COMPILER_PATH'] = '' + gcc_host_compiler_path = prompt_loop_or_load_from_env( + environ_cp, + var_name='GCC_HOST_COMPILER_PATH', + var_default=default_gcc_host_compiler_path, + ask_for_var= + 'Please specify which gcc should be used by nvcc as the host compiler.', + check_success=os.path.exists, + error_msg='Invalid gcc path. %s cannot be found.', + ) - # Set GCC_HOST_COMPILER_PATH - environ_cp['GCC_HOST_COMPILER_PATH'] = gcc_host_compiler_path write_action_env_to_bazelrc('GCC_HOST_COMPILER_PATH', gcc_host_compiler_path) @@ -635,7 +851,7 @@ def set_tf_cuda_version(environ_cp): write_action_env_to_bazelrc('TF_CUDA_VERSION', tf_cuda_version) -def set_tf_cunn_version(environ_cp): +def set_tf_cudnn_version(environ_cp): """Set CUDNN_INSTALL_PATH and TF_CUDNN_VERSION.""" ask_cudnn_version = ( 'Please specify the cuDNN version you want to use. ' @@ -808,102 +1024,153 @@ def set_other_cuda_vars(environ_cp): def set_host_cxx_compiler(environ_cp): """Set HOST_CXX_COMPILER.""" default_cxx_host_compiler = which('g++') or '' - ask_cxx_host_compiler = ( - 'Please specify which C++ compiler should be used as' - ' the host C++ compiler. [Default is %s]: ') % default_cxx_host_compiler - while True: - host_cxx_compiler = get_from_env_or_user_or_default( - environ_cp, 'HOST_CXX_COMPILER', ask_cxx_host_compiler, - default_cxx_host_compiler) - if os.path.exists(host_cxx_compiler): - break - - # Reset and retry - print('Invalid C++ compiler path. %s cannot be found' % host_cxx_compiler) - environ_cp['HOST_CXX_COMPILER'] = '' + host_cxx_compiler = prompt_loop_or_load_from_env( + environ_cp, + var_name='HOST_CXX_COMPILER', + var_default=default_cxx_host_compiler, + ask_for_var=('Please specify which C++ compiler should be used as the ' + 'host C++ compiler.'), + check_success=os.path.exists, + error_msg='Invalid C++ compiler path. %s cannot be found.', + ) - # Set HOST_CXX_COMPILER - environ_cp['HOST_CXX_COMPILER'] = host_cxx_compiler write_action_env_to_bazelrc('HOST_CXX_COMPILER', host_cxx_compiler) def set_host_c_compiler(environ_cp): """Set HOST_C_COMPILER.""" default_c_host_compiler = which('gcc') or '' - ask_c_host_compiler = ( - 'Please specify which C compiler should be used as the' - ' host C compiler. [Default is %s]: ') % default_c_host_compiler - while True: - host_c_compiler = get_from_env_or_user_or_default( - environ_cp, 'HOST_C_COMPILER', ask_c_host_compiler, - default_c_host_compiler) - if os.path.exists(host_c_compiler): - break - - # Reset and retry - print('Invalid C compiler path. %s cannot be found' % host_c_compiler) - environ_cp['HOST_C_COMPILER'] = '' + host_c_compiler = prompt_loop_or_load_from_env( + environ_cp, + var_name='HOST_C_COMPILER', + var_default=default_c_host_compiler, + ask_for_var=('Please specify which C compiler should be used as the host' + 'C compiler.'), + check_success=os.path.exists, + error_msg='Invalid C compiler path. %s cannot be found.', + ) - # Set HOST_C_COMPILER - environ_cp['HOST_C_COMPILER'] = host_c_compiler write_action_env_to_bazelrc('HOST_C_COMPILER', host_c_compiler) def set_computecpp_toolkit_path(environ_cp): """Set COMPUTECPP_TOOLKIT_PATH.""" - ask_computecpp_toolkit_path = ('Please specify the location where ComputeCpp ' - 'for SYCL %s is installed. [Default is %s]: ' - ) % (_TF_OPENCL_VERSION, - _DEFAULT_COMPUTECPP_TOOLKIT_PATH) - while True: - computecpp_toolkit_path = get_from_env_or_user_or_default( - environ_cp, 'COMPUTECPP_TOOLKIT_PATH', ask_computecpp_toolkit_path, - _DEFAULT_COMPUTECPP_TOOLKIT_PATH) + def toolkit_exists(toolkit_path): + """Check if a computecpp toolkit path is valid.""" if is_linux(): sycl_rt_lib_path = 'lib/libComputeCpp.so' else: sycl_rt_lib_path = '' - sycl_rt_lib_path_full = os.path.join(computecpp_toolkit_path, + sycl_rt_lib_path_full = os.path.join(toolkit_path, sycl_rt_lib_path) - if os.path.exists(sycl_rt_lib_path_full): - break + exists = os.path.exists(sycl_rt_lib_path_full) + if not exists: + print('Invalid SYCL %s library path. %s cannot be found' % + (_TF_OPENCL_VERSION, sycl_rt_lib_path_full)) + return exists - print('Invalid SYCL %s library path. %s cannot be found' % - (_TF_OPENCL_VERSION, sycl_rt_lib_path_full)) - environ_cp['COMPUTECPP_TOOLKIT_PATH'] = '' + computecpp_toolkit_path = prompt_loop_or_load_from_env( + environ_cp, + var_name='COMPUTECPP_TOOLKIT_PATH', + var_default=_DEFAULT_COMPUTECPP_TOOLKIT_PATH, + ask_for_var=( + 'Please specify the location where ComputeCpp for SYCL %s is ' + 'installed.' % _TF_OPENCL_VERSION), + check_success=toolkit_exists, + error_msg='Invalid SYCL compiler path. %s cannot be found.', + suppress_default_error=True) - # Set COMPUTECPP_TOOLKIT_PATH - environ_cp['COMPUTECPP_TOOLKIT_PATH'] = computecpp_toolkit_path write_action_env_to_bazelrc('COMPUTECPP_TOOLKIT_PATH', computecpp_toolkit_path) +def set_trisycl_include_dir(environ_cp): + """Set TRISYCL_INCLUDE_DIR""" + ask_trisycl_include_dir = ('Please specify the location of the triSYCL ' + 'include directory. (Use --config=sycl_trisycl ' + 'when building with Bazel) ' + '[Default is %s]: ' + ) % (_DEFAULT_TRISYCL_INCLUDE_DIR) + while True: + trisycl_include_dir = get_from_env_or_user_or_default( + environ_cp, 'TRISYCL_INCLUDE_DIR', ask_trisycl_include_dir, + _DEFAULT_TRISYCL_INCLUDE_DIR) + if os.path.exists(trisycl_include_dir): + break + + print('Invalid triSYCL include directory, %s cannot be found' + % (trisycl_include_dir)) + + # Set TRISYCL_INCLUDE_DIR + environ_cp['TRISYCL_INCLUDE_DIR'] = trisycl_include_dir + write_action_env_to_bazelrc('TRISYCL_INCLUDE_DIR', + trisycl_include_dir) + +def set_trisycl_include_dir(environ_cp): + """Set TRISYCL_INCLUDE_DIR.""" + ask_trisycl_include_dir = ('Please specify the location of the triSYCL ' + 'include directory. (Use --config=sycl_trisycl ' + 'when building with Bazel) ' + '[Default is %s]: ') % _DEFAULT_TRISYCL_INCLUDE_DIR + while True: + trisycl_include_dir = get_from_env_or_user_or_default( + environ_cp, 'TRISYCL_INCLUDE_DIR', ask_trisycl_include_dir, + _DEFAULT_TRISYCL_INCLUDE_DIR) + if os.path.exists(trisycl_include_dir): + break + + print('Invalid triSYCL include directory, %s cannot be found' + % (trisycl_include_dir)) + + # Set TRISYCL_INCLUDE_DIR + environ_cp['TRISYCL_INCLUDE_DIR'] = trisycl_include_dir + write_action_env_to_bazelrc('TRISYCL_INCLUDE_DIR', + trisycl_include_dir) + + +def set_trisycl_include_dir(environ_cp): + """Set TRISYCL_INCLUDE_DIR.""" + + trisycl_include_dir = prompt_loop_or_load_from_env( + environ_cp, + var_name='TRISYCL_INCLUDE_DIR', + var_default=_DEFAULT_TRISYCL_INCLUDE_DIR, + ask_for_var=('Please specify the location of the triSYCL include ' + 'directory. (Use --config=sycl_trisycl when building with ' + 'Bazel)'), + check_success=os.path.exists, + error_msg='Invalid trySYCL include directory. %s cannot be found.', + suppress_default_error=True) + + write_action_env_to_bazelrc('TRISYCL_INCLUDE_DIR', trisycl_include_dir) + def set_mpi_home(environ_cp): """Set MPI_HOME.""" + default_mpi_home = which('mpirun') or which('mpiexec') or '' default_mpi_home = os.path.dirname(os.path.dirname(default_mpi_home)) - ask_mpi_home = ('Please specify the MPI toolkit folder. [Default is %s]: ' - ) % default_mpi_home - while True: - mpi_home = get_from_env_or_user_or_default(environ_cp, 'MPI_HOME', - ask_mpi_home, default_mpi_home) - - if os.path.exists(os.path.join(mpi_home, 'include')) and os.path.exists( - os.path.join(mpi_home, 'lib')): - break - - print('Invalid path to the MPI Toolkit. %s or %s cannot be found' % - (os.path.join(mpi_home, 'include'), - os.path.exists(os.path.join(mpi_home, 'lib')))) - environ_cp['MPI_HOME'] = '' + def valid_mpi_path(mpi_home): + exists = (os.path.exists(os.path.join(mpi_home, 'include')) and + os.path.exists(os.path.join(mpi_home, 'lib'))) + if not exists: + print('Invalid path to the MPI Toolkit. %s or %s cannot be found' % + (os.path.join(mpi_home, 'include'), + os.path.exists(os.path.join(mpi_home, 'lib')))) + return exists - # Set MPI_HOME - environ_cp['MPI_HOME'] = str(mpi_home) + _ = prompt_loop_or_load_from_env( + environ_cp, + var_name='MPI_HOME', + var_default=default_mpi_home, + ask_for_var='Please specify the MPI toolkit folder.', + check_success=valid_mpi_path, + error_msg='', + suppress_default_error=True) def set_other_mpi_vars(environ_cp): @@ -941,13 +1208,12 @@ def set_other_mpi_vars(environ_cp): def set_mkl(): write_to_bazelrc('build:mkl --define using_mkl=true') write_to_bazelrc('build:mkl -c opt') - write_to_bazelrc('build:mkl --copt="-DEIGEN_USE_VML"') print( 'Add "--config=mkl" to your bazel command to build with MKL ' 'support.\nPlease note that MKL on MacOS or windows is still not ' 'supported.\nIf you would like to use a local MKL instead of ' 'downloading, please set the environment variable \"TF_MKL_ROOT\" every ' - 'time before build.') + 'time before build.\n') def set_monolithic(): @@ -976,6 +1242,19 @@ def create_android_bazelrc_configs(): write_to_bazelrc('build:android_arm64 --cpu=arm64-v8a') +def set_grpc_build_flags(): + write_to_bazelrc('build --define grpc_no_ares=true') + +def set_windows_build_flags(): + if is_windows(): + # The non-monolithic build is not supported yet + write_to_bazelrc('build --config monolithic') + # Suppress warning messages + write_to_bazelrc('build --copt=-w --host_copt=-w') + # Output more verbose information when something goes wrong + write_to_bazelrc('build --verbose_failures') + + def main(): # Make a copy of os.environ to be clear when functions and getting and setting # environment variables. @@ -993,8 +1272,9 @@ def main(): environ_cp['TF_NEED_GCP'] = '0' environ_cp['TF_NEED_HDFS'] = '0' environ_cp['TF_NEED_JEMALLOC'] = '0' + environ_cp['TF_NEED_OPENCL_SYCL'] = '0' + environ_cp['TF_NEED_COMPUTECPP'] = '0' environ_cp['TF_NEED_OPENCL'] = '0' - environ_cp['TF_NEED_S3'] = '0' environ_cp['TF_CUDA_CLANG'] = '0' if is_macos(): @@ -1015,17 +1295,21 @@ def main(): set_build_var(environ_cp, 'TF_NEED_VERBS', 'VERBS', 'with_verbs_support', False, 'verbs') - set_action_env_var(environ_cp, 'TF_NEED_OPENCL', 'OpenCL', False) - if environ_cp.get('TF_NEED_OPENCL') == '1': + set_action_env_var(environ_cp, 'TF_NEED_OPENCL_SYCL', 'OpenCL SYCL', False) + if environ_cp.get('TF_NEED_OPENCL_SYCL') == '1': set_host_cxx_compiler(environ_cp) set_host_c_compiler(environ_cp) - set_computecpp_toolkit_path(environ_cp) + set_action_env_var(environ_cp, 'TF_NEED_COMPUTECPP', 'ComputeCPP', True) + if environ_cp.get('TF_NEED_COMPUTECPP') == '1': + set_computecpp_toolkit_path(environ_cp) + else: + set_trisycl_include_dir(environ_cp) set_action_env_var(environ_cp, 'TF_NEED_CUDA', 'CUDA', False) if (environ_cp.get('TF_NEED_CUDA') == '1' and 'TF_CUDA_CONFIG_REPO' not in environ_cp): set_tf_cuda_version(environ_cp) - set_tf_cunn_version(environ_cp) + set_tf_cudnn_version(environ_cp) set_tf_cuda_compute_capabilities(environ_cp) set_tf_cuda_clang(environ_cp) @@ -1044,10 +1328,29 @@ def main(): set_mpi_home(environ_cp) set_other_mpi_vars(environ_cp) + set_grpc_build_flags() set_cc_opt_flags(environ_cp) set_mkl() set_monolithic() + set_windows_build_flags() create_android_bazelrc_configs() + if workspace_has_any_android_rule(): + print('The WORKSPACE file has at least one of ["android_sdk_repository", ' + '"android_ndk_repository"] already set. Will not ask to help ' + 'configure the WORKSPACE. Please delete the existing rules to ' + 'activate the helper.\n') + else: + if get_var( + environ_cp, 'TF_SET_ANDROID_WORKSPACE', 'android workspace', + False, + ('Would you like to interactively configure ./WORKSPACE for ' + 'Android builds?'), + 'Searching for NDK and SDK installations.', + 'Not configuring the WORKSPACE for Android builds.'): + create_android_ndk_rule(environ_cp) + create_android_sdk_rule(environ_cp) + + if __name__ == '__main__': main() diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 3f23203aefd3d42c12c6a40f3711bcdedd22fd23..0054ce4b39e3a054f318b9766be41ec7efe79c0f 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -54,6 +54,15 @@ config_setting( visibility = ["//visibility:public"], ) +config_setting( + name = "raspberry_pi_armeabi", + values = { + "crosstool_top": "@local_config_arm_compiler//:toolchain", + "cpu": "armeabi", + }, + visibility = ["//visibility:public"], +) + config_setting( name = "android_arm", values = { @@ -110,7 +119,7 @@ config_setting( config_setting( name = "no_tensorflow_py_deps", - values = {"define": "no_tensorflow_py_deps=true"}, + define_values = {"no_tensorflow_py_deps": "true"}, visibility = ["//visibility:public"], ) @@ -166,55 +175,122 @@ config_setting( # TODO(jhseu): Enable on other platforms other than Linux. config_setting( name = "with_jemalloc_linux_x86_64", - values = { - "cpu": "k8", - "define": "with_jemalloc=true", - }, + define_values = {"with_jemalloc": "true"}, + values = {"cpu": "k8"}, visibility = ["//visibility:public"], ) config_setting( name = "with_jemalloc_linux_ppc64le", - values = { - "cpu": "ppc", - "define": "with_jemalloc=true", - }, + define_values = {"with_jemalloc": "true"}, + values = {"cpu": "ppc"}, + visibility = ["//visibility:public"], +) + +config_setting( + name = "with_default_optimizations", + define_values = {"with_default_optimizations": "true"}, visibility = ["//visibility:public"], ) config_setting( name = "with_gcp_support", - values = {"define": "with_gcp_support=true"}, + define_values = {"with_gcp_support": "true"}, visibility = ["//visibility:public"], ) config_setting( name = "with_hdfs_support", - values = {"define": "with_hdfs_support=true"}, + define_values = {"with_hdfs_support": "true"}, visibility = ["//visibility:public"], ) config_setting( name = "with_s3_support", - values = {"define": "with_s3_support=true"}, + define_values = {"with_s3_support": "true"}, + visibility = ["//visibility:public"], +) + +# Crosses between platforms and file system libraries not supported on those +# platforms due to limitations in nested select() statements. +config_setting( + name = "with_gcp_support_windows_override", + define_values = {"with_gcp_support": "true"}, + values = {"cpu": "x64_windows"}, + visibility = ["//visibility:public"], +) + +config_setting( + name = "with_hdfs_support_windows_override", + define_values = {"with_hdfs_support": "true"}, + values = {"cpu": "x64_windows"}, + visibility = ["//visibility:public"], +) + +config_setting( + name = "with_s3_support_windows_override", + define_values = {"with_s3_support": "true"}, + values = {"cpu": "x64_windows"}, + visibility = ["//visibility:public"], +) + +config_setting( + name = "with_gcp_support_android_override", + define_values = {"with_gcp_support": "true"}, + values = {"crosstool_top": "//external:android/crosstool"}, + visibility = ["//visibility:public"], +) + +config_setting( + name = "with_hdfs_support_android_override", + define_values = {"with_hdfs_support": "true"}, + values = {"crosstool_top": "//external:android/crosstool"}, + visibility = ["//visibility:public"], +) + +config_setting( + name = "with_s3_support_android_override", + define_values = {"with_s3_support": "true"}, + values = {"crosstool_top": "//external:android/crosstool"}, + visibility = ["//visibility:public"], +) + +config_setting( + name = "with_gcp_support_ios_override", + define_values = {"with_gcp_support": "true"}, + values = {"crosstool_top": "//tools/osx/crosstool:crosstool"}, + visibility = ["//visibility:public"], +) + +config_setting( + name = "with_hdfs_support_ios_override", + define_values = {"with_hdfs_support": "true"}, + values = {"crosstool_top": "//tools/osx/crosstool:crosstool"}, + visibility = ["//visibility:public"], +) + +config_setting( + name = "with_s3_support_ios_override", + define_values = {"with_s3_support": "true"}, + values = {"crosstool_top": "//tools/osx/crosstool:crosstool"}, visibility = ["//visibility:public"], ) config_setting( name = "with_xla_support", - values = {"define": "with_xla_support=true"}, + define_values = {"with_xla_support": "true"}, visibility = ["//visibility:public"], ) config_setting( name = "with_gdr_support", - values = {"define": "with_gdr_support=true"}, + define_values = {"with_gdr_support": "true"}, visibility = ["//visibility:public"], ) config_setting( name = "with_verbs_support", - values = {"define": "with_verbs_support=true"}, + define_values = {"with_verbs_support": "true"}, visibility = ["//visibility:public"], ) @@ -291,6 +367,7 @@ config_setting( package_group( name = "internal", packages = [ + "//learning/meta_rank/...", "//tensorflow/...", "//tensorflow_fold/llgtm/...", ], @@ -336,6 +413,7 @@ filegroup( "//tensorflow/compiler/tf2xla:all_files", "//tensorflow/compiler/tf2xla/cc:all_files", "//tensorflow/compiler/tf2xla/kernels:all_files", + "//tensorflow/compiler/tf2xla/lib:all_files", "//tensorflow/compiler/tf2xla/ops:all_files", "//tensorflow/compiler/xla:all_files", "//tensorflow/compiler/xla/client:all_files", @@ -408,11 +486,31 @@ filegroup( "//tensorflow/contrib/learn/python/learn/datasets:all_files", "//tensorflow/contrib/linalg:all_files", "//tensorflow/contrib/linear_optimizer:all_files", + "//tensorflow/contrib/lite:all_files", + "//tensorflow/contrib/lite/java:all_files", + "//tensorflow/contrib/lite/java/demo/app/src/main:all_files", + "//tensorflow/contrib/lite/java/demo/app/src/main/assets:all_files", + "//tensorflow/contrib/lite/java/src/main/native:all_files", + "//tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite:all_files", + "//tensorflow/contrib/lite/kernels:all_files", + "//tensorflow/contrib/lite/kernels/internal:all_files", + "//tensorflow/contrib/lite/models/smartreply:all_files", + "//tensorflow/contrib/lite/nnapi:all_files", + "//tensorflow/contrib/lite/python:all_files", + "//tensorflow/contrib/lite/schema:all_files", + "//tensorflow/contrib/lite/testing:all_files", + "//tensorflow/contrib/lite/toco:all_files", + "//tensorflow/contrib/lite/toco/graph_transformations/tests:all_files", + "//tensorflow/contrib/lite/toco/python:all_files", + "//tensorflow/contrib/lite/toco/tensorflow_graph_matching:all_files", + "//tensorflow/contrib/lite/toco/tflite:all_files", + "//tensorflow/contrib/lite/tools:all_files", "//tensorflow/contrib/lookup:all_files", "//tensorflow/contrib/losses:all_files", "//tensorflow/contrib/makefile:all_files", "//tensorflow/contrib/meta_graph_transform:all_files", "//tensorflow/contrib/metrics:all_files", + "//tensorflow/contrib/model_pruning:all_files", "//tensorflow/contrib/mpi_collectives:all_files", "//tensorflow/contrib/ndlstm:all_files", "//tensorflow/contrib/nearest_neighbor:all_files", @@ -456,6 +554,7 @@ filegroup( "//tensorflow/contrib/timeseries/python/timeseries/state_space_models:all_files", "//tensorflow/contrib/tpu:all_files", "//tensorflow/contrib/tpu/profiler:all_files", + "//tensorflow/contrib/tpu/proto:all_files", "//tensorflow/contrib/training:all_files", "//tensorflow/contrib/util:all_files", "//tensorflow/contrib/verbs:all_files", @@ -503,6 +602,7 @@ filegroup( "//tensorflow/java/src/main/native:all_files", "//tensorflow/python:all_files", "//tensorflow/python/data:all_files", + "//tensorflow/python/data/kernel_tests:all_files", "//tensorflow/python/data/ops:all_files", "//tensorflow/python/data/util:all_files", "//tensorflow/python/debug:all_files", @@ -539,6 +639,7 @@ filegroup( "//tensorflow/tools/test:all_files", "//tensorflow/user_ops:all_files", "//third_party/hadoop:all_files", + "//third_party/mpi:all_files", "//third_party/sycl:all_files", "//third_party/sycl/sycl:all_files", ], @@ -669,3 +770,10 @@ tf_cc_shared_object( "//tensorflow/core:tensorflow", ], ) + +exports_files( + [ + "tf_version_script.lds", + "tf_exported_symbols.lds", + ], +) diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 6dd1b999102d0135720b6ab3a43cbe61255acbc1..8a85eba5fc439af59144d3e8b869bf16b9462456 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -383,12 +383,11 @@ void TF_Reset_Helper(const TF_SessionOptions* opt, const char** containers, // be less than the total node count. Status ValidateNoCycles(const Graph& g) { // TODO(nolivia): check this on a subset of the graph instead of all of it. - int total_num_nodes = g.num_node_ids(); // A node is ready when all of its inputs have been visited. std::vector ready; - std::vector pending_count(total_num_nodes, 0); + std::vector pending_count(g.num_node_ids(), 0); - for (int i = 0; i < total_num_nodes; ++i) { + for (int i = 0; i < g.num_node_ids(); ++i) { const Node* n = g.FindNodeId(i); if (n == nullptr) continue; pending_count[i] = n->in_edges().size(); @@ -421,7 +420,7 @@ Status ValidateNoCycles(const Graph& g) { } } - if (processed < total_num_nodes) { + if (processed < g.num_nodes()) { std::vector nodes_in_cycle; for (int i = 0; i < pending_count.size() && nodes_in_cycle.size() < 3; ++i) { @@ -430,7 +429,7 @@ Status ValidateNoCycles(const Graph& g) { } } return errors::InvalidArgument( - "Graph is invalid, contains a cycle with ", total_num_nodes - processed, + "Graph is invalid, contains a cycle with ", g.num_nodes() - processed, " nodes, including: ", str_util::Join(nodes_in_cycle, ", ")); } return Status::OK(); @@ -580,6 +579,7 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, status->status = InvalidArgument( "invalid string tensor encoding (string #", i, " of ", srcarray.size(), "): ", status->status.error_message()); + delete[] base; return nullptr; } dst += consumed; @@ -589,6 +589,7 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, status->status = InvalidArgument( "invalid string tensor encoding (decoded ", (dst - base), " bytes, but the tensor is encoded in ", size, " bytes"); + delete[] base; return nullptr; } @@ -625,6 +626,23 @@ Status MessageToBuffer(const tensorflow::protobuf::Message& in, return Status::OK(); } +void RecordMutation(TF_Graph* graph, const TF_Operation& op, + const char* mutation_type) + EXCLUSIVE_LOCKS_REQUIRED(graph->mu) { + // If any session has already run this node_id, mark this session as + // unrunnable. + for (auto it : graph->sessions) { + if (it.first->last_num_graph_nodes > op.node.id()) { + it.second = FailedPrecondition( + "Operation '", op.node.DebugString(), "' was changed by ", + mutation_type, + " after it was run by a session. Nodes can be mutated " + "only before they are executed by a session. Either don't modify " + "nodes after running them or create a new session."); + } + } +} + // Helpers for loading a TensorFlow plugin (a .so file). Status LoadLibrary(const char* library_filename, void** result, const void** buf, size_t* len); @@ -890,8 +908,8 @@ const tensorflow::AttrValue* GetAttrValue(TF_Operation* oper, TF_Status* status) { const tensorflow::AttrValue* attr = oper->node.attrs().Find(attr_name); if (attr == nullptr) { - status->status = - InvalidArgument("Operation has no attr named '", attr_name, "'."); + status->status = InvalidArgument("Operation '", oper->node.name(), + "' has no attr named '", attr_name, "'."); } return attr; } @@ -939,13 +957,17 @@ void TF_GraphSetTensorShape(TF_Graph* graph, TF_Output output, return; } - std::vector dim_vec; - dim_vec.reserve(num_dims); - for (int i = 0; i < num_dims; ++i) { - dim_vec.push_back(ic->MakeDim(dims[i])); + tensorflow::shape_inference::ShapeHandle new_shape; + if (num_dims != -1) { + std::vector dim_vec; + dim_vec.reserve(num_dims); + for (int i = 0; i < num_dims; ++i) { + dim_vec.push_back(ic->MakeDim(dims[i])); + } + new_shape = ic->MakeShape(dim_vec); + } else { + new_shape = ic->UnknownShape(); } - - tensorflow::shape_inference::ShapeHandle new_shape = ic->MakeShape(dim_vec); status->status = graph->refiner.SetShape(node, output.index, new_shape); } @@ -1741,7 +1763,6 @@ void TF_OperationToNodeDef(TF_Operation* oper, TF_Buffer* output_node_def, TF_Graph::TF_Graph() : graph(tensorflow::OpRegistry::Global()), refiner(graph.versions().producer(), graph.op_registry()), - num_sessions(0), delete_requested(false), parent(nullptr), parent_inputs(nullptr) {} @@ -1751,7 +1772,7 @@ TF_Graph* TF_NewGraph() { return new TF_Graph; } void TF_DeleteGraph(TF_Graph* g) { g->mu.lock(); g->delete_requested = true; - const bool del = g->num_sessions == 0; + const bool del = g->sessions.empty(); g->mu.unlock(); if (del) delete g; } @@ -1831,6 +1852,16 @@ void TF_ImportGraphDefOptionsSetPrefix(TF_ImportGraphDefOptions* opts, opts->opts.prefix = prefix; } +void TF_ImportGraphDefOptionsSetUniquifyNames(TF_ImportGraphDefOptions* opts, + unsigned char uniquify_names) { + opts->opts.uniquify_names = uniquify_names; +} + +void TF_ImportGraphDefOptionsSetUniquifyPrefix(TF_ImportGraphDefOptions* opts, + unsigned char uniquify_prefix) { + opts->opts.uniquify_prefix = uniquify_prefix; +} + void TF_ImportGraphDefOptionsAddInputMapping(TF_ImportGraphDefOptions* opts, const char* src_name, int src_index, TF_Output dst) { @@ -2321,11 +2352,12 @@ TF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opt, Session* session; status->status = NewSession(opt->options, &session); if (status->status.ok()) { + TF_Session* new_session = new TF_Session(session, graph); if (graph != nullptr) { mutex_lock l(graph->mu); - graph->num_sessions += 1; + graph->sessions[new_session] = Status::OK(); } - return new TF_Session(session, graph); + return new_session; } else { DCHECK_EQ(nullptr, session); return nullptr; @@ -2389,7 +2421,7 @@ TF_Session* TF_LoadSessionFromSavedModel( TF_Session* session = new TF_Session(bundle.session.release(), graph); - graph->num_sessions += 1; + graph->sessions[session] = Status::OK(); session->last_num_graph_nodes = graph->graph.num_node_ids(); return session; #endif // __ANDROID__ @@ -2404,8 +2436,8 @@ void TF_DeleteSession(TF_Session* s, TF_Status* status) { TF_Graph* const graph = s->graph; if (graph != nullptr) { graph->mu.lock(); - graph->num_sessions -= 1; - const bool del = graph->delete_requested && graph->num_sessions == 0; + graph->sessions.erase(s); + const bool del = graph->delete_requested && graph->sessions.empty(); graph->mu.unlock(); if (del) delete graph; } @@ -2421,6 +2453,13 @@ static bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) { mutex_lock session_lock(session->mu); session->graph->mu.lock(); const Graph& graph = session->graph->graph; + + status->status = session->graph->sessions[session]; + if (!status->status.ok()) { + session->graph->mu.unlock(); + return false; + } + const auto num_nodes = graph.num_node_ids(); if (session->last_num_graph_nodes < num_nodes) { status->status = tensorflow::ValidateNoCycles(session->graph->graph); diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index bb569d67fcbcec29e9494236abd79b3e40db91cd..df7fe222b130d2fd58915be112ff08a29d27639a 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -889,6 +889,20 @@ TF_CAPI_EXPORT extern void TF_DeleteImportGraphDefOptions( TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsSetPrefix( TF_ImportGraphDefOptions* opts, const char* prefix); +// Set whether to uniquify imported operation names. If true, imported operation +// names will be modified if their name already exists in the graph. If false, +// conflicting names will be treated as an error. Note that this option has no +// effect if a prefix is set, since the prefix will guarantee all names are +// unique. Defaults to false. +TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsSetUniquifyNames( + TF_ImportGraphDefOptions* opts, unsigned char uniquify_names); + +// If true, the specified prefix will be modified if it already exists as an +// operation name or prefix in the graph. If false, a conflicting prefix will be +// treated as an error. This option has no effect if no prefix is specified. +TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsSetUniquifyPrefix( + TF_ImportGraphDefOptions* opts, unsigned char uniquify_prefix); + // Set any imported nodes with input `src_name:src_index` to have that input // replaced with `dst`. `src_name` refers to a node in the graph to be imported, // `dst` references a node already existing in the graph being imported into. diff --git a/tensorflow/c/c_api_function.cc b/tensorflow/c/c_api_function.cc index dcb818b88b6fca460852beb6e948d2eb6964f663..d60d1de315ed37a327bd036ddb914a3c32413f65 100644 --- a/tensorflow/c/c_api_function.cc +++ b/tensorflow/c/c_api_function.cc @@ -68,7 +68,7 @@ class NodeNameMapping { // This is a superset of values in name_mapping_. std::unordered_set used_names_; // Mapping from original node name from the graph to the normalized - // and uniqified version of it. + // and uniquified version of it. std::unordered_map name_mapping_; }; @@ -226,12 +226,17 @@ Status FillFunctionBody( } node_def->add_input(strings::StrCat("^", normalized)); } + + // A function is stateful if any of its nodes are stateful. + if (node->op_def().is_stateful()) { + fdef->mutable_signature()->set_is_stateful(true); + } } return Status::OK(); } // Graph to FunctionDef conversion. This code is closely modeled on the Python -// code in third_party/tensorflow/python/framework/function.py. +// code in tensorflow/python/framework/function.py. Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name, bool append_hash_to_fn_name, const std::vector& body_nodes, diff --git a/tensorflow/c/c_api_function_test.cc b/tensorflow/c/c_api_function_test.cc index d5580b658992413ae6f9cb79ef88751ee28ce465..4ffc9d69312eae6c683b5701ceb44c13c7e61c5e 100644 --- a/tensorflow/c/c_api_function_test.cc +++ b/tensorflow/c/c_api_function_test.cc @@ -1482,6 +1482,51 @@ TEST_F(CApiFunctionTest, GetOpDef) { EXPECT_EQ(op_def.name(), func_name_); EXPECT_EQ(op_def.input_arg_size(), 1); EXPECT_EQ(op_def.output_arg_size(), 1); + EXPECT_FALSE(op_def.is_stateful()); + + TF_DeleteBuffer(buffer); +} + +void DefineStatefulFunction(const char* name, TF_Function** func) { + std::unique_ptr func_graph( + TF_NewGraph(), TF_DeleteGraph); + std::unique_ptr s(TF_NewStatus(), + TF_DeleteStatus); + + TF_Tensor* tensor_shape = Int32Tensor({37, 1}); + TF_Operation* shape = Const(tensor_shape, func_graph.get(), s.get(), "shape"); + TF_Operation* random = + RandomUniform(shape, TF_FLOAT, func_graph.get(), s.get()); + + TF_Output inputs[] = {}; + TF_Output outputs[] = {{random, 0}}; + *func = TF_GraphToFunction(func_graph.get(), name, /*append_hash=*/0, -1, + /*opers=*/nullptr, 0, inputs, 1, outputs, + /*output_names=*/nullptr, + /*opts=*/nullptr, "", s.get()); + ASSERT_EQ(TF_OK, TF_GetCode(s.get())) << TF_Message(s.get()); + ASSERT_NE(*func, nullptr); + TF_DeleteTensor(tensor_shape); +} + +TEST_F(CApiFunctionTest, StatefulOpDef) { + DefineStatefulFunction(func_name_, &func_); + TF_GraphCopyFunction(host_graph_, func_, nullptr, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + // Test we can retrieve function OpDef from graph + TF_Buffer* buffer = TF_NewBuffer(); + TF_GraphGetOpDef(host_graph_, func_name_, buffer, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + // Sanity check returned OpDef + string data(static_cast(buffer->data), buffer->length); + OpDef op_def; + op_def.ParseFromString(data); + EXPECT_EQ(op_def.name(), func_name_); + EXPECT_EQ(op_def.input_arg_size(), 0); + EXPECT_EQ(op_def.output_arg_size(), 1); + EXPECT_TRUE(op_def.is_stateful()); TF_DeleteBuffer(buffer); } diff --git a/tensorflow/c/c_api_internal.h b/tensorflow/c/c_api_internal.h index bb04e01beec931a8ea66d0855eec9625d3a6a5ab..aac333d9e29e60148a271fba95fadde708c7c370 100644 --- a/tensorflow/c/c_api_internal.h +++ b/tensorflow/c/c_api_internal.h @@ -81,12 +81,20 @@ struct TF_Graph { std::unordered_map name_map GUARDED_BY(mu); - // TF_Graph may only / must be deleted when - // num_sessions == 0 && delete_requested == true - - // num_sessions incremented by TF_NewSession, and decremented by + // The keys of this map are all the active sessions using this graph. + // Each value is the current "runnability" status of the corresponding + // session. Under normal conditions all statuses are Status::OK(), but + // if some operation is mutated after it was run by a session (this + // is detected in RecordMutation function), that session is no longer + // safe to run. Its status will contain the error that will be returned + // to the user, should she try running this session. + // + // Sessions are added to this map in TF_NewSession, and removed in // TF_DeleteSession. - int num_sessions GUARDED_BY(mu); + // TF_Graph may only / must be deleted when + // sessions.size() == 0 && delete_requested == true + tensorflow::gtl::FlatMap sessions + GUARDED_BY(mu); bool delete_requested GUARDED_BY(mu); // set true by TF_DeleteGraph // Used to link graphs contained in TF_WhileParams to the parent graph that @@ -167,6 +175,9 @@ TF_Tensor* TF_TensorFromTensor(const Tensor& src, TF_Status* status); Status MessageToBuffer(const tensorflow::protobuf::Message& in, TF_Buffer* out); +void RecordMutation(TF_Graph* graph, const TF_Operation& op, + const char* mutation_type); + } // end namespace tensorflow #endif // TENSORFLOW_C_C_API_INTERNAL_H_ diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index 05881e619ba232de99e78f315cfa8ab9294e5137..6ec1db8ccfdb713f330b708e604bd4b502ff7202 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -287,6 +287,13 @@ TEST(CAPI, SetShape) { ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); EXPECT_EQ(-1, num_dims); + // Set the shape to be unknown, expect no change. + TF_GraphSetTensorShape(graph, feed_out_0, /*dims=*/nullptr, -1, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + num_dims = TF_GraphGetTensorNumDims(graph, feed_out_0, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + EXPECT_EQ(-1, num_dims); + // Set the shape to be 2 x Unknown int64_t dims[] = {2, -1}; TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s); @@ -315,7 +322,17 @@ TEST(CAPI, SetShape) { EXPECT_EQ(dims[0], returned_dims[0]); EXPECT_EQ(dims[1], returned_dims[1]); - // Try to set 'unknown' on the shape and see that + // Try to set 'unknown' with unknown rank on the shape and see that + // it doesn't change. + TF_GraphSetTensorShape(graph, feed_out_0, /*dims=*/nullptr, -1, s); + EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + EXPECT_EQ(2, num_dims); + EXPECT_EQ(2, returned_dims[0]); + EXPECT_EQ(3, returned_dims[1]); + + // Try to set 'unknown' with same rank on the shape and see that // it doesn't change. dims[0] = -1; dims[1] = -1; @@ -383,7 +400,7 @@ TEST(CAPI, Graph) { EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s)); ASSERT_FALSE(GetAttrValue(feed, "missing", &attr_value, s)); - EXPECT_EQ(string("Operation has no attr named 'missing'."), + EXPECT_EQ(string("Operation 'feed' has no attr named 'missing'."), string(TF_Message(s))); // Make a constant oper with the scalar "3". @@ -1054,7 +1071,7 @@ class CApiColocationTest : public ::testing::Test { TF_OperationGetAttrMetadata(op, tensorflow::kColocationAttrName, s_); if (expected.empty()) { ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_); - EXPECT_EQ(std::string("Operation has no attr named '_class'."), + EXPECT_EQ(std::string("Operation 'add' has no attr named '_class'."), std::string(TF_Message(s_))); return; } diff --git a/tensorflow/c/c_test_util.cc b/tensorflow/c/c_test_util.cc index c291a2e440a8515e968b0ce0395b289080f04e8b..37439ff0beac5a5220460465e954b6c093ee1ba9 100644 --- a/tensorflow/c/c_test_util.cc +++ b/tensorflow/c/c_test_util.cc @@ -193,6 +193,15 @@ TF_Operation* LessThan(TF_Output l, TF_Output r, TF_Graph* graph, return TF_FinishOperation(desc, s); } +TF_Operation* RandomUniform(TF_Operation* shape, TF_DataType dtype, + TF_Graph* graph, TF_Status* s) { + TF_OperationDescription* desc = + TF_NewOperation(graph, "RandomUniform", "random_uniform"); + TF_AddInput(desc, {shape, 0}); + TF_SetAttrType(desc, "dtype", dtype); + return TF_FinishOperation(desc, s); +} + void Split3Helper(TF_Operation* input, TF_Graph* graph, TF_Status* s, const char* name, TF_Operation** op) { TF_Operation* zero = ScalarConst( diff --git a/tensorflow/c/c_test_util.h b/tensorflow/c/c_test_util.h index d54733749248fa32c39d88bb0281d329dd50c7bd..96a93afef3e22d352fdbe911c3a5b01c867c6033 100644 --- a/tensorflow/c/c_test_util.h +++ b/tensorflow/c/c_test_util.h @@ -74,6 +74,9 @@ TF_Operation* Neg(TF_Operation* n, TF_Graph* graph, TF_Status* s, TF_Operation* LessThan(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s); +TF_Operation* RandomUniform(TF_Operation* shape, TF_DataType dtype, + TF_Graph* graph, TF_Status* s); + // Split `input` along the first dimention into 3 tensors TF_Operation* Split3(TF_Operation* input, TF_Graph* graph, TF_Status* s, const char* name = "split3"); diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index c77896b80b478cd34d3502e1061a7e76204ba021..d533758e360bc44a6f52f57eaae5b222e0482860 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -39,6 +39,7 @@ tf_cuda_library( tf_cuda_library( name = "c_api_internal", hdrs = ["c_api_internal.h"], + visibility = ["//tensorflow:internal"], deps = [ ":c_api", ":runtime", @@ -105,7 +106,6 @@ tf_cc_test( cc_library( name = "tape", - srcs = ["tape.cc"], hdrs = ["tape.h"], visibility = ["//tensorflow:internal"], deps = [ diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 8359de62b7ff690fec9f6a0e3280f947c62f8b6e..706c89536db019c7f7389af576815746b2425520 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -571,6 +571,12 @@ void TFE_ContextAddFunctionDef(TFE_Context* ctx, status->status = ctx->func_lib_def.AddFunctionDef(function_def); } +void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function, + TF_Status* status) { + tensorflow::mutex_lock l(ctx->functions_mu); + status->status = ctx->func_lib_def.AddFunctionDef(function->fdef); +} + } // extern "C" TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t) { diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index 865580c5f3a823d9cf49fe460bd007e3b3b88767..ca105962df0d6655946304159937621022e7fcba 100644 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -200,6 +200,13 @@ TF_CAPI_EXPORT extern void TFE_ContextAddFunctionDef(TFE_Context* ctx, const char* serialized_function_def, size_t size, TF_Status* status); +// Adds a function (created from TF_GraphToFunction or +// TF_FunctionImportFunctionDef) to the context, allowing it to be executed with +// TFE_Execute by creating an op with the same name as the function. +TF_CAPI_EXPORT extern void TFE_ContextAddFunction(TFE_Context* ctx, + TF_Function* function, + TF_Status* status); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 4af91b8853d0e85570bad136752a9d0a04b87da5..3fe0b7efa11bc619ed98bf9a1634ade5b6ed0a7c 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -295,6 +295,67 @@ TEST(CAPI, Execute) { TF_DeleteStatus(status); } +TEST(CAPI, Function) { + // First create a simple identity function. + TF_Graph* function_graph = TF_NewGraph(); + TF_OperationDescription* arg_descr = + TF_NewOperation(function_graph, "Placeholder", "arg"); + TF_SetAttrType(arg_descr, "dtype", TF_INT32); + TF_Status* status = TF_NewStatus(); + TF_Operation* arg = TF_FinishOperation(arg_descr, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TF_OperationDescription* id_descr = + TF_NewOperation(function_graph, "Identity", "id"); + TF_SetAttrType(id_descr, "T", TF_INT32); + TF_AddInput(id_descr, {arg, 0}); + TF_Operation* id = TF_FinishOperation(id_descr, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TF_Output input{arg, 0}; + TF_Output output{id, 0}; + TF_Function* fn = + TF_GraphToFunction(function_graph, "ident", 0, 1, &id, 1, &input, 1, + &output, nullptr, nullptr, "test", status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TF_DeleteGraph(function_graph); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_Context* ctx = TFE_NewContext(opts, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TFE_DeleteContextOptions(opts); + TFE_ContextAddFunction(ctx, fn, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TF_DeleteFunction(fn); + + TF_Tensor* t = + TF_AllocateTensor(TF_INT32, nullptr, 0, 1 * sizeof(tensorflow::int32)); + *reinterpret_cast(TF_TensorData(t)) = 42; + TFE_TensorHandle* h = TFE_NewTensorHandle(t, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TF_DeleteTensor(t); + + TFE_Op* op = TFE_NewOp(ctx, "ident", status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TFE_OpAddInput(op, h, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + + std::vector result; + result.push_back(nullptr); + int num_retvals = 1; + TFE_Execute(op, result.data(), &num_retvals, status); + TFE_DeleteOp(op); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + ASSERT_EQ(num_retvals, 1); + + TF_Tensor* r = TFE_TensorHandleResolve(result[0], status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + EXPECT_EQ(*reinterpret_cast(TF_TensorData(r)), 42); + TFE_DeleteTensorHandle(h); + TF_DeleteTensor(r); + TFE_DeleteTensorHandle(result[0]); + TFE_DeleteContext(ctx, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TF_DeleteStatus(status); +} + string MatMulFunction() { tensorflow::FunctionDef def; CHECK(tensorflow::protobuf::TextFormat::ParseFromString( diff --git a/tensorflow/c/eager/tape.cc b/tensorflow/c/eager/tape.cc deleted file mode 100644 index 464612a81ebda428f5582b6927f3a3b00a5aa6f5..0000000000000000000000000000000000000000 --- a/tensorflow/c/eager/tape.cc +++ /dev/null @@ -1,102 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/c/eager/tape.h" - -namespace tensorflow { -namespace eager { - -bool GradientTape::ShouldRecord(gtl::ArraySlice tensor_ids) { - for (int64 i : tensor_ids) { - if (tensor_tape_.find(i) != tensor_tape_.end()) { - return true; - } - } - return false; -} - -void GradientTape::Watch(int64 tensor_id) { - tensor_tape_.emplace(tensor_id, -1); -} - -void GradientTape::RecordOperation( - const string& op_type, gtl::ArraySlice output_tensors, - gtl::ArraySlice input_tensor_id, void* backward_function, - const std::function& backward_function_deleter) { - if (!ShouldRecord(input_tensor_id)) { - backward_function_deleter(); - return; - } - std::vector ids; - ids.reserve(input_tensor_id.size()); - for (int64 i : input_tensor_id) { - tensor_usage_[i]++; - ids.push_back(i); - } - const int64 op_id = next_op_id_++; - std::vector tensors; - tensors.reserve(output_tensors.size()); - for (const TapeTensor& o : output_tensors) { - // Note: the tensor can have already been watched and hence be in the tape, - // so we cannot check that we're inserting it here. - tensor_tape_[o.id] = op_id; - tensor_usage_[o.id] = 1; - tensors.push_back(o); - } - op_tape_[op_id] = OpTapeEntry{op_type, tensors, ids, backward_function, - backward_function_deleter}; -} - -void GradientTape::DeleteTrace(int64 tensor_id) { - auto it = tensor_usage_.find(tensor_id); - if (it == tensor_usage_.end()) { - return; - } - it->second--; - if (it->second != 0) { - return; - } - tensor_usage_.erase(it); - auto tensor_op_it = tensor_tape_.find(tensor_id); - if (tensor_op_it == tensor_tape_.end()) { - return; - } - const int64 op_id = tensor_op_it->second; - if (op_id == -1) { - // Do not delete watched tensors. - return; - } - tensor_tape_.erase(tensor_op_it); - auto op_it = op_tape_.find(op_id); - CHECK(op_it != op_tape_.end()); - for (const auto& output : op_it->second.output_tensor_info) { - if (tensor_usage_.find(output.id) != tensor_usage_.end()) { - // Found a usage for an output, so cannot delete the op. - return; - } - } - for (int64 id : op_it->second.input_tensor_id) { - DeleteTrace(id); - } - op_it->second.backward_function_deleter(); - op_tape_.erase(op_it); -} - -std::pair GradientTape::Export() { - return {std::move(tensor_tape_), std::move(op_tape_)}; -} - -} // namespace eager -} // namespace tensorflow diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h index df51f300eb61d54cb1e06d5a58a9b10e834f73c4..20ed037c52f34bc7a8aa39243c0b85e58fee1d46 100644 --- a/tensorflow/c/eager/tape.h +++ b/tensorflow/c/eager/tape.h @@ -19,6 +19,7 @@ limitations under the License. // maintains the data structures required to do so. #include +#include #include #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.h" @@ -36,13 +37,14 @@ struct TapeTensor { }; // Represents an entry in the tape. +template struct OpTapeEntry { string op_type; std::vector output_tensor_info; std::vector input_tensor_id; // TODO(apassos) consider narrowing down this interface. - void* backward_function; + BackwardFunction* backward_function; // Should be called before deleting the backward function. TODO(apassos) use // unique_ptrs to ensure this happens. @@ -55,13 +57,78 @@ struct OpTapeEntry { using TensorTape = std::unordered_map; // Map from operation-id to tape entry. -using OpTape = std::unordered_map; +template +using OpTape = std::unordered_map>; + +// Operations the tape needs to perform on tensors to do backpropagation. Named +// "vspace" because a subset of these are related to a vector space, such as +// adding gradients, getting zeroes, etc. Currently cannot be implemented +// without using tensorflow python code, hence left unspecified here. +// +// Gradient is the type returned by gradient functions. In Python TF it's either +// Tensor or IndexedSlices or None, which here we map to nullptr. Gradients need +// to allow their size to be computed and they need to be passable to a backward +// function and deleted (as the backprop code creates lots of gradients the user +// is not interested in). +// +// BackwardFunction needs to be a closure which stores intermediate activations +// from the forward computation and calls a vector-jacobian product function +// (also known as adjoint function) to compute, given downstream gradients, +// upstream gradients. +// +// TODO(apassos) provide concrete template instantiations for TFE_TensorHandle +// specialization, which is blocked by quite a few things needing to loop back +// into python now. +template +class VSpace { + public: + virtual ~VSpace() {} + + // Returns the number of elements in the gradient tensor. + virtual int64 NumElements(Gradient* tensor) const = 0; + + // Consumes references to the tensors in the gradient_tensors list and returns + // a tensor with the result. + virtual Gradient* AggregateGradients( + gtl::ArraySlice gradient_tensors) const = 0; + + // Returns a tensor of the right shape and dtype filled with zeros. + virtual Gradient* Zeros(TensorShape shape, DataType dtype) const = 0; + + // Returns a Tensor which is filled with ones and like the input. + virtual Gradient* Ones(TensorShape shape, DataType dtype) const = 0; + + // Calls the passed-in backward function. + virtual Status CallBackwardFunction( + BackwardFunction* backward_function, + gtl::ArraySlice output_gradients, + std::vector* result) const = 0; + + // Deletes the input tensor. + virtual void DeleteGradient(Gradient* gradient) const = 0; + + // Lets this VSpace know that it can release resources held by the + // `backward_function`, It will not be called again. + // `backward_function` must not be null. + virtual void ReleaseBackwardFunction( + BackwardFunction* backward_function) const = 0; +}; // Traces the execution of operations, doing eager garbage collection, and // exporting a full trace so other code can do backpropagation. Not thread-safe. +template class GradientTape { public: - GradientTape() {} + // If `persistent` is true, GradientTape will not eagerly delete backward + // functions (and hence the tensors they keep alive). Instead, everything + // is deleted in ~GradientTape. Persistent GradientTapes are useful when + // users want to compute multiple gradients over the same tape. + GradientTape(bool persistent) : persistent_(persistent) {} + ~GradientTape() { + for (const auto& pair : op_tape_) { + pair.second.backward_function_deleter(); + } + } bool ShouldRecord(gtl::ArraySlice tensor_ids); @@ -70,26 +137,486 @@ class GradientTape { void RecordOperation(const string& op_type, gtl::ArraySlice output_tensors, gtl::ArraySlice input_tensor_id, - void* backward_function, + BackwardFunction* backward_function, const std::function& backward_function_deleter); void DeleteTrace(int64 tensor_id); - // Note: it is only valid to call Export once per tape, and after calling - // export the tape is no longer valid (i.e. calls to ShouldRecord, Watch, - // Record, and Delete have undefined behavior). - std::pair Export(); + // Consumes the internal state of the tape (so cannot be called more than + // once) and produces the gradient of the target tensors with respect to the + // source tensors. The output gradients are used if not empty and not + // null. The result is populated with one tensor per target element. + Status ComputeGradient(const VSpace& vspace, + gtl::ArraySlice target_tensor_ids, + gtl::ArraySlice source_tensor_id, + gtl::ArraySlice output_gradients, + std::vector* result); private: TensorTape tensor_tape_; - OpTape op_tape_; + OpTape op_tape_; int64 next_op_id_{0}; // Map from tensor id to number of remaining usages (i.e. how many entries in // the tape refer to it); to aid in tape garbage collection. std::unordered_map tensor_usage_; + + // If false, all activations are deleted in the first call to ComputeGradient. + // Else, only when this is destructed. + bool persistent_; +}; + +// Template instantiations here + +template +bool GradientTape::ShouldRecord( + gtl::ArraySlice tensor_ids) { + for (int64 i : tensor_ids) { + if (tensor_tape_.find(i) != tensor_tape_.end()) { + return true; + } + } + return false; +} + +template +void GradientTape::Watch(int64 tensor_id) { + tensor_tape_.emplace(tensor_id, -1); +} + +template +void GradientTape::RecordOperation( + const string& op_type, gtl::ArraySlice output_tensors, + gtl::ArraySlice input_tensor_id, BackwardFunction* backward_function, + const std::function& backward_function_deleter) { + if (!ShouldRecord(input_tensor_id)) { + backward_function_deleter(); + return; + } + std::vector ids; + ids.reserve(input_tensor_id.size()); + for (int64 i : input_tensor_id) { + tensor_usage_[i]++; + ids.push_back(i); + } + const int64 op_id = next_op_id_++; + std::vector tensors; + tensors.reserve(output_tensors.size()); + for (const TapeTensor& o : output_tensors) { + // Note: the tensor can have already been watched and hence be in the tape, + // so we cannot check that we're inserting it here. + tensor_tape_[o.id] = op_id; + tensor_usage_[o.id] = 1; + tensors.push_back(o); + } + op_tape_[op_id] = OpTapeEntry{ + op_type, tensors, ids, backward_function, backward_function_deleter}; +} + +template +void GradientTape::DeleteTrace(int64 tensor_id) { + auto it = tensor_usage_.find(tensor_id); + if (it == tensor_usage_.end()) { + return; + } + it->second--; + if (it->second != 0) { + return; + } + tensor_usage_.erase(it); + auto tensor_op_it = tensor_tape_.find(tensor_id); + if (tensor_op_it == tensor_tape_.end()) { + return; + } + const int64 op_id = tensor_op_it->second; + if (op_id == -1) { + // Do not delete watched tensors. + return; + } + tensor_tape_.erase(tensor_op_it); + auto op_it = op_tape_.find(op_id); + CHECK(op_it != op_tape_.end()); + for (const auto& output : op_it->second.output_tensor_info) { + if (tensor_usage_.find(output.id) != tensor_usage_.end()) { + // Found a usage for an output, so cannot delete the op. + return; + } + } + for (int64 id : op_it->second.input_tensor_id) { + DeleteTrace(id); + } + op_it->second.backward_function_deleter(); + op_tape_.erase(op_it); +} + +// Terminology: +// +// - op: a possibly composite operation, which has an entry in the tape +// - target: dy in dx/dy +// - source: dx in dx/dy +// - tensor: one of the many inputs or outputs of an operation +// +// Below here we do the gradient algorithm. It works as follows: +// +// First we filter the tape to just the subset of operations we want to +// differentiate. In the process of doing so we count how many times each Tensor +// is used as an input to an op (so we know when we're done computing gradients +// for that Tensor). We also count, for each tape entry, how many of its output +// Tensors need gradients to be computed (Tensors which are not used do not need +// any gradients to be computed). +// +// Finally, we start a backprop stack with a set of tape entries for which we +// have all gradients available. This set usually is a subset of the set of +// targets (not all since targets which have outputs in the tape will not have +// gradients available initially). +// +// Then we repeatedly pop an entry from the stack, run its backprop, and update +// the gradients of its inputs. Once we have computed all gradients for a single +// input we can mark this input as done, and this can trigger adding an entry to +// the stack if all outputs of that entry are now done. +// +// When the stack is empty we have gradients for all tensors we're interested +// in. + +namespace { + +template +struct BackpropInitialState { + OpTape op_tape; + + // Map from tensor ID to how many references still exist for this tensor in + // the tape. + std::unordered_map tensor_usage_counts; + + // Maps from op ID to how many output tensors of this op still need to have + // their gradients computed. + std::unordered_map op_missing_tensor; }; +// If `persistent_tape` is true, op_tape is not changed and none of the +// backwards functions are deleted. +// If `persistent_tape` is false, op_tape is cleared and backwards functions +// not needed for gradient computation are deleted. Backwards functions that +// are needed, are copied and returned in BackpropInitialState. +template +BackpropInitialState PrepareBackprop( + gtl::ArraySlice target, const TensorTape& tensor_tape, + OpTape* op_tape, + const std::unordered_set& sources_set, bool persistent_tape) { + std::vector tensor_stack; + tensor_stack.reserve(target.size()); + for (auto t : target) { + tensor_stack.push_back(t); + } + BackpropInitialState result; + while (!tensor_stack.empty()) { + int64 tensor_id = tensor_stack.back(); + tensor_stack.pop_back(); + auto op_id_it = tensor_tape.find(tensor_id); + if (op_id_it == tensor_tape.end()) { + continue; + } + int64 op_id = op_id_it->second; + auto op_it = op_tape->find(op_id); + auto result_op_it = result.op_tape.find(op_id); + if (op_id == -1 || op_it == op_tape->end() || + result_op_it != result.op_tape.end()) { + continue; + } + CHECK(result.op_tape.emplace(op_id, op_it->second).second); + for (auto it : op_it->second.input_tensor_id) { + auto count_it = result.tensor_usage_counts.find(it); + if (count_it != result.tensor_usage_counts.end()) { + count_it->second++; + } else { + result.tensor_usage_counts[it] = 1; + if (sources_set.find(it) == sources_set.end() && + tensor_tape.find(it) != tensor_tape.end()) { + tensor_stack.push_back(it); + } + } + } + if (!persistent_tape) { + op_tape->erase(op_it); + } + } + for (auto& pair : result.tensor_usage_counts) { + auto it = tensor_tape.find(pair.first); + if (it != tensor_tape.end() && it->second != -1) { + result.op_missing_tensor[it->second] += 1; + } + } + if (!persistent_tape) { + // Call destructors for all unneeded gradient functions and + // clear the op_tape. We can clear the tape because ownership of + // backward functions that will be used for gradient computation + // has been transfered to `result`. + for (const auto& op_pair : *op_tape) { + op_pair.second.backward_function_deleter(); + } + op_tape->clear(); + } + return result; +} + +template +std::vector InitialStack( + const OpTape& op_tape, + const std::unordered_map& op_missing_tensor) { + std::vector result; + for (auto& op_entry : op_tape) { + if (op_missing_tensor.find(op_entry.first) == op_missing_tensor.end()) { + result.push_back(op_entry.first); + } + } + return result; +} + +template +Status InitialGradients( + const VSpace& vspace, + gtl::ArraySlice target_tensor_ids, + gtl::ArraySlice output_gradients, const TensorTape& tensor_tape, + const OpTape& op_tape, + const std::unordered_map& tensor_usage_counts, + std::unordered_map>* result) { + for (int i = 0; i < target_tensor_ids.size(); ++i) { + const int64 id = target_tensor_ids[i]; + if (tensor_usage_counts.find(id) != tensor_usage_counts.end()) { + if (!output_gradients.empty() && output_gradients[i] != nullptr) { + // TODO(apassos) figure out how to print debugging information here. + return errors::InvalidArgument( + "A gradient was provided for a tensor which is used as part of the " + "computation."); + } + } else { + if (output_gradients.empty() || output_gradients[i] == nullptr) { + auto tensor_it = tensor_tape.find(id); + if (tensor_it != tensor_tape.end() && tensor_it->second != -1) { + auto op_it = op_tape.find(tensor_it->second); + if (op_it == op_tape.end()) { + return errors::Internal( + "Internal state of the gradient tape is invalid: " + "failed to find operation producing a tensor"); + } + bool found = false; + for (int j = 0; j < op_it->second.output_tensor_info.size(); ++j) { + if (op_it->second.output_tensor_info[j].id == id) { + found = true; + (*result)[id].push_back( + vspace.Ones(op_it->second.output_tensor_info[j].shape, + op_it->second.output_tensor_info[j].dtype)); + break; + } + } + if (!found) { + return errors::Internal( + "Internal state of the gradient tape is invalid: " + "none of operations outputs match expected tensor"); + } + } else { + // No record of the target tensor found on the tape, so no gradient + // needs to be computed from it. Do nothing. + } + } else { + (*result)[id].push_back(output_gradients[i]); + } + } + } + return Status::OK(); +} + +} // namespace + +// If over kMinAggregateCount gradients are accumulated and the total +// memory consumption is over kMinAggregateBytes, do an early aggregation +// so as to release the gradient tensor to save memory. +constexpr int kMinAggregateCount = 4; +constexpr int kMinAggregateBytes = 128 * 1024 * 1024; + +template +Status GradientTape::ComputeGradient( + const VSpace& vspace, + gtl::ArraySlice target_tensor_ids, + gtl::ArraySlice source_tensor_ids, + gtl::ArraySlice output_gradients, + std::vector* result) { + std::unordered_set sources_set(source_tensor_ids.begin(), + source_tensor_ids.end()); + BackpropInitialState state = PrepareBackprop( + target_tensor_ids, tensor_tape_, &op_tape_, sources_set, persistent_); + std::vector op_stack = + InitialStack(state.op_tape, state.op_missing_tensor); + std::unordered_map> gradients; + Status s = InitialGradients(vspace, target_tensor_ids, output_gradients, + tensor_tape_, state.op_tape, + state.tensor_usage_counts, &gradients); + auto cleanup = [this, &state]() { + if (!persistent_) { + // Release all backprop functions + for (const auto& pair : state.op_tape) { + pair.second.backward_function_deleter(); + } + } + }; + if (!s.ok()) { + cleanup(); + return s; + } + std::unordered_map gradients_size; + // TODO(apassos) multiple threads could be dequeuing from op_stack at the same + // time, for better CPU backprop performance. + VLOG(1) << "Initial stack:"; + if (VLOG_IS_ON(1)) { + for (auto t : op_stack) { + VLOG(1) << " " << t; + } + } + std::unordered_map> + functions_accept_none_for_indices({ + {"SoftmaxCrossEntropyWithLogits", {1}}, + {"FusedBatchNorm", {1, 2, 3, 4}}, + }); + while (!op_stack.empty()) { + const int64 op = op_stack.back(); + VLOG(1) << "Popped " << op; + op_stack.pop_back(); + auto op_it = state.op_tape.find(op); + if (op_it == state.op_tape.end()) { + // It is possible for ops to end up on the stack if they are unrelated to + // the target; we should just skip them. + continue; + } + auto trace = std::move(op_it->second); + state.op_tape.erase(op_it); + std::vector out_gradients; + out_gradients.reserve(trace.output_tensor_info.size()); + bool any_gradient_nonzero = false; + for (int i = 0; i < trace.output_tensor_info.size(); ++i) { + const int64 id = trace.output_tensor_info[i].id; + auto grad_it = gradients.find(id); + if (grad_it == gradients.end()) { + auto func_name_it = + functions_accept_none_for_indices.find(trace.op_type); + if (func_name_it != functions_accept_none_for_indices.end() && + func_name_it->second.find(i) != func_name_it->second.end()) { + out_gradients.push_back(nullptr); + } else { + out_gradients.push_back( + vspace.Zeros(trace.output_tensor_info[i].shape, + trace.output_tensor_info[i].dtype)); + } + } else { + any_gradient_nonzero = true; + out_gradients.push_back(vspace.AggregateGradients(grad_it->second)); + if (sources_set.find(grad_it->first) == sources_set.end()) { + gradients.erase(grad_it); + } + } + } + std::vector in_gradients; + if (any_gradient_nonzero) { + Status s = vspace.CallBackwardFunction(trace.backward_function, + out_gradients, &in_gradients); + if (!persistent_) { + vspace.ReleaseBackwardFunction(trace.backward_function); + } + if (!s.ok()) { + cleanup(); + return s; + } + } else { + in_gradients.resize(trace.input_tensor_id.size()); + if (!persistent_) { + vspace.ReleaseBackwardFunction(trace.backward_function); + } + } + VLOG(1) << "Got " << in_gradients.size() << " in_gradients for " + << trace.input_tensor_id.size() << " sources"; + for (int i = 0; i < in_gradients.size(); ++i) { + const int64 id = trace.input_tensor_id[i]; + if (in_gradients[i] != nullptr) { + auto& unaggregated_grads = gradients[id]; + unaggregated_grads.push_back(in_gradients[i]); + if (unaggregated_grads.size() > kMinAggregateCount) { + auto size_it = gradients_size.find(id); + int64 size; + if (size_it == gradients_size.end()) { + size = vspace.NumElements(unaggregated_grads[0]); + gradients_size.emplace(id, size); + } else { + size = size_it->second; + } + if (unaggregated_grads.size() * size * 4 > kMinAggregateBytes) { + Gradient* grad = vspace.AggregateGradients(unaggregated_grads); + unaggregated_grads.clear(); + unaggregated_grads.push_back(grad); + } + } + } + auto usage_count_it = state.tensor_usage_counts.find(id); + if (usage_count_it == state.tensor_usage_counts.end()) { + VLOG(1) << "Tensor " << id << " not used"; + continue; + } + usage_count_it->second--; + if (usage_count_it->second > 0) { + VLOG(1) << "Tensor " << id << " usage count " << usage_count_it->second; + continue; + } + auto tape_it = tensor_tape_.find(id); + if (tape_it == tensor_tape_.end()) { + VLOG(1) << "Tensor " << id + << " has no associated op. Deleting gradient"; + auto grad_it = gradients.find(id); + if (grad_it != gradients.end()) { + for (auto g : grad_it->second) { + vspace.DeleteGradient(g); + } + gradients.erase(grad_it); + } + continue; + } + const int64 op_id = tape_it->second; + if (op_id == -1) { + VLOG(1) << "Tensor " << id << " is source"; + continue; + } + auto missing_it = state.op_missing_tensor.find(op_id); + if (missing_it != state.op_missing_tensor.end()) { + missing_it->second--; + VLOG(1) << "Op " << op_id << " missing " << missing_it->second + << " output gradients"; + if (missing_it->second == 0) { + op_stack.push_back(op_id); + } + } + } + } + CHECK(state.op_tape.empty()); + result->reserve(source_tensor_ids.size()); + for (auto is : source_tensor_ids) { + auto grad_it = gradients.find(is); + if (grad_it == gradients.end()) { + result->push_back(nullptr); + } else { + if (grad_it->second.size() == 1) { + result->push_back(grad_it->second[0]); + } else { + result->push_back(vspace.AggregateGradients(grad_it->second)); + } + gradients.erase(grad_it); + } + } + VLOG(1) << "Final gradients size: " << gradients.size(); + for (auto grad_pair : gradients) { + for (const auto& g : grad_pair.second) { + vspace.DeleteGradient(g); + } + } + return Status::OK(); +} + } // namespace eager } // namespace tensorflow diff --git a/tensorflow/c/python_api.cc b/tensorflow/c/python_api.cc index 0fe85d5d2c60bcd0566f010b23820ec174b7830b..6e37cdb5f4beea53d4a2ded0705ae482d0bc2d68 100644 --- a/tensorflow/c/python_api.cc +++ b/tensorflow/c/python_api.cc @@ -22,18 +22,81 @@ namespace tensorflow { void AddControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input) { mutex_lock l(graph->mu); graph->graph.AddControlEdge(&input->node, &op->node); + RecordMutation(graph, *op, "adding control input"); +} + +void SetAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name, + TF_Buffer* attr_value_proto, TF_Status* status) { + AttrValue attr_val; + if (!attr_val.ParseFromArray(attr_value_proto->data, + attr_value_proto->length)) { + status->status = + tensorflow::errors::InvalidArgument("Invalid AttrValue proto"); + return; + } + + mutex_lock l(graph->mu); + op->node.AddAttr(attr_name, attr_val); + RecordMutation(graph, *op, "setting attribute"); } void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device) { mutex_lock l(graph->mu); op->node.set_requested_device(device); + RecordMutation(graph, *op, "setting device"); } void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst, TF_Status* status) { mutex_lock l(graph->mu); + tensorflow::shape_inference::InferenceContext* ic = + graph->refiner.GetContext(&new_src.oper->node); + + if (ic->num_outputs() <= new_src.index) { + status->status = tensorflow::errors::OutOfRange( + "Cannot update edge. Output index [", new_src.index, + "] is greater than the number of total outputs [", ic->num_outputs(), + "]."); + return; + } + tensorflow::shape_inference::ShapeHandle shape = ic->output(new_src.index); + + tensorflow::shape_inference::InferenceContext* ic_dst = + graph->refiner.GetContext(&dst.oper->node); + if (ic_dst->num_inputs() <= dst.index) { + status->status = tensorflow::errors::OutOfRange( + "Cannot update edge. Input index [", dst.index, + "] is greater than the number of total inputs [", ic_dst->num_inputs(), + "]."); + return; + } + if (!ic_dst->MergeInput(dst.index, shape)) { + status->status = tensorflow::errors::InvalidArgument( + "Cannot update edge, incompatible shapes: ", ic_dst->DebugString(shape), + " and ", ic_dst->DebugString(ic_dst->input(dst.index)), "."); + return; + } status->status = graph->graph.UpdateEdge(&new_src.oper->node, new_src.index, &dst.oper->node, dst.index); + + if (status->status.ok()) { + // This modification only updates the destination node for + // the purposes of running this graph in a session. Thus, we don't + // record the source node as being modified. + RecordMutation(graph, *dst.oper, "updating input tensor"); + } +} + +void RemoveAllControlInputs(TF_Graph* graph, TF_Operation* op) { + mutex_lock l(graph->mu); + std::vector control_edges; + for (const Edge* edge : op->node.in_edges()) { + if (!edge->IsControlEdge()) continue; + control_edges.push_back(edge); + } + for (const Edge* edge : control_edges) { + graph->graph.RemoveControlEdge(edge); + } } } // namespace tensorflow diff --git a/tensorflow/c/python_api.h b/tensorflow/c/python_api.h index ab71a4170bb58df46a3d23585cf256eb656d38d2..b51ef2b53122802fef598a26bd6f1843976f11b0 100644 --- a/tensorflow/c/python_api.h +++ b/tensorflow/c/python_api.h @@ -25,11 +25,18 @@ namespace tensorflow { void AddControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input); +// Changes an attr value in the node_def Protocol Buffer and sets a status upon +// completion. +void SetAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name, + TF_Buffer* attr_value_proto, TF_Status* status); + void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device); void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst, TF_Status* status); +void RemoveAllControlInputs(TF_Graph* graph, TF_Operation* op); + } // namespace tensorflow #endif // THIRD_PARTY_TENSORFLOW_C_PYTHON_API_H_ diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD index 80112f9b44b1d5fd65a7d47788b072dc47a2b29a..e354831d7d25af83c068a68a4f844056263a598c 100644 --- a/tensorflow/cc/BUILD +++ b/tensorflow/cc/BUILD @@ -421,6 +421,7 @@ tf_cc_test( tf_gen_op_wrappers_cc( name = "cc_ops", + api_def_srcs = ["//tensorflow/core:base_api_def"], op_lib_names = [ "array_ops", "audio_ops", @@ -525,6 +526,30 @@ cc_library_with_android_deps( "//tensorflow/core:android_tensorflow_lib", ], copts = tf_copts(), + data = [ + "//tensorflow/core:base_api_def", + ], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:op_gen_lib", + "//tensorflow/core:op_gen_overrides_proto_cc", + "//tensorflow/core:proto_text", + "//tensorflow/core:protos_all_cc", + ], +) + +tf_cc_test( + name = "cc_op_gen_test", + srcs = [ + "framework/cc_op_gen.cc", + "framework/cc_op_gen.h", + "framework/cc_op_gen_test.cc", + ], + data = [ + "//tensorflow/cc:ops/op_gen_overrides.pbtxt", + ], deps = [ "//tensorflow/core:framework", "//tensorflow/core:lib", @@ -533,6 +558,8 @@ cc_library_with_android_deps( "//tensorflow/core:op_gen_overrides_proto_cc", "//tensorflow/core:proto_text", "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", ], ) diff --git a/tensorflow/cc/framework/cc_op_gen.cc b/tensorflow/cc/framework/cc_op_gen.cc index 38a17598b8e4161f96ab8134823de033d3284440..d889c518f9c38a9f070970b37a2ad4b1fc26671b 100644 --- a/tensorflow/cc/framework/cc_op_gen.cc +++ b/tensorflow/cc/framework/cc_op_gen.cc @@ -18,10 +18,11 @@ limitations under the License. #include #include "tensorflow/cc/framework/cc_op_gen.h" +#include "tensorflow/core/framework/api_def.pb.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/op_def_util.h" #include "tensorflow/core/framework/op_gen_lib.h" -#include "tensorflow/core/framework/op_gen_overrides.pb.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.pb_text.h" @@ -35,7 +36,6 @@ limitations under the License. #include "tensorflow/core/public/version.h" namespace tensorflow { - namespace { const int kRightMargin = 79; @@ -297,7 +297,7 @@ string ToCamelCase(const string& str) { // argument to a function. std::pair AttrTypeName(StringPiece attr_type) { static const std::unordered_map, - StringPiece::Hasher> + StringPieceHasher> attr_type_map{ {"string", {"StringPiece", false}}, {"list(string)", {"gtl::ArraySlice", true}}, @@ -325,29 +325,112 @@ std::pair AttrTypeName(StringPiece attr_type) { } bool IsCPPKeyword(StringPiece name) { - static const std::unordered_set + static const std::unordered_set // Keywords obtained from http://en.cppreference.com/w/cpp/keyword kCPPReserved{ - "alignas", "alignof", "and", "and_eq", "asm", "atomic_cancel", - "atomic_commit", "atomic_noexcept", "auto", "bitand", "bitor", "bool", - "break", "case", "catch", "char", "char16_t", "char32_t", "class", - "compl", "concept", "const", "const_cast", "constexpr", "continue", - "decltype", "default", "delete", "do", "double", "dynamic_cast", - "else", "enum", "explicit", "export", "extern", "false", "final", - "float", "for", "friend", "goto", "if", "import", "inline", "int", - "long", "module", "mutable", "namespace", "new", "noexcept", "not", - "not_eq", "nullptr", "operator", "or", "or_eq", "override", "private", - "protected", "public", "register", "reinterpret_cast", "requires", - "return", "short", "signed", "sizeof", "static", "static_assert", - "static_cast", "struct", "switch", "synchronized", "template", "this", - "thread_local", "throw", "true", "try", "typedef", "typeid", - "typename", "union", "unsigned", "using", "virtual", "void", - "volatile", "wchar_t", "while", "xor", "xor_eq", + "alignas", + "alignof", + "and", + "and_eq", + "asm", + "atomic_cancel", + "atomic_commit", + "atomic_noexcept", + "auto", + "bitand", + "bitor", + "bool", + "break", + "case", + "catch", + "char", + "char16_t", + "char32_t", + "class", + "compl", + "concept", + "const", + "const_cast", + "constexpr", + "continue", + "decltype", + "default", + "delete", + "do", + "double", + "dynamic_cast", + "else", + "enum", + "explicit", + "export", + "extern", + "false", + "final", + "float", + "for", + "friend", + "goto", + "if", + "import", + "inline", + "int", + "long", + "module", + "mutable", + "namespace", + "new", + "noexcept", + "not", + "not_eq", + "nullptr", + "operator", + "or", + "or_eq", + "override", + "private", + "protected", + "public", + "register", + "reinterpret_cast", + "requires", + "return", + "short", + "signed", + "sizeof", + "static", + "static_assert", + "static_cast", + "struct", + "switch", + "synchronized", + "template", + "this", + "thread_local", + "throw", + "true", + "try", + "typedef", + "typeid", + "typename", + "union", + "unsigned", + "using", + "virtual", + "void", + "volatile", + "wchar_t", + "while", + "xor", + "xor_eq", // The following are not C++ keywords, but names of local variables // and parameters used in the op constructor. Treating them as // keywords, so that other parameter names don't conflict with these. - "builder", "node", "ret", "scope", "unique_name", + "builder", + "node", + "ret", + "scope", + "unique_name", }; return kCPPReserved.count(name) > 0; } @@ -385,10 +468,10 @@ bool ArgIsList(const OpDef::ArgDef& arg) { } bool HasOptionalAttrs( - const OpDef& op_def, + const ApiDef& api_def, const std::unordered_map& inferred_input_attrs) { - for (int i = 0; i < op_def.attr_size(); ++i) { - const auto& attr(op_def.attr(i)); + for (int i = 0; i < api_def.attr_size(); ++i) { + const auto& attr(api_def.attr(i)); if ((inferred_input_attrs.find(attr.name()) == inferred_input_attrs.end()) && attr.has_default_value()) { @@ -398,12 +481,21 @@ bool HasOptionalAttrs( return false; } +const ApiDef::Arg* FindInputArg(StringPiece name, const ApiDef& api_def) { + for (int i = 0; i < api_def.in_arg_size(); ++i) { + if (api_def.in_arg(i).name() == name) { + return &api_def.in_arg(i); + } + } + return nullptr; +} + struct OpInfo { // graph_op_def: The OpDef used by the runtime, has the names that // must be used when calling NodeBuilder. // interface_op_def: The OpDef used in the interface in the generated // code, with possibly overridden names and defaults. - explicit OpInfo(const OpDef& graph_op_def, const OpDef& inteface_op_def, + explicit OpInfo(const OpDef& graph_op_def, const ApiDef& api_def, const std::vector& aliases); string GetOpAttrStruct() const; string GetConstructorDecl(StringPiece op_name_prefix, @@ -423,74 +515,81 @@ struct OpInfo { string comment; const OpDef& graph_op_def; - const OpDef& op_def; + const ApiDef& api_def; const std::vector& aliases; + // Map from type attribute to corresponding original argument name. std::unordered_map inferred_input_attrs; }; -OpInfo::OpInfo(const OpDef& g_op_def, const OpDef& i_op_def, - const std::vector& a) - : graph_op_def(g_op_def), op_def(i_op_def), aliases(a) { - op_name = op_def.name(); - InferOpAttributes(op_def, &inferred_input_attrs); - has_optional_attrs = HasOptionalAttrs(op_def, inferred_input_attrs); +OpInfo::OpInfo(const OpDef& graph_op_def, const ApiDef& api_def, + const std::vector& aliases) + : graph_op_def(graph_op_def), api_def(api_def), aliases(aliases) { + op_name = api_def.endpoint(0).name(); + InferOpAttributes(graph_op_def, &inferred_input_attrs); + has_optional_attrs = HasOptionalAttrs(api_def, inferred_input_attrs); arg_types.push_back("const ::tensorflow::Scope&"); arg_names.push_back("scope"); - if (op_def.has_deprecation()) { - if (!op_def.summary().empty()) { - comment = strings::StrCat(op_def.summary(), "\n"); + if (graph_op_def.has_deprecation()) { + if (!api_def.summary().empty()) { + comment = strings::StrCat(api_def.summary(), "\n"); } strings::StrAppend(&comment, "DEPRECATED at GraphDef version ", - op_def.deprecation().version(), ":\n", - op_def.deprecation().explanation(), ".\n"); - } else if (op_def.summary().empty()) { + graph_op_def.deprecation().version(), ":\n", + graph_op_def.deprecation().explanation(), ".\n"); + } else if (api_def.summary().empty()) { comment = "TODO: add doc.\n"; } else { - comment = strings::StrCat(op_def.summary(), "\n"); + comment = strings::StrCat(api_def.summary(), "\n"); } - if (!op_def.description().empty()) { - strings::StrAppend(&comment, "\n", op_def.description(), "\n"); + if (!api_def.description().empty()) { + strings::StrAppend(&comment, "\n", api_def.description(), "\n"); } strings::StrAppend(&comment, "\nArguments:\n* scope: A Scope object\n"); // Process inputs - for (int i = 0; i < op_def.input_arg_size(); ++i) { - const auto& arg(op_def.input_arg(i)); + for (int i = 0; i < api_def.arg_order_size(); ++i) { + const auto& arg = *FindInputArg(api_def.arg_order(i), graph_op_def); + const auto& api_def_arg = *FindInputArg(api_def.arg_order(i), api_def); arg_types.push_back(strings::StrCat( "::tensorflow::", ArgIsList(arg) ? "InputList" : "Input")); - arg_names.push_back(AvoidCPPKeywords(arg.name())); + arg_names.push_back(AvoidCPPKeywords(api_def_arg.rename_to())); // TODO(keveman): Include input type information. - StringPiece description = arg.description(); + StringPiece description = api_def_arg.description(); if (!description.empty()) { ConsumeEquals(&description); - strings::StrAppend(&comment, "* ", AvoidCPPKeywords(arg.name()), ": ", - arg.description(), "\n"); + strings::StrAppend(&comment, "* ", + AvoidCPPKeywords(api_def_arg.rename_to()), ": ", + api_def_arg.description(), "\n"); } } // Process attrs string required_attrs_comment; string optional_attrs_comment; - for (int i = 0; i < op_def.attr_size(); ++i) { - const auto& attr(op_def.attr(i)); + for (int i = 0; i < graph_op_def.attr_size(); ++i) { + // ApiDef attributes must be in the same order as in OpDef since + // we initialize ApiDef based on OpDef. + const auto& attr(graph_op_def.attr(i)); + const auto& api_def_attr(api_def.attr(i)); + CHECK_EQ(attr.name(), api_def_attr.name()); // Skip inferred arguments if (inferred_input_attrs.count(attr.name()) > 0) continue; const auto entry = AttrTypeName(attr.type()); const auto attr_type_name = entry.first; const bool use_const = entry.second; - string attr_name = AvoidCPPKeywords(attr.name()); + string attr_name = AvoidCPPKeywords(api_def_attr.rename_to()); string attr_comment; - if (!attr.description().empty()) { + if (!api_def_attr.description().empty()) { // TODO(keveman): Word wrap and indent this, to handle multi-line // descriptions. strings::StrAppend(&attr_comment, "* ", attr_name, ": ", - attr.description(), "\n"); + api_def_attr.description(), "\n"); } - if (attr.has_default_value()) { + if (api_def_attr.has_default_value()) { strings::StrAppend(&optional_attrs_comment, attr_comment); } else { strings::StrAppend(&required_attrs_comment, attr_comment); @@ -508,44 +607,49 @@ OpInfo::OpInfo(const OpDef& g_op_def, const OpDef& i_op_def, } // Process outputs - for (int i = 0; i < op_def.output_arg_size(); ++i) { - const auto& arg = op_def.output_arg(i); + for (int i = 0; i < graph_op_def.output_arg_size(); ++i) { + // ApiDef arguments must be in the same order as in OpDef since + // we initialize ApiDef based on OpDef. + const auto& arg = graph_op_def.output_arg(i); + const auto& api_def_arg(api_def.out_arg(i)); + CHECK_EQ(arg.name(), api_def_arg.name()); + bool is_list = ArgIsList(arg); output_types.push_back( strings::StrCat("::tensorflow::", is_list ? "OutputList" : "Output")); - output_names.push_back(AvoidCPPKeywords(arg.name())); + output_names.push_back(AvoidCPPKeywords(api_def_arg.rename_to())); is_list_output.push_back(is_list); } strings::StrAppend(&comment, "\nReturns:\n"); - if (op_def.output_arg_size() == 0) { // No outputs. + if (graph_op_def.output_arg_size() == 0) { // No outputs. strings::StrAppend(&comment, "* the created `Operation`\n"); - } else if (op_def.output_arg_size() == 1) { // One output + } else if (graph_op_def.output_arg_size() == 1) { // One output if (is_list_output[0]) { strings::StrAppend(&comment, "* `OutputList`: "); } else { strings::StrAppend(&comment, "* `Output`: "); } - if (op_def.output_arg(0).description().empty()) { - strings::StrAppend(&comment, "The ", op_def.output_arg(0).name(), + if (api_def.out_arg(0).description().empty()) { + strings::StrAppend(&comment, "The ", api_def.out_arg(0).name(), " tensor.\n"); } else { // TODO(josh11b): Word wrap this. - strings::StrAppend(&comment, op_def.output_arg(0).description(), "\n"); + strings::StrAppend(&comment, api_def.out_arg(0).description(), "\n"); } } else { // Multiple outputs. - for (int i = 0; i < op_def.output_arg_size(); ++i) { + for (int i = 0; i < graph_op_def.output_arg_size(); ++i) { if (is_list_output[i]) { strings::StrAppend(&comment, "* `OutputList`"); } else { strings::StrAppend(&comment, "* `Output`"); } strings::StrAppend(&comment, " ", output_names[i]); - if (op_def.output_arg(i).description().empty()) { + if (api_def.out_arg(i).description().empty()) { strings::StrAppend(&comment, "\n"); } else { // TODO(josh11b): Word wrap this. - strings::StrAppend(&comment, ": ", op_def.output_arg(i).description(), + strings::StrAppend(&comment, ": ", api_def.out_arg(i).description(), "\n"); } } @@ -564,19 +668,20 @@ string OpInfo::GetOpAttrStruct() const { string struct_fields; string setters; - for (int i = 0; i < op_def.attr_size(); ++i) { - const auto& attr(op_def.attr(i)); + for (int i = 0; i < graph_op_def.attr_size(); ++i) { + const auto& attr(graph_op_def.attr(i)); + const auto& api_def_attr(api_def.attr(i)); // If attr will be inferred or it doesn't have a default value, don't // add it to the struct. if ((inferred_input_attrs.find(attr.name()) != inferred_input_attrs.end()) || - !attr.has_default_value()) { + !api_def_attr.has_default_value()) { continue; } const auto entry = AttrTypeName(attr.type()); const auto attr_type_name = entry.first; const bool use_const = entry.second; - const string camel_case_name = ToCamelCase(attr.name()); + const string camel_case_name = ToCamelCase(api_def_attr.rename_to()); const string suffix = (camel_case_name == op_name || camel_case_name == "Attrs") ? "_" : ""; const string attr_func_def = @@ -584,22 +689,25 @@ string OpInfo::GetOpAttrStruct() const { attr_type_name, use_const ? "&" : ""); string attr_comment; - if (!attr.description().empty()) { - strings::StrAppend(&attr_comment, attr.description(), "\n\n"); + if (!api_def_attr.description().empty()) { + strings::StrAppend(&attr_comment, api_def_attr.description(), "\n\n"); } strings::StrAppend(&attr_comment, "Defaults to ", - SummarizeAttrValue(attr.default_value()), "\n"); + SummarizeAttrValue(api_def_attr.default_value()), "\n"); attr_comment = MakeComment(attr_comment, " "); strings::StrAppend(&setters, attr_comment); strings::StrAppend(&setters, " Attrs ", attr_func_def, " x) {\n"); strings::StrAppend(&setters, " Attrs ret = *this;\n"); - strings::StrAppend(&setters, " ret.", attr.name(), "_ = x;\n"); + strings::StrAppend(&setters, " ret.", api_def_attr.rename_to(), + "_ = x;\n"); strings::StrAppend(&setters, " return ret;\n }\n\n"); strings::StrAppend( - &struct_fields, " ", attr_type_name, " ", attr.name(), "_ = ", - PrintAttrValue(op_def.name(), attr.default_value()), ";\n"); + &struct_fields, " ", attr_type_name, " ", api_def_attr.rename_to(), + "_ = ", + PrintAttrValue(graph_op_def.name(), api_def_attr.default_value()), + ";\n"); } if (struct_fields.empty()) { @@ -676,17 +784,18 @@ void OpInfo::WriteClassDecl(WritableFile* h) const { // Add the static functions to set optional attrs if (has_optional_attrs) { strings::StrAppend(&class_decl, "\n"); - for (int i = 0; i < op_def.attr_size(); ++i) { - const auto& attr(op_def.attr(i)); + for (int i = 0; i < graph_op_def.attr_size(); ++i) { + const auto& attr(graph_op_def.attr(i)); + const auto& api_def_attr(api_def.attr(i)); if ((inferred_input_attrs.find(attr.name()) != inferred_input_attrs.end()) || - !attr.has_default_value()) { + !api_def_attr.has_default_value()) { continue; } const auto entry = AttrTypeName(attr.type()); const auto attr_type_name = entry.first; const bool use_const = entry.second; - const string camel_case_name = ToCamelCase(attr.name()); + const string camel_case_name = ToCamelCase(api_def_attr.rename_to()); const string suffix = (camel_case_name == op_name || camel_case_name == "Attrs") ? "_" : ""; const string attr_func_def = strings::StrCat( @@ -726,11 +835,11 @@ void OpInfo::GetOutput(string* out) const { strings::StrCat("if (!", scope_str, ".ok()) return;"); // No outputs. - if (op_def.output_arg_size() == 0) { + if (graph_op_def.output_arg_size() == 0) { strings::StrAppend(out, " this->operation = Operation(ret);\n return;\n"); return; } - if (op_def.output_arg_size() == 1) { + if (graph_op_def.output_arg_size() == 1) { // One output, no need for NameRangeMap if (is_list_output[0]) { strings::StrAppend(out, @@ -752,7 +861,7 @@ void OpInfo::GetOutput(string* out) const { ".UpdateStatus(_status_);\n", " return;\n"); strings::StrAppend(out, " }\n\n"); - for (int i = 0; i < op_def.output_arg_size(); ++i) { + for (int i = 0; i < graph_op_def.output_arg_size(); ++i) { const string arg_range = strings::StrCat( "_outputs_range[\"", graph_op_def.output_arg(i).name(), "\"]"); if (is_list_output[i]) { @@ -776,11 +885,13 @@ string OpInfo::GetConstructorBody() const { strings::StrAppend(&body, " ", return_on_error, "\n"); - for (int i = 0; i < op_def.input_arg_size(); ++i) { - const auto& arg(op_def.input_arg(i)); - strings::StrAppend(&body, " auto _", arg.name(), " = ::tensorflow::ops::", - ArgIsList(arg) ? "AsNodeOutList" : "AsNodeOut", "(", - scope_str, ", ", AvoidCPPKeywords(arg.name()), ");\n"); + for (int i = 0; i < graph_op_def.input_arg_size(); ++i) { + const auto& arg(graph_op_def.input_arg(i)); + const auto& api_def_arg(api_def.in_arg(i)); + strings::StrAppend( + &body, " auto _", api_def_arg.rename_to(), " = ::tensorflow::ops::", + ArgIsList(arg) ? "AsNodeOutList" : "AsNodeOut", "(", scope_str, ", ", + AvoidCPPKeywords(api_def_arg.rename_to()), ");\n"); strings::StrAppend(&body, " ", return_on_error, "\n"); } @@ -791,19 +902,21 @@ string OpInfo::GetConstructorBody() const { &body, " auto builder = ::tensorflow::NodeBuilder(unique_name, \"", graph_op_def.name(), "\")\n"); const string spaces = " "; - for (int i = 0; i < op_def.input_arg_size(); ++i) { - const auto& arg(op_def.input_arg(i)); - strings::StrAppend(&body, spaces, ".Input(_", arg.name(), ")\n"); + for (int i = 0; i < api_def.in_arg_size(); ++i) { + const auto& arg(api_def.in_arg(i)); + strings::StrAppend(&body, spaces, ".Input(_", arg.rename_to(), ")\n"); } - for (int i = 0; i < op_def.attr_size(); ++i) { + for (int i = 0; i < api_def.attr_size(); ++i) { const auto& graph_attr(graph_op_def.attr(i)); - const auto& attr(op_def.attr(i)); - if (inferred_input_attrs.find(attr.name()) != inferred_input_attrs.end()) { + const auto& api_def_attr(api_def.attr(i)); + if (inferred_input_attrs.find(api_def_attr.name()) != + inferred_input_attrs.end()) { continue; } - const string attr_name = attr.has_default_value() - ? strings::StrCat("attrs.", attr.name(), "_") - : AvoidCPPKeywords(attr.name()); + const string attr_name = + api_def_attr.has_default_value() + ? strings::StrCat("attrs.", api_def_attr.rename_to(), "_") + : AvoidCPPKeywords(api_def_attr.rename_to()); strings::StrAppend(&body, spaces, ".Attr(\"", graph_attr.name(), "\", ", attr_name, ")\n"); } @@ -845,10 +958,10 @@ void OpInfo::WriteClassDef(WritableFile* cc) const { TF_CHECK_OK(cc->Append(class_def)); } -void WriteCCOp(const OpDef& graph_op_def, const OpDef& interface_op_def, +void WriteCCOp(const OpDef& graph_op_def, const ApiDef& api_def, const std::vector& aliases, WritableFile* h, WritableFile* cc) { - OpInfo op_info(graph_op_def, interface_op_def, aliases); + OpInfo op_info(graph_op_def, api_def, aliases); op_info.WriteClassDecl(h); op_info.WriteClassDef(cc); @@ -943,8 +1056,9 @@ string MakeInternal(const string& fname) { } // namespace -void WriteCCOps(const OpList& ops, const string& dot_h_fname, - const string& dot_cc_fname, const string& overrides_fnames) { +void WriteCCOps(const OpList& ops, const ApiDefMap& api_def_map, + const string& dot_h_fname, const string& dot_cc_fname, + const string& overrides_fnames) { Env* env = Env::Default(); // Load the override map. @@ -984,24 +1098,23 @@ void WriteCCOps(const OpList& ops, const string& dot_h_fname, // code depends on it. if (graph_op_def.name() == "Const") continue; - // Incorporate overrides from override_map. - OpDef interface_op_def = graph_op_def; - const OpGenOverride* op_override = - override_map.ApplyOverride(&interface_op_def); + const auto* api_def = api_def_map.GetApiDef(graph_op_def.name()); + std::vector aliases; - if (op_override) { - if (op_override->skip()) continue; - aliases.assign(op_override->alias().begin(), op_override->alias().end()); - if (op_override->hide()) { - // Write hidden ops to _internal.h and _internal.cc. - WriteCCOp(graph_op_def, interface_op_def, aliases, internal_h.get(), - internal_cc.get()); - continue; - } + if (api_def->visibility() == ApiDef::SKIP) continue; + // First endpoint is canonical, the rest are aliases. + for (int endpoint_i = 1; endpoint_i < api_def->endpoint_size(); + ++endpoint_i) { + aliases.push_back(api_def->endpoint(endpoint_i).name()); + } + if (api_def->visibility() == ApiDef::HIDDEN) { + // Write hidden ops to _internal.h and _internal.cc. + WriteCCOp(graph_op_def, *api_def, aliases, internal_h.get(), + internal_cc.get()); + continue; } - // This isn't a hidden op, write it to the main files. - WriteCCOp(graph_op_def, interface_op_def, aliases, h.get(), cc.get()); + WriteCCOp(graph_op_def, *api_def, aliases, h.get(), cc.get()); } FinishFiles(false, h.get(), cc.get(), op_header_guard); diff --git a/tensorflow/cc/framework/cc_op_gen.h b/tensorflow/cc/framework/cc_op_gen.h index fa5e004f0317d046d82bee005bdf9f17773a45f3..cea28990144b9371e8009ce13f912b44044f9aac 100644 --- a/tensorflow/cc/framework/cc_op_gen.h +++ b/tensorflow/cc/framework/cc_op_gen.h @@ -17,13 +17,15 @@ limitations under the License. #define THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_CC_OP_GEN_H_ #include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/op_gen_lib.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { /// Result is written to files dot_h and dot_cc. -void WriteCCOps(const OpList& ops, const string& dot_h_fname, - const string& dot_cc_fname, const string& overrides_fnames); +void WriteCCOps(const OpList& ops, const ApiDefMap& api_def_map, + const string& dot_h_fname, const string& dot_cc_fname, + const string& overrides_fnames); } // namespace tensorflow diff --git a/tensorflow/cc/framework/cc_op_gen_main.cc b/tensorflow/cc/framework/cc_op_gen_main.cc index 3b80cf993eb9a5d5f4c41687577414e7216dd174..326d5668b8803ee39ffe24900c92e1db87b93601 100644 --- a/tensorflow/cc/framework/cc_op_gen_main.cc +++ b/tensorflow/cc/framework/cc_op_gen_main.cc @@ -16,7 +16,11 @@ limitations under the License. #include "tensorflow/cc/framework/cc_op_gen.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/op_gen_lib.h" #include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/types.h" @@ -24,10 +28,28 @@ namespace tensorflow { namespace { void PrintAllCCOps(const std::string& dot_h, const std::string& dot_cc, - const std::string& overrides_fnames, bool include_internal) { + const std::string& overrides_fnames, bool include_internal, + const std::vector& api_def_dirs) { OpList ops; OpRegistry::Global()->Export(include_internal, &ops); - WriteCCOps(ops, dot_h, dot_cc, overrides_fnames); + ApiDefMap api_def_map(ops); + if (!api_def_dirs.empty()) { + Env* env = Env::Default(); + // Only load files that correspond to "ops". + for (const auto& op : ops.op()) { + for (const auto& api_def_dir : api_def_dirs) { + const std::string api_def_file_pattern = + io::JoinPath(api_def_dir, "api_def_" + op.name() + ".pbtxt"); + if (env->FileExists(api_def_file_pattern).ok()) { + TF_CHECK_OK(api_def_map.LoadFile(env, api_def_file_pattern)); + } + } + } + } + + api_def_map.UpdateDocs(); + + WriteCCOps(ops, api_def_map, dot_h, dot_cc, overrides_fnames); } } // namespace @@ -35,18 +57,24 @@ void PrintAllCCOps(const std::string& dot_h, const std::string& dot_cc, int main(int argc, char* argv[]) { tensorflow::port::InitMain(argv[0], &argc, &argv); - if (argc != 5) { + // TODO(annarev): Update this file to no longer take op_gen_overrides.pbtxt + // as an argument. + if (argc != 6) { for (int i = 1; i < argc; ++i) { fprintf(stderr, "Arg %d = %s\n", i, argv[i]); } fprintf(stderr, - "Usage: %s out.h out.cc overrides1.pbtxt,2.pbtxt include_internal\n" + "Usage: %s out.h out.cc overrides1.pbtxt,2.pbtxt include_internal " + "api_def_dirs1,api_def_dir2 ...\n" " include_internal: 1 means include internal ops\n", argv[0]); exit(1); } bool include_internal = tensorflow::StringPiece("1") == argv[4]; - tensorflow::PrintAllCCOps(argv[1], argv[2], argv[3], include_internal); + std::vector api_def_dirs = tensorflow::str_util::Split( + argv[5], ",", tensorflow::str_util::SkipEmpty()); + tensorflow::PrintAllCCOps(argv[1], argv[2], argv[3], include_internal, + api_def_dirs); return 0; } diff --git a/tensorflow/cc/framework/cc_op_gen_test.cc b/tensorflow/cc/framework/cc_op_gen_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..0b7e720a5c7b343415eee1aa157b8de755a1e1a5 --- /dev/null +++ b/tensorflow/cc/framework/cc_op_gen_test.cc @@ -0,0 +1,195 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/framework/cc_op_gen.h" + +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/op_gen_lib.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +// TODO(annarev): Remove this op_gen_overrides.pbtxt reference. +// It is needed only because WriteCCOps takes it as an argument. +constexpr char kOverridesFnames[] = + "tensorflow/cc/ops/op_gen_overrides.pbtxt"; +constexpr char kBaseOpDef[] = R"( +op { + name: "Foo" + input_arg { + name: "images" + description: "Images to process." + } + input_arg { + name: "dim" + description: "Description for dim." + type: DT_FLOAT + } + output_arg { + name: "output" + description: "Description for output." + type: DT_FLOAT + } + attr { + name: "T" + type: "type" + description: "Type for images" + allowed_values { + list { + type: DT_UINT8 + type: DT_INT8 + } + } + default_value { + i: 1 + } + } + summary: "Summary for op Foo." + description: "Description for op Foo." +} +)"; + +void ExpectHasSubstr(StringPiece s, StringPiece expected) { + EXPECT_TRUE(s.contains(expected)) + << "'" << s << "' does not contain '" << expected << "'"; +} + +void ExpectDoesNotHaveSubstr(StringPiece s, StringPiece expected) { + EXPECT_FALSE(s.contains(expected)) + << "'" << s << "' contains '" << expected << "'"; +} + +void ExpectSubstrOrder(const string& s, const string& before, + const string& after) { + int before_pos = s.find(before); + int after_pos = s.find(after); + ASSERT_NE(std::string::npos, before_pos); + ASSERT_NE(std::string::npos, after_pos); + EXPECT_LT(before_pos, after_pos) + << before << " is not before " << after << " in " << s; +} + +// Runs WriteCCOps and stores output in (internal_)cc_file_path and +// (internal_)h_file_path. +void GenerateCcOpFiles(Env* env, const OpList& ops, + const ApiDefMap& api_def_map, string* h_file_text, + string* internal_h_file_text) { + const string& tmpdir = testing::TmpDir(); + + const auto h_file_path = io::JoinPath(tmpdir, "test.h"); + const auto cc_file_path = io::JoinPath(tmpdir, "test.cc"); + const auto internal_h_file_path = io::JoinPath(tmpdir, "test_internal.h"); + const auto internal_cc_file_path = io::JoinPath(tmpdir, "test_internal.cc"); + + WriteCCOps(ops, api_def_map, h_file_path, cc_file_path, kOverridesFnames); + + TF_ASSERT_OK(ReadFileToString(env, h_file_path, h_file_text)); + TF_ASSERT_OK( + ReadFileToString(env, internal_h_file_path, internal_h_file_text)); +} + +TEST(CcOpGenTest, TestVisibilityChangedToHidden) { + const string api_def = R"( +op { + graph_op_name: "Foo" + visibility: HIDDEN +} +)"; + Env* env = Env::Default(); + OpList op_defs; + protobuf::TextFormat::ParseFromString(kBaseOpDef, &op_defs); // NOLINT + ApiDefMap api_def_map(op_defs); + + string h_file_text, internal_h_file_text; + // Without ApiDef + GenerateCcOpFiles(env, op_defs, api_def_map, &h_file_text, + &internal_h_file_text); + ExpectHasSubstr(h_file_text, "class Foo"); + ExpectDoesNotHaveSubstr(internal_h_file_text, "class Foo"); + + // With ApiDef + TF_ASSERT_OK(api_def_map.LoadApiDef(api_def)); + GenerateCcOpFiles(env, op_defs, api_def_map, &h_file_text, + &internal_h_file_text); + ExpectHasSubstr(internal_h_file_text, "class Foo"); + ExpectDoesNotHaveSubstr(h_file_text, "class Foo"); +} + +TEST(CcOpGenTest, TestArgNameChanges) { + const string api_def = R"( +op { + graph_op_name: "Foo" + arg_order: "dim" + arg_order: "images" +} +)"; + Env* env = Env::Default(); + OpList op_defs; + protobuf::TextFormat::ParseFromString(kBaseOpDef, &op_defs); // NOLINT + + ApiDefMap api_def_map(op_defs); + string cc_file_text, h_file_text; + string internal_cc_file_text, internal_h_file_text; + // Without ApiDef + GenerateCcOpFiles(env, op_defs, api_def_map, &h_file_text, + &internal_h_file_text); + ExpectSubstrOrder(h_file_text, "Input images", "Input dim"); + + // With ApiDef + TF_ASSERT_OK(api_def_map.LoadApiDef(api_def)); + GenerateCcOpFiles(env, op_defs, api_def_map, &h_file_text, + &internal_h_file_text); + ExpectSubstrOrder(h_file_text, "Input dim", "Input images"); +} + +TEST(CcOpGenTest, TestEndpoints) { + const string api_def = R"( +op { + graph_op_name: "Foo" + endpoint { + name: "Foo1" + } + endpoint { + name: "Foo2" + } +} +)"; + Env* env = Env::Default(); + OpList op_defs; + protobuf::TextFormat::ParseFromString(kBaseOpDef, &op_defs); // NOLINT + + ApiDefMap api_def_map(op_defs); + string cc_file_text, h_file_text; + string internal_cc_file_text, internal_h_file_text; + // Without ApiDef + GenerateCcOpFiles(env, op_defs, api_def_map, &h_file_text, + &internal_h_file_text); + ExpectHasSubstr(h_file_text, "class Foo {"); + ExpectDoesNotHaveSubstr(h_file_text, "class Foo1"); + ExpectDoesNotHaveSubstr(h_file_text, "class Foo2"); + + // With ApiDef + TF_ASSERT_OK(api_def_map.LoadApiDef(api_def)); + GenerateCcOpFiles(env, op_defs, api_def_map, &h_file_text, + &internal_h_file_text); + ExpectHasSubstr(h_file_text, "class Foo1"); + ExpectHasSubstr(h_file_text, "typedef Foo1 Foo2"); + ExpectDoesNotHaveSubstr(h_file_text, "class Foo {"); +} +} // namespace +} // namespace tensorflow diff --git a/tensorflow/cc/gradients/nn_grad.cc b/tensorflow/cc/gradients/nn_grad.cc index 09fadfcab51575798286876f9a4e0ee9a60940ac..13a3bba5e6d5ca19ff3f0eca76665ba7d3ab628d 100644 --- a/tensorflow/cc/gradients/nn_grad.cc +++ b/tensorflow/cc/gradients/nn_grad.cc @@ -196,6 +196,18 @@ Status MaxPoolGradV2Helper(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("MaxPoolV2", MaxPoolGradV2Helper); +Status LRNGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs){ + internal::LRNGrad::Attrs grad_attrs; + + auto dx = internal::LRNGrad(scope, grad_inputs[0], op.input(0), op.output(0), + grad_attrs); + grad_outputs->push_back(dx); + return scope.status(); +} +REGISTER_GRADIENT_OP("LRN", LRNGradHelper); + } // anonymous namespace } // namespace ops } // namespace tensorflow diff --git a/tensorflow/cc/gradients/nn_grad_test.cc b/tensorflow/cc/gradients/nn_grad_test.cc index ac66f51cf01911957722e94ca28e8e78dc6de2ed..f9063e836509669d81d03b1d2f0d32d1166b6eca 100644 --- a/tensorflow/cc/gradients/nn_grad_test.cc +++ b/tensorflow/cc/gradients/nn_grad_test.cc @@ -191,5 +191,12 @@ TEST_F(NNGradTest, MaxPoolGradV2Helper) { RunTest(x, x_init_value, y, y_shape); } +TEST_F(NNGradTest, LRN){ + TensorShape x_shape({1, 1, 2, 1}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); + auto y = LRN(scope_, x); + RunTest(x, x_shape, y, x_shape); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/cc/ops/op_gen_overrides.pbtxt b/tensorflow/cc/ops/op_gen_overrides.pbtxt index 0184c82c5afc99990530b902efdf670a2bdbc4bc..4aac990e748b0a79cbc3b353b4121a582b0883b0 100644 --- a/tensorflow/cc/ops/op_gen_overrides.pbtxt +++ b/tensorflow/cc/ops/op_gen_overrides.pbtxt @@ -11,7 +11,7 @@ op { name: "Reverse" skip: true } op { name: "ReverseV2" rename_to: "Reverse" } op { name: "Split" input_rename: { from: "split_dim" to: "axis" } } op { name: "SplitV" input_rename: { from: "split_dim" to: "axis" } } -op { name: "Squeeze" input_rename: { from: "squeeze_dims" to: "axis" } } +op { name: "Squeeze" attr_rename: { from: "squeeze_dims" to: "axis" } } op { name: "Pack" rename_to: "Stack" } op { name: "Unpack" rename_to: "Unstack" } op { name: "Select" rename_to: "Where3" input_rename: { from: "t" to: "x" } input_rename: { from: "e" to: "y" } } diff --git a/tensorflow/cc/saved_model/tag_constants.h b/tensorflow/cc/saved_model/tag_constants.h index 2b0b2d5c7fb33768494c1781669c1adcb875a579..b71cb263ca42dab7e830c1880ec4b311bc272f82 100644 --- a/tensorflow/cc/saved_model/tag_constants.h +++ b/tensorflow/cc/saved_model/tag_constants.h @@ -21,6 +21,9 @@ namespace tensorflow { /// Tag for the `gpu` graph. constexpr char kSavedModelTagGpu[] = "gpu"; +/// Tag for the `tpu` graph. +constexpr char kSavedModelTagTpu[] = "tpu"; + /// Tag for the `serving` graph. constexpr char kSavedModelTagServe[] = "serve"; diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD index a9a6ea84319a18a8fbce648391bf5918ff6d9a08..5740c040e309bad8d7e3bdc468c09a3323fb99e0 100644 --- a/tensorflow/compiler/aot/BUILD +++ b/tensorflow/compiler/aot/BUILD @@ -24,7 +24,6 @@ tf_cc_test( srcs = ["runtime_test.cc"], deps = [ ":runtime", - "//tensorflow/compiler/tf2xla:xla_local_runtime_context", "//tensorflow/core:framework", "//tensorflow/core:test", "//tensorflow/core:test_main", @@ -111,6 +110,7 @@ cc_library( "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", + "//tensorflow/core:graph", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", ], diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc index ae22f7edc423247b34895411d19d7a3c21f86d4f..53da2881b60db9ad39565567623eb86f754559af 100644 --- a/tensorflow/compiler/aot/codegen.cc +++ b/tensorflow/compiler/aot/codegen.cc @@ -101,21 +101,8 @@ Status ComputeArgSizes(const CompileResult& compile_result, std::vector* arg_sizes) { const xla::ProgramShape& ps = compile_result.program_shape; for (int i = 0; i < ps.parameters_size(); ++i) { - if (i == ps.parameters_size() - 1 && compile_result.has_context_arg) { - // If the compiled function needs a XlaLocalRuntimeContext* arg, it's - // always last, and must be represented as an opaque type. - const xla::PrimitiveType type = ps.parameters(i).element_type(); - if (type != xla::OPAQUE) { - return errors::InvalidArgument( - "expected final context arg to be opaque, but got type: ", - xla::PrimitiveType_Name(type), ", from program shape: ", - xla::ShapeUtil::HumanString(ps)); - } - arg_sizes->push_back(-1); - } else { - arg_sizes->push_back(xla::ShapeUtil::ByteSizeOf( - ps.parameters(i), compile_result.pointer_size)); - } + arg_sizes->push_back(xla::ShapeUtil::ByteSizeOf( + ps.parameters(i), compile_result.pointer_size)); } return Status::OK(); } @@ -165,11 +152,6 @@ string RewriteWithName(const string& name, string code, Status GenArgMethods(const tf2xla::Config& config, const xla::ProgramShape& ps, const CompileResult& compile_result, string* methods) { size_t num_args = ps.parameters_size(); - if (compile_result.has_context_arg) { - // If the compiled function needs a XlaLocalRuntimeContext* arg, it's - // always last, and is set in the class constructor. - num_args--; - } if (config.feed_size() != num_args) { return errors::InvalidArgument("mismatch between feed_size(", config.feed_size(), ") and num_args(", @@ -418,7 +400,7 @@ namespace xla { class ExecutableRunOptions; } // (Implementation detail) Entry point to the function in the object file. extern "C" void {{ENTRY}}( void* result, const xla::ExecutableRunOptions* run_options, - const void** args, void** temps); + const void** args, void** temps, tensorflow::int64* profile_counters); {{NS_START}} // {{CLASS}} represents a computation previously specified in a @@ -474,7 +456,6 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { data->temp_sizes = TempSizes(); data->num_temps = kNumTemps; data->result_index = kResultIndex; - data->requires_runtime_context = {{HAS_CONTEXT_ARG}}; data->arg_names = StaticArgNames(); data->result_names = StaticResultNames(); data->program_shape = StaticProgramShape(); @@ -483,7 +464,7 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { return *kStaticData; } - {{CLASS}}(AllocMode alloc_mode = AllocMode::ARGS_RESULTS_AND_TEMPS) + {{CLASS}}(AllocMode alloc_mode = AllocMode::ARGS_RESULTS_PROFILES_AND_TEMPS) : XlaCompiledCpuFunction(StaticData(), alloc_mode) {} {{CLASS}}(const {{CLASS}}&) = delete; @@ -496,8 +477,8 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { // void set_argN_data(void* data) // Sets the buffer of type T for positional argument N. May be called in // any AllocMode. Must be called before Run to have an affect. Must be - // called in AllocMode::RESULTS_AND_TEMPS_ONLY for each positional argument, - // to set the argument buffers. + // called in AllocMode::RESULTS_PROFILES_AND_TEMPS_ONLY for each positional + // argument, to set the argument buffers. // // T* argN_data() // Returns the buffer of type T for positional argument N. @@ -560,8 +541,6 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { {"{{ARG_SIZES}}", str_util::Join(arg_sizes, ", ")}, {"{{CLASS}}", opts.class_name}, {"{{ENTRY}}", compile_result.entry_point}, - {"{{HAS_CONTEXT_ARG}}", - compile_result.has_context_arg ? "true" : "false"}, {"{{INCLUDE_XLA_DATA_PROTO}}", include_xla_data_proto}, {"{{METHODS_ARG}}\n", methods_arg}, {"{{METHODS_RESULT}}\n", methods_result}, diff --git a/tensorflow/compiler/aot/codegen_test.cc b/tensorflow/compiler/aot/codegen_test.cc index 0f6114666fcc89c631434527d2ae8c92c039ffea..75026c57c04a64186a1e5be6c41e4dd7de8520b7 100644 --- a/tensorflow/compiler/aot/codegen_test.cc +++ b/tensorflow/compiler/aot/codegen_test.cc @@ -145,11 +145,9 @@ TEST(GenerateHeader, Golden) { { xla::ShapeUtil::MakeShape(xla::F32, {1, 2}), xla::ShapeUtil::MakeShape(xla::S64, {3, 4}), - xla::ShapeUtil::MakeOpaqueShape(), }, xla::ShapeUtil::MakeTupleShape( {xla::ShapeUtil::MakeShape(xla::U32, {5, 6})})); - compile_result.has_context_arg = true; compile_result.entry_point = "entry_point"; compile_result.pointer_size = 8; string header; diff --git a/tensorflow/compiler/aot/codegen_test_h.golden b/tensorflow/compiler/aot/codegen_test_h.golden index 65f342ce27ef09092f252f791973f245a8cdd6f3..35e50433d63a549bc6fb6a2be9015d7c471509d0 100644 --- a/tensorflow/compiler/aot/codegen_test_h.golden +++ b/tensorflow/compiler/aot/codegen_test_h.golden @@ -19,7 +19,7 @@ namespace xla { class ExecutableRunOptions; } // (Implementation detail) Entry point to the function in the object file. extern "C" void entry_point( void* result, const xla::ExecutableRunOptions* run_options, - const void** args, void** temps); + const void** args, void** temps, tensorflow::int64* profile_counters); namespace foo { namespace bar { @@ -48,7 +48,7 @@ namespace bar { // is guaranteed that no thread may call a non-const method. // // The logical function signature is: -// ((unknown): f32[1,2], (unknown): s64[3,4], (unknown): opaque[]) -> (u32[5,6]) +// ((unknown): f32[1,2], (unknown): s64[3,4]) -> (u32[5,6]) // // Memory stats: // arg bytes total: 104 @@ -58,11 +58,11 @@ namespace bar { class MyClass : public tensorflow::XlaCompiledCpuFunction { public: // Number of input arguments for the compiled computation. - static constexpr size_t kNumArgs = 3; + static constexpr size_t kNumArgs = 2; // Byte size of each argument buffer. There are kNumArgs entries. static const intptr_t* ArgSizes() { - static constexpr intptr_t kArgSizes[kNumArgs] = {8, 96, -1}; + static constexpr intptr_t kArgSizes[kNumArgs] = {8, 96}; return kArgSizes; } @@ -77,7 +77,6 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction { data->temp_sizes = TempSizes(); data->num_temps = kNumTemps; data->result_index = kResultIndex; - data->requires_runtime_context = true; data->arg_names = StaticArgNames(); data->result_names = StaticResultNames(); data->program_shape = StaticProgramShape(); @@ -86,7 +85,7 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction { return *kStaticData; } - MyClass(AllocMode alloc_mode = AllocMode::ARGS_RESULTS_AND_TEMPS) + MyClass(AllocMode alloc_mode = AllocMode::ARGS_RESULTS_PROFILES_AND_TEMPS) : XlaCompiledCpuFunction(StaticData(), alloc_mode) {} MyClass(const MyClass&) = delete; @@ -99,8 +98,8 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction { // void set_argN_data(void* data) // Sets the buffer of type T for positional argument N. May be called in // any AllocMode. Must be called before Run to have an affect. Must be - // called in AllocMode::RESULTS_AND_TEMPS_ONLY for each positional argument, - // to set the argument buffers. + // called in AllocMode::RESULTS_PROFILES_AND_TEMPS_ONLY for each positional + // argument, to set the argument buffers. // // T* argN_data() // Returns the buffer of type T for positional argument N. @@ -236,8 +235,8 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction { // Shape of the args and results. static const xla::ProgramShape* StaticProgramShape() { static const xla::ProgramShape* kShape = []() { - static const char kProto[] = {10,12,16,11,26,2,1,2,42,4,10,2,1,0,10,12,16,5,26,2,3,4,42,4,10,2,1,0,10,2,16,14,18,16,16,13,34,12,16,8,26,2,5,6,42,4,10,2,1,0}; - static constexpr int kProtoSize = 50; + static const char kProto[] = {10,12,16,11,26,2,1,2,42,4,10,2,1,0,10,12,16,5,26,2,3,4,42,4,10,2,1,0,18,16,16,13,34,12,16,8,26,2,5,6,42,4,10,2,1,0}; + static constexpr int kProtoSize = 46; xla::ProgramShape* shape = new xla::ProgramShape; shape->ParseFromArray(kProto, kProtoSize); return shape; diff --git a/tensorflow/compiler/aot/compile.cc b/tensorflow/compiler/aot/compile.cc index 2b8cc6024cb85e4f6269313927ff66d1d9a1cf79..c87f2b75dfa18ad5c3eda4bd6fcbcb3083ef73fd 100644 --- a/tensorflow/compiler/aot/compile.cc +++ b/tensorflow/compiler/aot/compile.cc @@ -94,9 +94,8 @@ Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config, xla::ClientLibrary::GetOrCreateCompileOnlyClient(cpu_platform) .ValueOrDie(); xla::Computation computation; - TF_RETURN_IF_ERROR(ConvertGraphDefToXla(graph_def, config, client, - &computation, - &compile_result->has_context_arg)); + TF_RETURN_IF_ERROR( + ConvertGraphDefToXla(graph_def, config, client, &computation)); if (!flags.out_session_module.empty()) { TF_ASSIGN_OR_RETURN(std::unique_ptr module, computation.Snapshot()); diff --git a/tensorflow/compiler/aot/compile.h b/tensorflow/compiler/aot/compile.h index 965c2960816b3acc8d2209e6824d88647de0ce14..e03c5b1aa77c1262ed903aae3072ef65f34d80a2 100644 --- a/tensorflow/compiler/aot/compile.h +++ b/tensorflow/compiler/aot/compile.h @@ -34,7 +34,6 @@ struct CompileResult { // Contains object file and meta-info. std::unique_ptr aot; xla::ProgramShape program_shape; // Static shape of args and results. - bool has_context_arg = false; // Is last arg XlaLocalRuntimeContext? string entry_point; // Name of generated function. int pointer_size = 0; // Size of a pointer in bytes. }; diff --git a/tensorflow/compiler/aot/runtime_test.cc b/tensorflow/compiler/aot/runtime_test.cc index ac79c278c1fdf8b6aedcb52121c767b8ba0ad358..6d603a02eb4ceade6832ba67b2981814ee25327a 100644 --- a/tensorflow/compiler/aot/runtime_test.cc +++ b/tensorflow/compiler/aot/runtime_test.cc @@ -15,7 +15,6 @@ limitations under the License. #include "tensorflow/compiler/aot/runtime.h" -#include "tensorflow/compiler/tf2xla/xla_local_runtime_context.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/platform/test.h" diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc index 6b037f276ad1d6771b904bb970f45f32ae9531b8..413efd9cea3b6f71574615ad9ca92471ff925781 100644 --- a/tensorflow/compiler/aot/tests/tfcompile_test.cc +++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc @@ -70,7 +70,7 @@ TEST(TFCompileTest, Add) { // Run tests that use set_argN_data separately, to avoid accidentally re-using // non-existent buffers. TEST(TFCompileTest, Add_SetArg) { - AddComp add(AddComp::AllocMode::RESULTS_AND_TEMPS_ONLY); + AddComp add(AddComp::AllocMode::RESULTS_PROFILES_AND_TEMPS_ONLY); int32 arg_x = 10; int32 arg_y = 32; @@ -258,7 +258,7 @@ TEST(TFCompileTest, MatMul2_SetArg) { Eigen::ThreadPoolDevice device(&tp, tp.NumThreads()); foo::bar::MatMulComp matmul( - foo::bar::MatMulComp::AllocMode::RESULTS_AND_TEMPS_ONLY); + foo::bar::MatMulComp::AllocMode::RESULTS_PROFILES_AND_TEMPS_ONLY); matmul.set_thread_pool(&device); // Test using the set_argN_data() methods. diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl index 363d6925a14dfab8b79617449a73727ab55c4527..542451ed2d14fbceca00c6ccb6e28c1c3a0d4321 100644 --- a/tensorflow/compiler/aot/tfcompile.bzl +++ b/tensorflow/compiler/aot/tfcompile.bzl @@ -130,6 +130,10 @@ def tf_library(name, graph, config, header_file = name + ".h" object_file = name + ".o" ep = ("__" + PACKAGE_NAME + "__" + name).replace("/", "_") + if type(tfcompile_flags) == type(""): + flags = tfcompile_flags + else: + flags = " ".join(["'" + arg.replace("'", "'\\''") + "'" for arg in (tfcompile_flags or [])]) native.genrule( name=("gen_" + name), srcs=[ @@ -148,7 +152,7 @@ def tf_library(name, graph, config, " --target_triple=" + target_llvm_triple() + " --out_header=$(@D)/" + header_file + " --out_object=$(@D)/" + object_file + - " " + (tfcompile_flags or "")), + " " + flags), tools=[tfcompile_tool], visibility=visibility, testonly=testonly, @@ -185,7 +189,7 @@ def tf_library(name, graph, config, " --cpp_class=" + cpp_class + " --target_triple=" + target_llvm_triple() + " --out_session_module=$(@D)/" + session_module_pb + - " " + (tfcompile_flags or "")), + " " + flags), tools=[tfcompile_tool], visibility=visibility, testonly=testonly, @@ -195,8 +199,7 @@ def tf_library(name, graph, config, # The cc_library rule packaging up the header and object file, and needed # kernel implementations. - need_xla_data_proto = (tfcompile_flags and - tfcompile_flags.find("--gen_program_shape") != -1) + need_xla_data_proto = (flags and flags.find("--gen_program_shape") != -1) native.cc_library( name=name, srcs=[object_file], @@ -264,7 +267,6 @@ def tf_library(name, graph, config, srcs=[test_file], deps=[ ":" + name, - "@org_tensorflow//tensorflow/compiler/tf2xla:xla_local_runtime_context", "@org_tensorflow//tensorflow/compiler/aot:runtime", "@org_tensorflow//tensorflow/compiler/aot:tf_library_test_main", "@org_tensorflow//tensorflow/compiler/xla:executable_run_options", @@ -310,7 +312,6 @@ def tf_library(name, graph, config, linkopts = if_android(["-pie", "-s"]), deps=[ ":" + name, - "@org_tensorflow//tensorflow/compiler/tf2xla:xla_local_runtime_context", "@org_tensorflow//tensorflow/compiler/aot:benchmark", "@org_tensorflow//tensorflow/compiler/aot:runtime", "@org_tensorflow//tensorflow/compiler/xla:executable_run_options", diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index bf7d9cf14d10f41aa48ea594a8d63db97b9973e1..026a1bf879d373fd0f5f4444b3ce10d01702f82b 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -251,6 +251,7 @@ cc_library( "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", + "//tensorflow/core:graph", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index 22899ebeebc929055518893b358f7950d380d6f6..dc06b7a4025ddc83bf766b702036297203c16e55 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/tensor_id.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -48,6 +49,52 @@ const char* const kXlaNumResourceArgsAttr = "_XlaNumResourceArgs"; namespace { +bool AreAllParentsConst(const Node& n, + const gtl::FlatSet& runtime_const_nodes) { + if (n.type_string() == "GuaranteeConst" || n.type_string() == "Const") { + // If the current node is itself a cast-to-const, no need + // to look at the incoming edges. + return true; + } + + bool all_parents_const = true; + bool atleast_one_non_control_edge = false; + for (const Edge* in : n.in_edges()) { + atleast_one_non_control_edge = + atleast_one_non_control_edge || !in->IsControlEdge(); + if (!in->IsControlEdge() && runtime_const_nodes.count(in->src()) == 0) { + all_parents_const = false; + break; + } + } + return all_parents_const && atleast_one_non_control_edge; +} + +void MarkGuaranteedConstants( + const Graph& graph, + const std::vector>& src_arg_pairs) { + gtl::FlatSet guaranteed_const_nodes; + std::vector srcs; + srcs.reserve(src_arg_pairs.size()); + for (const auto& src_arg : src_arg_pairs) { + srcs.push_back(src_arg.first); + } + ReverseDFSFrom(graph, srcs, /*enter=*/nullptr, + /*leave=*/[&guaranteed_const_nodes](Node* n) { + // TODO(vinuraja): Doesn't work in the presence of loops. + if (AreAllParentsConst(*n, guaranteed_const_nodes)) { + guaranteed_const_nodes.insert(n); + } + }); + + for (auto& src_arg : src_arg_pairs) { + if (guaranteed_const_nodes.count(src_arg.first) != 0) { + VLOG(1) << "Guaranteed const found: " << src_arg.first->DebugString(); + src_arg.second->AddAttr("_is_guaranteed_constant", true); + } + } +} + // A node/slot pair. // TODO(phawkins): is there a common definition of this? struct NodeSlot { @@ -175,9 +222,11 @@ Status Encapsulator::SplitIntoSubgraphs() { // Map from input graph nodes to subgraph nodes. std::unordered_map node_images; + std::vector> src_arg_pairs; // Copy all marked nodes to a subgraph. Do nothing for unmarked nodes. for (Node* node : graph_in_->op_nodes()) { string func_id = GetFunctionNameAttr(node); + if (func_id.empty()) continue; Subgraph& subgraph = subgraphs_[func_id]; @@ -276,11 +325,13 @@ Status Encapsulator::SplitIntoSubgraphs() { kArgOp); builder.Attr("T", dtype); builder.Attr("index", arg_index); + s = builder.Finalize(&arg_def); if (!s.ok()) return s; Node* arg = dst_subgraph.graph->AddNode(arg_def, &s); if (!s.ok()) return s; + src_arg_pairs.push_back({edge->src(), arg}); dst_subgraph.args.push_back(arg); } @@ -292,6 +343,8 @@ Status Encapsulator::SplitIntoSubgraphs() { } } + MarkGuaranteedConstants(*graph_in_, src_arg_pairs); + for (auto& entry : subgraphs_) { FixupSourceAndSinkEdges(entry.second.graph.get()); } diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc index 4a1dbaf05dc7824835f3567c6abcf48222720230..717efb360185f1ce26ee1e9adb0ee5bf7f4799f8 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc @@ -398,5 +398,109 @@ TEST(EncapsulateSubgraphsTest, ParallelChecking) { EXPECT_EQ(expected_edges, GraphEdges(*graph)); } +const Node* FindNodeByName(const Graph& graph, const string& name) { + for (const Node* node : graph.nodes()) { + if (node->name() == name) return node; + } + return nullptr; +} + +bool HasGuaranteeConstAttr(const Node& n) { + bool is_guaranteed_constant = false; + if (!GetNodeAttr(n.attrs(), "_is_guaranteed_constant", + &is_guaranteed_constant) + .ok()) { + return false; + } + return is_guaranteed_constant; +} + +TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Simple) { + Scope root = Scope::NewRootScope().ExitOnError().WithDevice( + "/job:localhost/replica:0/task:0/cpu:0"); + auto x1 = ops::Placeholder(root.WithOpName("x1"), DT_FLOAT); + auto const_x2 = ops::Const(root.WithOpName("const_x2"), 10.0f); + auto const_guarantee_x1 = + ops::GuaranteeConst(root.WithOpName("const_guarantee_x1"), x1); + auto add1 = ops::Add(root.WithOpName("add1"), const_guarantee_x1, const_x2); + add1.node()->AddAttr("_encapsulate", "encapsulate1"); + + Graph graph_before(OpRegistry::Global()); + TF_ASSERT_OK(root.ToGraph(&graph_before)); + + std::unique_ptr graph_after; + FunctionLibraryDefinition library(OpRegistry::Global(), {}); + int guaranteed_consts = 0; + TF_ASSERT_OK(EncapsulateSubgraphsInFunctions( + "_encapsulate", graph_before, + /*rewrite_subgraph_fn=*/ + [&guaranteed_consts](std::unique_ptr* graph_ptr, + std::vector* input_permutation, + std::vector* output_permutation, + NodeDef* call_def) { + Graph* graph = graph_ptr->get(); + for (const Node* n : graph->nodes()) { + if (n->type_string() == "_Arg" && + StringPiece(n->name()).starts_with("const")) { + ++guaranteed_consts; + EXPECT_TRUE(HasGuaranteeConstAttr(*n)); + } else { + EXPECT_FALSE(HasGuaranteeConstAttr(*n)); + } + } + return Status::OK(); + }, + /*parallel_checking=*/false, + /*reuse_existing_functions=*/false, &graph_after, &library)); + EXPECT_EQ(2, guaranteed_consts); +} + +TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Add) { + Scope root = Scope::NewRootScope().ExitOnError().WithDevice( + "/job:localhost/replica:0/task:0/cpu:0"); + auto x1 = ops::Placeholder(root.WithOpName("x1"), DT_FLOAT); + auto x2 = ops::Placeholder(root.WithOpName("x2"), DT_FLOAT); + auto const_guarantee_x1 = + ops::GuaranteeConst(root.WithOpName("const_guarantee_x1"), x1); + auto const_guarantee_x2 = + ops::GuaranteeConst(root.WithOpName("const_guarantee_x2"), x2); + auto const_guarantee_add1 = ops::Add(root.WithOpName("const_guarantee_add1"), + const_guarantee_x1, const_guarantee_x2); + auto add2 = ops::Add(root.WithOpName("add2"), const_guarantee_x1, x2); + auto mul1 = ops::Mul(root.WithOpName("mul1"), const_guarantee_add1, add2); + mul1.node()->AddAttr("_encapsulate", "encapsulate1"); + + Graph graph_before(OpRegistry::Global()); + TF_ASSERT_OK(root.ToGraph(&graph_before)); + + std::unique_ptr graph_after; + FunctionLibraryDefinition library(OpRegistry::Global(), {}); + int guaranteed_consts = 0; + TF_ASSERT_OK(EncapsulateSubgraphsInFunctions( + "_encapsulate", graph_before, + /*rewrite_subgraph_fn=*/ + [&guaranteed_consts](std::unique_ptr* graph_ptr, + std::vector* input_permutation, + std::vector* output_permutation, + NodeDef* call_def) { + Graph* graph = graph_ptr->get(); + for (const Node* n : graph->nodes()) { + if (n->type_string() == "_Arg" && + StringPiece(n->name()).starts_with("const")) { + ++guaranteed_consts; + EXPECT_TRUE(HasGuaranteeConstAttr(*n)); + } else { + EXPECT_FALSE(HasGuaranteeConstAttr(*n)); + } + } + return Status::OK(); + }, + /*parallel_checking=*/false, + /*reuse_existing_functions=*/false, &graph_after, &library)); + // Only 1 runtime const, which is const_guarantee_add1. Add2 has one const + // and another non-const, so overall non-const. + EXPECT_EQ(1, guaranteed_consts); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/kernels/BUILD b/tensorflow/compiler/jit/kernels/BUILD index 459a582e157f5ddc63997ca93e7c0294293517d3..9bea5663319c8a25249fdc265cee0191556a7c04 100644 --- a/tensorflow/compiler/jit/kernels/BUILD +++ b/tensorflow/compiler/jit/kernels/BUILD @@ -16,7 +16,6 @@ cc_library( "//tensorflow/compiler/jit:xla_device", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/compiler/tf2xla:xla_local_runtime_context", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc index 27c5da08c112664d361b5f969d100eed7b9df65c..39a770ab7b9ae56bd24865b86c69331b0a38ccec 100644 --- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc +++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc @@ -19,7 +19,6 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" -#include "tensorflow/compiler/tf2xla/xla_local_runtime_context.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" @@ -103,7 +102,6 @@ xla::StatusOr XlaAllocator::Allocate( } void* data = reinterpret_cast(const_cast(t.tensor_data().data())); - TF_RET_CHECK(data != nullptr); tensors_[data] = t; return gpu::DeviceMemoryBase(data, size); } @@ -111,7 +109,6 @@ xla::StatusOr XlaAllocator::Allocate( Status XlaAllocator::RegisterArgument(const Tensor* t) { void* data = reinterpret_cast(const_cast(t->tensor_data().data())); - TF_RET_CHECK(data != nullptr); tensors_[data] = *t; return Status::OK(); } @@ -257,7 +254,6 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition(); options.graph_def_version = ctx->function_library()->graph_def_version(); options.allow_cpu_custom_calls = (platform_id_ == gpu::host::kHostPlatformId); - options.local_executable_has_hybrid_result = true; const XlaCompiler::CompilationResult* kernel; xla::LocalExecutable* executable; @@ -268,7 +264,6 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { // Builds an XLA allocator for the device. XlaAllocator xla_allocator(client->platform(), ctx); - XlaLocalRuntimeContext local_runtime_context; std::unique_ptr output; // Build xla::ShapedBuffers that point directly to the Tensor buffers. @@ -301,18 +296,6 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { OP_REQUIRES_OK(ctx, xla_allocator.RegisterArgument(t)); } - // Make the final parameter point at local_runtime_context. - if (kernel->requires_runtime_context) { - gpu::DeviceMemoryBase local_runtime_context_dmem( - &local_runtime_context, sizeof(local_runtime_context)); - arg_buffers.push_back( - xla::ShapedBuffer::MakeArrayShapedBuffer( - xla::ShapeUtil::MakeOpaqueShape(), client->platform(), - client->default_device_ordinal(), local_runtime_context_dmem) - .ConsumeValueOrDie()); - arg_ptrs.push_back(arg_buffers.back().get()); - } - // Execute the computation. VLOG(2) << "Executing computation."; xla::ExecutableRunOptions run_options; @@ -324,12 +307,6 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { auto run_result = executable->Run(arg_ptrs, run_options); OP_REQUIRES(ctx, run_result.ok(), run_result.status()); - if (local_runtime_context.error) { - ctx->CtxFailure(errors::InvalidArgument("Compiled kernel returned error: ", - local_runtime_context.error_msg)); - return; - } - output = run_result.ConsumeValueOrDie()->release(); auto elapsed = env->NowMicros() - start_time; VLOG(2) << "Elapsed time: " << elapsed << "us"; diff --git a/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.cc b/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.cc index 09aee39d8cd0e910320674fcfd8a7884ce2fdd04..4bc209b7ecf499d82e7567f7eff12b17cefa9863 100644 --- a/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.cc +++ b/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.cc @@ -39,21 +39,23 @@ static void AllocateFlags() { flags->tf_xla_min_cluster_size = 2; flags->tf_xla_max_cluster_size = std::numeric_limits::max(); flags->tf_xla_clustering_debug = false; - flag_list = new std::vector({ - Flag("tf_xla_auto_jit", &flags->tf_xla_auto_jit, - "Control compilation of operators into XLA computations on CPU and " - "GPU devices. 0 = use ConfigProto setting; -1 = off; 1 = on for " - "things very likely to be improved; 2 = on for everything. " - "Experimental."), - Flag("tf_xla_min_cluster_size", &flags->tf_xla_min_cluster_size, - "Minimum number of operators in an XLA compilation. Ignored for " - "operators placed on an XLA device or operators explicitly marked " - "for compilation."), - Flag("tf_xla_max_cluster_size", &flags->tf_xla_max_cluster_size, - "Maximum number of operators in an XLA compilation."), - Flag("tf_xla_clustering_debug", &flags->tf_xla_clustering_debug, - "Dump graphs during XLA compilation."), - }); + flags->tf_xla_cpu_global_jit = false; + flag_list = new std::vector( + {Flag("tf_xla_auto_jit", &flags->tf_xla_auto_jit, + "Control compilation of operators into XLA computations on CPU and " + "GPU devices. 0 = use ConfigProto setting; -1 = off; 1 = on for " + "things very likely to be improved; 2 = on for everything. " + "Experimental."), + Flag("tf_xla_min_cluster_size", &flags->tf_xla_min_cluster_size, + "Minimum number of operators in an XLA compilation. Ignored for " + "operators placed on an XLA device or operators explicitly marked " + "for compilation."), + Flag("tf_xla_max_cluster_size", &flags->tf_xla_max_cluster_size, + "Maximum number of operators in an XLA compilation."), + Flag("tf_xla_clustering_debug", &flags->tf_xla_clustering_debug, + "Dump graphs during XLA compilation."), + Flag("tf_xla_cpu_global_jit", &flags->tf_xla_cpu_global_jit, + "Enables global JIT compilation for CPU via SessionOptions.")}); xla::legacy_flags::ParseFlagsFromEnv(*flag_list); } diff --git a/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h b/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h index 24f80507428b6742c64d3d7e96e4b1c540eda01b..e1ccd7ddb8706ca445b6811ca1fec369af7cd5d5 100644 --- a/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h +++ b/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h @@ -46,6 +46,8 @@ typedef struct { int32 tf_xla_max_cluster_size; // Maximum number of operators in an XLA // compilation. bool tf_xla_clustering_debug; // Dump graphs during XLA compilation. + bool tf_xla_cpu_global_jit; // Enables global JIT compilation for CPU + // via SessionOptions. } MarkForCompilationPassFlags; // Return a pointer to the MarkForCompilationPassFlags struct; diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 78d0aa86a8fae9a0c6035bdc579ef800337df917..aceedeb823ac47a36435e36e586f219d313ed121 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -210,6 +210,13 @@ Status FindCompilationCandidates( !IsCompilableWhile(*node, jit_device_type, 0, lib_runtime)) { continue; } + // _Retval nodes in a top-level function represent fetches. + // Do not compile them. + if (node->type_string() == "_Retval") { + VLOG(2) << "Compilation rejected node: return value " << node->name() + << ": " << node->type_string(); + continue; + } candidates->insert(node); } return Status::OK(); @@ -290,9 +297,11 @@ Status MarkForCompilationPass::Run( global_jit_level = static_cast(flags->tf_xla_auto_jit); } + bool cpu_global_jit = flags->tf_xla_cpu_global_jit; const FunctionLibraryDefinition* fld = options.flib_def; - auto is_compilable = [global_jit_level, fld](const Node* node, - const DeviceType& device_type) { + + auto is_compilable = [global_jit_level, cpu_global_jit, fld]( + const Node* node, const DeviceType& device_type) { const XlaOpRegistry::DeviceRegistration* registration; if (!XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)) { @@ -315,7 +324,11 @@ Status MarkForCompilationPass::Run( if (status.ok()) return compile; // Otherwise use the value of global_jit_level. - return registration->enable_jit_by_default && global_jit_level > 0; + // Ignore enable_jit_by_default if global jit compilation for CPU + // is explicitly requested via tf_xla_cpu_global_jit flag + bool ignore_registration = cpu_global_jit && device_type == DEVICE_CPU; + return (ignore_registration || registration->enable_jit_by_default) && + global_jit_level > 0; }; return RunImpl(options, is_compilable); } @@ -556,6 +569,7 @@ Status MarkForCompilationPass::RunImpl( if (cluster_sizes[cluster] >= min_cluster_size || marked_for_compilation || registration->requires_compilation) { string& name = cluster_names[cluster]; + if (name.empty()) { name = strings::StrCat("cluster_", cluster_sequence_num++); } diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index b3d258aea177fbefa4bae51d8156da2ff86c9032..454f0aeae98d7afd51f12b2cfb1810de275a57f7 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -525,5 +525,32 @@ TEST(XlaCompilationTest, IllegalCycle_UsefulErrorMessage) { "+-- c\n")); } +TEST(XlaCompilationTest, Retval) { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + GraphDef graphdef; + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* a = ops::SourceOp("Const", builder.opts() + .WithName("A") + .WithAttr("dtype", DT_FLOAT) + .WithAttr("value", Tensor())); + Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B")); + ops::UnaryOp("_Retval", b, + builder.opts() + .WithName("R") + .WithAttr("T", DT_FLOAT) + .WithAttr("index", 0)); + + TF_EXPECT_OK(builder.ToGraph(graph.get())); + } + + TF_ASSERT_OK(MarkForCompilation(&graph)); + auto clusters = GetClusters(*graph); + + EXPECT_EQ(2, clusters.size()); + EXPECT_TRUE(clusters.find("R") == clusters.cend()); + EXPECT_EQ(clusters["A"], clusters["B"]); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index 23368b6c76a363882956577a20c1bd041211d234..3717c2cc24283e0b218f92ec820d16893cbe0c35 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -214,23 +214,15 @@ Status XlaCompilationCache::BuildExecutable( const XlaCompiler::CompilationResult& result, std::unique_ptr* executable) { VLOG(2) << "Compiling to local executable"; - xla::Shape opaque_shape = xla::ShapeUtil::MakeOpaqueShape(); std::vector argument_layouts( result.xla_input_shapes.size()); for (int i = 0; i < result.xla_input_shapes.size(); ++i) { argument_layouts[i] = &result.xla_input_shapes[i]; } - if (result.requires_runtime_context) { - // The final arg is the XlaLocalRuntimeContext*. - argument_layouts.push_back(&opaque_shape); - } xla::ExecutableBuildOptions build_options; build_options.set_device_ordinal(client_->default_device_ordinal()); - build_options.set_platform(client_->platform()); build_options.set_result_layout(result.xla_output_shape); - build_options.set_has_hybrid_result( - options.local_executable_has_hybrid_result); auto compile_result = client_->Compile(*result.computation, argument_layouts, build_options); diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 0ff99c5156ded2ae05c6976e3da8f31fce32f8f2..8ace678daa1e9c69af72b65941586ef63a7757a5 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -117,6 +117,33 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "categorical_op_test", + size = "small", + srcs = ["categorical_op_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:platform_test", + "//tensorflow/python:random_ops", + ], +) + +tf_xla_py_test( + name = "cholesky_op_test", + size = "small", + srcs = ["cholesky_op_test.py"], + tags = ["optonly"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:training", + ], +) + tf_xla_py_test( name = "clustering_test", size = "small", @@ -252,6 +279,19 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "image_ops_test", + size = "small", + srcs = ["image_ops_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:image_ops", + "//tensorflow/python:platform_test", + ], +) + tf_xla_py_test( name = "lrn_ops_test", size = "medium", @@ -389,6 +429,20 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "scan_ops_test", + size = "small", + srcs = ["scan_ops_test.py"], + tags = ["optonly"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + tf_xla_py_test( name = "segment_reduction_ops_test", size = "medium", @@ -430,6 +484,19 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "stateless_random_ops_test", + size = "small", + srcs = ["stateless_random_ops_test.py"], + tags = ["optonly"], + deps = [ + ":xla_test", + "//tensorflow/contrib/stateless", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:platform_test", + ], +) + tf_xla_py_test( name = "tensor_array_ops_test", size = "small", @@ -645,7 +712,7 @@ tf_library( cpp_class = "LSTMLayerInference", graph = "lstm_layer_inference.pbtxt", tags = ["manual"], - tfcompile_flags = "--xla_cpu_multi_thread_eigen=false", + tfcompile_flags = ["--xla_cpu_multi_thread_eigen=false"], ) # ----------------------------------------------------------------------------- diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index d412c572ae16b84c2434819aa0a2d881defef5f9..654dc15e86b21c7742d49281d53c1a75e6a45d3b 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -366,16 +366,52 @@ class BinaryOpsTest(XLATestCase): self._testBinary( gen_math_ops._real_div, - np.array([3, 3j, -1.5j, -8, 2 + 3j, 2 + 4j, 44 + 3j], dtype=dtype), - np.array([2, -2, 7j, -4j, 4 - 6j, 1 + 2j, 0], dtype=dtype), + np.array([3, 3j, -1.5j, -8, 2 + 3j, 2 + 4j], dtype=dtype), + np.array([2, -2, 7j, -4j, 4 - 6j, 1 + 2j], dtype=dtype), + expected=np.array( + [1.5, -1.5j, -0.2142857, -2j, (2 + 3j) / (4 - 6j), 2], + dtype=dtype)) + + # Test inf/nan scenarios. + self._testBinary( + gen_math_ops._real_div, + np.array([4 + 3j, 4, 3j, -4, -4j, 2 - 3j], dtype=dtype), + np.array([0, 0, 0, 0, 0, 0], dtype=dtype), expected=np.array( [ - 1.5, -1.5j, -0.2142857, -2j, (2 + 3j) / (4 - 6j), 2, - float("inf") + dtype(1 + 1j) / 0, + dtype(1) / 0, + dtype(1j) / 0, + dtype(-1) / 0, + dtype(-1j) / 0, + dtype(1 - 1j) / 0 ], dtype=dtype)) - # TODO(b/65408531): support+test pow for cplx + atan2_supported = self.device == "XLA_GPU" + if atan2_supported: + self._testBinary( + math_ops.pow, + dtype(3 + 2j), + dtype(4 - 5j), + expected=np.power(dtype(3 + 2j), dtype(4 - 5j))) + self._testBinary( # empty rhs + math_ops.pow, + np.array([1 + 2j, 2 - 3j], dtype=dtype), + np.zeros(shape=[0, 2], dtype=dtype), + expected=np.zeros(shape=[0, 2], dtype=dtype)) + self._testBinary( # to zero power + math_ops.pow, + np.array([1 + 2j, 2 - 3j], dtype=dtype), + np.zeros(shape=[1, 2], dtype=dtype), + expected=np.ones(shape=[1, 2], dtype=dtype)) + lhs = np.array([1 - 2j, 4 + 3j, 2 - 3j, 3, 2j, 1, 4], dtype=dtype) + rhs = np.array([2, 3j, 3 + 4j, 2 + 3j, 3 - 2j, 2, 3 + 3j], dtype=dtype) + scalar = dtype(2 + 2j) + self._testBinary(math_ops.pow, lhs, rhs, expected=np.power(lhs, rhs)) + self._testBinary( + math_ops.pow, scalar, rhs, expected=np.power(scalar, rhs)) + self._testBinary(math_ops.pow, lhs, scalar, np.power(lhs, scalar)) lhs = np.array([4 + 2j, -3 - 1j, 2j, 1], dtype=dtype) rhs = np.array([5, -6j, 7 - 3j, -8j], dtype=dtype) @@ -385,7 +421,9 @@ class BinaryOpsTest(XLATestCase): self._testBinary( gen_math_ops._sigmoid_grad, lhs, rhs, expected=rhs * lhs * (1 - lhs)) - # TODO(b/65408531): support+test _rsqrt_grad for cplx (needs pow) + if atan2_supported: + self._testBinary( + gen_math_ops._rsqrt_grad, lhs, rhs, expected=lhs**3 * rhs / -2) self._testBinary( gen_math_ops._sqrt_grad, lhs, rhs, expected=rhs / (2 * lhs)) diff --git a/tensorflow/compiler/tests/categorical_op_test.py b/tensorflow/compiler/tests/categorical_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..035cdea1786d39f3d21bb63be5c8ccffe1608bdf --- /dev/null +++ b/tensorflow/compiler/tests/categorical_op_test.py @@ -0,0 +1,143 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for multinomial generation ops in the XLA JIT compiler.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import random_seed +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.platform import googletest + + +# TODO(srvasude): Merge this with +# third_party/tensorflow/python/kernel_tests/random/multinomial_op_test.py. +class CategoricalTest(XLATestCase): + """Test cases for random-number generating operators.""" + + def output_dtypes(self): + return set(self.int_types).intersection([np.int32, np.int64]) + + def _chi2(self, expected, actual): + """Returns Chi2 GOF statistic.""" + actual = np.asarray(actual) + expected = np.asarray(expected) + diff = actual - expected + chi2 = np.sum(diff * diff / expected) + return chi2 + + def _do_sampling(self, logits, num_samples): + """Categorical samples from given input. + + Args: + logits: Numpy ndarray of shape [batch_size, num_classes]. + num_samples: Int; number of samples to draw. + + Returns: + Frequencies from sampled classes; shape [batch_size, num_classes]. + """ + with self.test_session() as sess, self.test_scope(): + random_seed.set_random_seed(1618) + op = random_ops.multinomial(logits, num_samples, + output_dtype=dtypes.int32) + d = sess.run(op) + + batch_size, num_classes = logits.shape + freqs_mat = [] + for i in range(batch_size): + cnts = dict(collections.Counter(d[i, :])) + + # Requires drawn class labels be in range. + self.assertLess(max(cnts.keys()), num_classes) + self.assertGreaterEqual(min(cnts.keys()), 0) + + freqs = [(cnts[k] * 1. / num_samples if k in cnts else 0) + for k in range(num_classes)] + freqs_mat.append(freqs) + + return freqs_mat + + def _testRngIsNotConstant(self, rng, dtype, output_dtype): + # Tests that 'rng' does not always return the same value. + with self.test_session() as sess: + with self.test_scope(): + x = rng(dtype, output_dtype) + + # The random-number generator, if working correctly, should produce the + # same output multiple times with low probability. + y = sess.run(x) + z = sess.run(x) + w = sess.run(x) + + # We use exact equality here. If the random-number generator is producing + # deterministic output, all three outputs will be bitwise identical. + self.assertTrue((not np.array_equal(y, z)) or + (not np.array_equal(z, w)) or + (not np.array_equal(y, w))) + + def testCategoricalIsNotConstant(self): + def rng(dtype, output_dtype): + return random_ops.multinomial(np.array([[1., 1., 1.]], dtype=dtype), 10, + output_dtype=output_dtype) + + dtype = np.float32 + for output_dtype in self.output_dtypes(): + self._testRngIsNotConstant(rng, dtype, output_dtype) + + def testCategoricalIsInRange(self): + for dtype in self.float_types: + for output_dtype in self.output_dtypes(): + with self.test_session() as sess: + with self.test_scope(): + x = random_ops.multinomial( + array_ops.ones(shape=[1, 20], dtype=dtype), 1000, + output_dtype=output_dtype) + y = sess.run(x) + self.assertTrue((y >= 0).sum() == 1000) + self.assertTrue((y < 20).sum() == 1000) + + def testSamplingCorrectness(self): + np.random.seed(1618) # Make it reproducible. + num_samples = 21000 + + rand_probs = np.random.dirichlet([1., 1., 2., 3.]) + rand_probs2 = np.random.dirichlet([1., 4., 5.], size=3) # batched + for probs in [[.5, .5], [.85, .05, .1], rand_probs, rand_probs2]: + probs = np.asarray(probs) + if len(probs.shape) == 1: + probs = probs.reshape(1, probs.size) # singleton batch + + logits = np.log(probs).astype(np.float32) + freqs = self._do_sampling(logits, num_samples) + + # the test here is similar to + # python/kernel_tests/random/multinomial_op_test.py + # Note that df >= 1 in all these cases. Choosing a cutoff of 1e-3 + # corresponds to an alpha value of 2.5% for df = 1, and smaller for larger + # df. + chi2 = self._chi2(probs, freqs) + self.assertLess(chi2, 1e-3) + + +if __name__ == '__main__': + googletest.main() diff --git a/tensorflow/compiler/tests/cholesky_op_test.py b/tensorflow/compiler/tests/cholesky_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..5010fe5e21d0782e68d4e6d5bf6b4df1b44793a3 --- /dev/null +++ b/tensorflow/compiler/tests/cholesky_op_test.py @@ -0,0 +1,126 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.ops.tf.Cholesky.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import unittest + +import numpy as np +from six.moves import xrange # pylint: disable=redefined-builtin + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class CholeskyOpTest(XLATestCase): + + def _verifyCholeskyBase(self, sess, placeholder, x, chol, verification, atol): + chol_np, verification_np = sess.run([chol, verification], {placeholder: x}) + self.assertAllClose(x, verification_np, atol=atol) + self.assertShapeEqual(x, chol) + # Check that the cholesky is lower triangular, and has positive diagonal + # elements. + if chol_np.shape[-1] > 0: + chol_reshaped = np.reshape(chol_np, (-1, chol_np.shape[-2], + chol_np.shape[-1])) + for chol_matrix in chol_reshaped: + self.assertAllClose(chol_matrix, np.tril(chol_matrix), atol=atol) + self.assertTrue((np.diag(chol_matrix) > 0.0).all()) + + def _verifyCholesky(self, x, atol=1e-6): + # Verify that LL^T == x. + with self.test_session() as sess: + placeholder = array_ops.placeholder( + dtypes.as_dtype(x.dtype), shape=x.shape) + with self.test_scope(): + chol = linalg_ops.cholesky(placeholder) + verification = math_ops.matmul(chol, chol, adjoint_b=True) + self._verifyCholeskyBase(sess, placeholder, x, chol, verification, atol) + + def testBasic(self): + data = np.array([[4., -1., 2.], [-1., 6., 0], [2., 0., 5.]]) + for dtype in self.float_types: + self._verifyCholesky(data.astype(dtype)) + + def testBatch(self): + for dtype in self.float_types: + simple_array = np.array( + [[[1., 0.], [0., 5.]]], dtype=dtype) # shape (1, 2, 2) + self._verifyCholesky(simple_array) + self._verifyCholesky(np.vstack((simple_array, simple_array))) + odd_sized_array = np.array( + [[[4., -1., 2.], [-1., 6., 0], [2., 0., 5.]]], dtype=dtype) + self._verifyCholesky(np.vstack((odd_sized_array, odd_sized_array))) + + # Generate random positive-definite matrices. + matrices = np.random.rand(10, 5, 5).astype(dtype) + for i in xrange(10): + matrices[i] = np.dot(matrices[i].T, matrices[i]) + self._verifyCholesky(matrices, atol=1e-4) + + def testNonSquareMatrix(self): + for dtype in self.float_types: + with self.assertRaises(ValueError): + linalg_ops.cholesky(np.array([[1., 2., 3.], [3., 4., 5.]], dtype=dtype)) + with self.assertRaises(ValueError): + linalg_ops.cholesky( + np.array( + [[[1., 2., 3.], [3., 4., 5.]], [[1., 2., 3.], [3., 4., 5.]]], + dtype=dtype)) + + def testWrongDimensions(self): + for dtype in self.float_types: + tensor3 = constant_op.constant([1., 2.], dtype=dtype) + with self.assertRaises(ValueError): + linalg_ops.cholesky(tensor3) + with self.assertRaises(ValueError): + linalg_ops.cholesky(tensor3) + + @unittest.skip("Test is slow") + def testLarge(self): + n = 200 + shape = (n, n) + data = np.ones(shape).astype(np.float32) / (2.0 * n) + np.diag( + np.ones(n).astype(np.float32)) + self._verifyCholesky(data, atol=1e-4) + + def testMatrixConditionNumbers(self): + for dtype in self.float_types: + condition_number = 1000 + size = 20 + + # Generate random positive-definite symmetric matrices, and take their + # Eigendecomposition. + matrix = np.random.rand(size, size) + matrix = np.dot(matrix.T, matrix) + _, w = np.linalg.eigh(matrix) + + # Build new Eigenvalues exponentially distributed between 1 and + # 1/condition_number + v = np.exp(-np.log(condition_number) * np.linspace(0, size, size) / size) + matrix = np.dot(np.dot(w, np.diag(v)), w.T).astype(dtype) + self._verifyCholesky(matrix, atol=1e-4) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/function_test.py b/tensorflow/compiler/tests/function_test.py index cbe2888696c87c6c2f50c3de71e8531977ea395a..11d8a99ffe1a136a54b16e20f1792062203f7969 100644 --- a/tensorflow/compiler/tests/function_test.py +++ b/tensorflow/compiler/tests/function_test.py @@ -24,10 +24,12 @@ from tensorflow.compiler.tests.xla_test import XLATestCase from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import function +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.platform import googletest +@test_util.with_c_api class FunctionTest(XLATestCase): def testFunction(self): diff --git a/tensorflow/compiler/tests/fused_batchnorm_test.py b/tensorflow/compiler/tests/fused_batchnorm_test.py index 936fcf8b6be0f8cd67ba07a8bef9d35a732d30ba..a80d69fa5f5099b8a8b67df0da9c92b957e9d194 100644 --- a/tensorflow/compiler/tests/fused_batchnorm_test.py +++ b/tensorflow/compiler/tests/fused_batchnorm_test.py @@ -36,7 +36,7 @@ class FusedBatchNormTest(XLATestCase): x_square = x * x x_square_sum = np.sum(x_square, (0, 1, 2)) x_sum = np.sum(x, axis=(0, 1, 2)) - element_count = np.size(x) / int(np.shape(x)[0]) + element_count = np.size(x) / int(np.shape(x)[-1]) mean = x_sum / element_count var = x_square_sum / element_count - mean * mean normalized = (x - mean) / np.sqrt(var + epsilon) @@ -64,8 +64,9 @@ class FusedBatchNormTest(XLATestCase): return grad_x, grad_scale, grad_offset def testInference(self): - x_shape = [2, 2, 6, 2] - scale_shape = [2] + channel = 3 + x_shape = [2, 2, 6, channel] + scale_shape = [channel] x_val = np.random.random_sample(x_shape).astype(np.float32) scale_val = np.random.random_sample(scale_shape).astype(np.float32) @@ -74,8 +75,9 @@ class FusedBatchNormTest(XLATestCase): with self.test_session() as sess, self.test_scope(): # To avoid constant folding t_val = array_ops.placeholder(np.float32, shape=x_shape, name="x") - scale = array_ops.placeholder(np.float32, shape=[2], name="scale") - offset = array_ops.placeholder(np.float32, shape=[2], name="offset") + scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale") + offset = array_ops.placeholder( + np.float32, shape=scale_shape, name="offset") epsilon = 0.001 y_ref, mean_ref, var_ref = self._reference_training( x_val, scale_val, offset_val, epsilon, data_format) @@ -97,8 +99,9 @@ class FusedBatchNormTest(XLATestCase): self.assertAllClose(y_val, y_ref, atol=1e-3) def _testLearning(self, use_gradient_checker): - x_shape = [2, 2, 6, 2] - scale_shape = [2] + channel = 3 + x_shape = [2, 2, 6, channel] + scale_shape = [channel] x_val = np.random.random_sample(x_shape).astype(np.float32) scale_val = np.random.random_sample(scale_shape).astype(np.float32) @@ -109,8 +112,9 @@ class FusedBatchNormTest(XLATestCase): with self.test_session() as sess, self.test_scope(): # To avoid constant folding t_val = array_ops.placeholder(np.float32, shape=x_shape, name="x") - scale = array_ops.placeholder(np.float32, shape=[2], name="scale") - offset = array_ops.placeholder(np.float32, shape=[2], name="offset") + scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale") + offset = array_ops.placeholder( + np.float32, shape=scale_shape, name="offset") epsilon = 0.001 y, mean, var = nn.fused_batch_norm( t_val, @@ -151,11 +155,12 @@ class FusedBatchNormTest(XLATestCase): def testLearningWithGradientChecker(self): self._testLearning(True) - def testGradient(self): + def testGradientTraining(self): # TODO(b/64270657): Use gradient_checker here in addition to comparing with # this reference implementation. - x_shape = [2, 2, 6, 2] - scale_shape = [2] + channel = 3 + x_shape = [2, 2, 6, channel] + scale_shape = [channel] grad_val = np.random.random_sample(x_shape).astype(np.float32) x_val = np.random.random_sample(x_shape).astype(np.float32) scale_val = np.random.random_sample(scale_shape).astype(np.float32) @@ -170,7 +175,7 @@ class FusedBatchNormTest(XLATestCase): var = array_ops.placeholder(np.float32, shape=scale_shape, name="var") scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale") grad_x, grad_scale, grad_offset, _, _ = gen_nn_ops.fused_batch_norm_grad( - grad, x, scale, mean, var, data_format="NHWC") + grad, x, scale, mean, var, data_format="NHWC", is_training=True) grad_x_val, grad_scale_val, grad_offset_val = sess.run( [grad_x, grad_scale, grad_offset], { @@ -188,6 +193,53 @@ class FusedBatchNormTest(XLATestCase): self.assertAllClose(grad_scale_val, grad_scale_ref, atol=1e-2) self.assertAllClose(grad_offset_val, grad_offset_ref, atol=1e-3) + def testGradientInference(self): + # TODO(b/64270657): Use gradient_checker here in addition to comparing with + # this reference implementation. + channel = 3 + x_shape = [2, 2, 6, channel] + scale_shape = [channel] + grad_val = np.random.random_sample(x_shape).astype(np.float32) + x_val = np.random.random_sample(x_shape).astype(np.float32) + scale_val = np.random.random_sample(scale_shape).astype(np.float32) + mean_val = np.random.random_sample(scale_shape).astype(np.float32) + var_val = np.random.random_sample(scale_shape).astype(np.float32) + + with self.test_session() as sess, self.test_scope(): + grad = array_ops.placeholder(np.float32, shape=x_shape, name="grad") + x = array_ops.placeholder(np.float32, shape=x_shape, name="x") + mean = array_ops.placeholder(np.float32, shape=scale_shape, name="mean") + var = array_ops.placeholder(np.float32, shape=scale_shape, name="var") + scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale") + with self.test_scope(): + out = gen_nn_ops.fused_batch_norm_grad( + grad, x, scale, mean, var, data_format="NHWC", is_training=False) + grad_x, grad_scale, grad_offset, _, _ = out + + ref_x, ref_scale, ref_offset, _, _ = gen_nn_ops.fused_batch_norm_grad( + grad, x, scale, mean, var, data_format="NHWC", is_training=False) + + grad_x_val, grad_scale_val, grad_offset_val, = sess.run( + [grad_x, grad_scale, grad_offset], { + grad: grad_val, + x: x_val, + mean: mean_val, + var: var_val, + scale: scale_val + }) + grad_x_ref, grad_scale_ref, grad_offset_ref, = sess.run( + [ref_x, ref_scale, ref_offset], { + grad: grad_val, + x: x_val, + mean: mean_val, + var: var_val, + scale: scale_val + }) + + self.assertAllClose(grad_x_val, grad_x_ref, atol=1e-2) + self.assertAllClose(grad_scale_val, grad_scale_ref, atol=1e-2) + self.assertAllClose(grad_offset_val, grad_offset_ref, atol=1e-3) + if __name__ == "__main__": test.main() diff --git a/tensorflow/compiler/tests/gather_test.py b/tensorflow/compiler/tests/gather_test.py index 664c77f2000281e3be989665664c1be58d4dd1e5..13cbe6f312f5175edaec28fa7a8f28064194b0e9 100644 --- a/tensorflow/compiler/tests/gather_test.py +++ b/tensorflow/compiler/tests/gather_test.py @@ -45,7 +45,7 @@ class GatherTest(xla_test.XLATestCase): with self.test_session() as session, self.test_scope(): data = np.array([0, 1, 2, 3, 7, 5]) for dtype in self.all_tf_types: - for indices in 4, [1, 2, 2, 4, 5]: + for indices in 4, [4], [1, 2, 2, 4, 5]: params_np = self._buildParams(data, dtype) params = array_ops.placeholder(dtype=dtype) indices_tf = constant_op.constant(indices) diff --git a/tensorflow/compiler/tests/image_ops_test.py b/tensorflow/compiler/tests/image_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..a04f376ebf6092fd9b6e879796454b1a5c648c96 --- /dev/null +++ b/tensorflow/compiler/tests/image_ops_test.py @@ -0,0 +1,142 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for image ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_image_ops +from tensorflow.python.platform import test + + +class ResizeBilinearTest(XLATestCase): + + def _assertForwardOpMatchesExpected(self, + image_np, + target_shape, + expected=None): + if expected is None: + self.fail("expected must be specified") + with self.test_session() as sess, self.test_scope(): + image = array_ops.placeholder(image_np.dtype) + resized = gen_image_ops.resize_bilinear( + image, target_shape, align_corners=True) + out = sess.run(resized, {image: image_np[np.newaxis, :, :, np.newaxis]}) + self.assertAllClose(expected[np.newaxis, :, :, np.newaxis], out) + + def _assertBackwardOpMatchesExpected(self, + grads_np, + input_shape=None, + dtype=None, + expected=None): + if input_shape is None: + self.fail("input_shape must be specified") + if expected is None: + self.fail("expected must be specified") + with self.test_session() as sess, self.test_scope(): + dtype = dtype or np.float32 + grads = array_ops.placeholder(np.float32) + resized = gen_image_ops._resize_bilinear_grad( + grads, + np.zeros([1, input_shape[0], input_shape[1], 1], dtype=dtype), + align_corners=True) + out = sess.run(resized, {grads: grads_np[np.newaxis, :, :, np.newaxis]}) + self.assertAllClose(expected[np.newaxis, :, :, np.newaxis], out) + + def testAlignCorners1x2To3x2(self): + for dtype in self.float_types: + self._assertForwardOpMatchesExpected( + np.array([[1, 2]], dtype=dtype), [3, 3], + expected=np.array( + [[1, 1.5, 2], [1, 1.5, 2], [1, 1.5, 2]], dtype=np.float32)) + + def testAlignCorners1x2To3x2Grad(self): + for dtype in self.float_types: + self._assertBackwardOpMatchesExpected( + np.array([[1, 2], [3, 4], [5, 6]], dtype=np.float32), + input_shape=[1, 2], + dtype=dtype, + expected=np.array([[9, 12]], dtype=np.float32)) + + def testAlignCorners2x2To1x1(self): + for dtype in self.float_types: + self._assertForwardOpMatchesExpected( + np.array([[1, 2], [3, 4]], dtype=dtype), [1, 1], + expected=np.array([[1]], dtype=np.float32)) + + def testAlignCorners2x2To1x1Grad(self): + for dtype in self.float_types: + self._assertBackwardOpMatchesExpected( + np.array([[7]], dtype=np.float32), + input_shape=[2, 2], + dtype=dtype, + expected=np.array([[7, 0], [0, 0]], dtype=np.float32)) + + def testAlignCorners2x2To3x3(self): + for dtype in self.float_types: + self._assertForwardOpMatchesExpected( + np.array([[1, 2], [3, 4]], dtype=dtype), [3, 3], + expected=np.array( + [[1, 1.5, 2], [2, 2.5, 3], [3, 3.5, 4]], dtype=np.float32)) + + def testAlignCorners2x2To3x3Grad(self): + self._assertBackwardOpMatchesExpected( + np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float32), + input_shape=[2, 2], + expected=np.array([[5.25, 8.25], [14.25, 17.25]], dtype=np.float32)) + + def testAlignCorners3x3To2x2(self): + for dtype in self.float_types: + self._assertForwardOpMatchesExpected( + np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=dtype), [2, 2], + expected=np.array([[1, 3], [7, 9]], dtype=np.float32)) + + def testAlignCorners3x3To2x2Grad(self): + for dtype in self.float_types: + self._assertBackwardOpMatchesExpected( + np.array([[7, 13], [22, 4]], dtype=np.float32), + input_shape=[3, 3], + dtype=dtype, + expected=np.array( + [[7, 0, 13], [0, 0, 0], [22, 0, 4]], dtype=np.float32)) + + def testAlignCorners4x4To3x3(self): + for dtype in self.float_types: + self._assertForwardOpMatchesExpected( + np.array( + [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], + dtype=dtype), [3, 3], + expected=np.array( + [[1, 2.5, 4], [7, 8.5, 10], [13, 14.5, 16]], dtype=np.float32)) + + def testAlignCorners4x4To3x3Grad(self): + for dtype in self.float_types: + self._assertBackwardOpMatchesExpected( + np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float32), + input_shape=[4, 4], + dtype=dtype, + expected=np.array( + [[1, 1, 1, 3], [2, 1.25, 1.25, 3], [2, 1.25, 1.25, 3], + [7, 4, 4, 9]], + dtype=np.float32)) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc index 6a8c3bcd55a6e454a19b6249cf4eb48739c8657f..798daaadbc5be50ef9cf7e1205f6d5a0bde59640 100644 --- a/tensorflow/compiler/tests/randomized_tests.cc +++ b/tensorflow/compiler/tests/randomized_tests.cc @@ -2460,6 +2460,36 @@ TEST_F(OpTest, Reshape) { }); } +TEST_F(OpTest, ResizeBilinear) { + Repeatedly([this]() { + std::vector in_dims = RandomDims(4, 4); + std::vector out_dims = RandomDims(2, 2); + + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("ResizeBilinear") + .RandomInput(DT_FLOAT, in_dims) + .Input(test::AsTensor( + std::vector(out_dims.begin(), out_dims.end()))) + .Attr("T", DT_FLOAT) + .Attr("align_corners", true)); + }); +} + +TEST_F(OpTest, ResizeBilinearGrad) { + Repeatedly([this]() { + std::vector in_dims = RandomDims(4, 4); + std::vector out_dims = RandomDims(2, 2); + + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("ResizeBilinearGrad") + .RandomInput(DT_FLOAT, in_dims) + .RandomInput(DT_FLOAT, + {in_dims[0], out_dims[0], out_dims[1], in_dims[3]}) + .Attr("T", DT_FLOAT) + .Attr("align_corners", true)); + }); +} + TEST_F(OpTest, Reverse) { Repeatedly([this]() { std::vector dims = RandomDims(1); diff --git a/tensorflow/compiler/tests/reduce_ops_test.py b/tensorflow/compiler/tests/reduce_ops_test.py index efda2cc207b2ab56774d193117a2237f3afbfb55..965fdf684b973498d0b3c3cde17711cca7279705 100644 --- a/tensorflow/compiler/tests/reduce_ops_test.py +++ b/tensorflow/compiler/tests/reduce_ops_test.py @@ -67,25 +67,37 @@ class ReduceOpsTest(XLATestCase): np.arange(-10, -4).reshape(2, 3), np.arange(-4, 2).reshape(2, 3), ] - NONEMPTY_FLOAT_DATA = [ - np.arange(1, 7).reshape(2, 3), - np.arange(-10, -4).reshape(2, 3), - np.arange(-4, 2).reshape(2, 3), + COMPLEX_DATA = [ + np.zeros(shape=(2, 0)).astype(np.complex64), + np.zeros(shape=(0, 30)).astype(np.complex64), + np.arange(1, 13, dtype=np.float32).view(np.complex64).reshape(2, 3), + np.arange(-14, -2, dtype=np.float32).view(np.complex64).reshape(2, 3), + np.arange(-4, 8, dtype=np.float32).view(np.complex64).reshape(2, 3), ] + NONEMPTY_FLOAT_DATA = [x for x in FLOAT_DATA if np.size(x) > 0] + NONEMPTY_COMPLEX_DATA = [x for x in COMPLEX_DATA if np.size(x) > 0] BOOL_DATA = [ np.array([], dtype=np.bool).reshape(2, 0), np.array([], dtype=np.bool).reshape(0, 3), np.array([[False, True, False], [True, True, False]]), ] - def testReduceSum(self): + def testReduceSumF32(self): self._testReduction(math_ops.reduce_sum, np.sum, np.float32, self.FLOAT_DATA) - def testReduceProd(self): + def testReduceSumC64(self): + self._testReduction(math_ops.reduce_sum, np.sum, np.complex64, + self.COMPLEX_DATA) + + def testReduceProdF32(self): self._testReduction(math_ops.reduce_prod, np.prod, np.float32, self.FLOAT_DATA) + def testReduceProdC64(self): + self._testReduction(math_ops.reduce_prod, np.prod, np.complex64, + self.COMPLEX_DATA) + def testReduceMin(self): def reference_min(inp, axis): @@ -108,12 +120,16 @@ class ReduceOpsTest(XLATestCase): self._testReduction(math_ops.reduce_max, reference_max, np.float32, self.FLOAT_DATA) - def testReduceMean(self): + def testReduceMeanF32(self): # TODO(phawkins): mean on XLA currently returns 0 instead of NaN when # reducing across zero inputs. self._testReduction(math_ops.reduce_mean, np.mean, np.float32, self.NONEMPTY_FLOAT_DATA) + def testReduceMeanC64(self): + self._testReduction(math_ops.reduce_mean, np.mean, np.complex64, + self.NONEMPTY_COMPLEX_DATA) + def testReduceAll(self): self._testReduction(math_ops.reduce_all, np.all, np.bool, self.BOOL_DATA) diff --git a/tensorflow/compiler/tests/scan_ops_test.py b/tensorflow/compiler/tests/scan_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..3260e63b23226d736a7ddc0f21a94a8c791e0442 --- /dev/null +++ b/tensorflow/compiler/tests/scan_ops_test.py @@ -0,0 +1,229 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Functional tests for scan ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import errors_impl +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +def numpy_reverse(x, axis): + length = len(x.shape) + if axis < 0: + axis = length + axis + + ix = [ + slice(None, None, -1) if i == axis else slice(None) for i in range(length) + ] + return x[ix] + + +def handle_options(func, x, axis, exclusive, reverse): + """Adds tf options to numpy scan ops.""" + length = len(x.shape) + if axis < 0: + axis = length + axis + + if reverse: + x = numpy_reverse(x, axis) + + if exclusive: + ix_head = [slice(0, 1) if i == axis else slice(None) for i in range(length)] + ix_init = [ + slice(0, -1) if i == axis else slice(None) for i in range(length) + ] + if func == np.cumsum: + init = np.zeros_like(x[ix_head]) + elif func == np.cumprod: + init = np.ones_like(x[ix_head]) + else: + raise ValueError("Unknown scan function.") + x = np.concatenate([init, func(x[ix_init], axis)], axis=axis) + else: + x = func(x, axis=axis) + + if reverse: + x = numpy_reverse(x, axis) + return x + + +class CumsumTest(XLATestCase): + + valid_dtypes = [np.float32] + + def axis_dtypes(self): + return set(self.int_types).intersection([np.int32, np.int64]) + + def _compare(self, x, axis, exclusive, reverse): + np_out = handle_options(np.cumsum, x, axis, exclusive, reverse) + with self.test_session(), self.test_scope(): + p = array_ops.placeholder(x.dtype) + tf_out = math_ops.cumsum(p, axis, exclusive, reverse).eval( + feed_dict={p: x}) + + self.assertAllClose(np_out, tf_out) + + def _compareAll(self, x, axis): + for exclusive in [True, False]: + for reverse in [True, False]: + self._compare(x, axis, exclusive, reverse) + + def testEmpty(self): + for dtype in self.valid_dtypes: + x = np.zeros([0]).astype(dtype) + for axis in (-1, 0): + self._compareAll(x, axis) + + def testAxisType(self): + for dtype in self.valid_dtypes: + x = np.arange(1, 6).reshape([5]).astype(dtype) + for axis_dtype in self.axis_dtypes(): + with self.test_session(), self.test_scope(): + p = array_ops.placeholder(x.dtype) + axis = constant_op.constant(0, axis_dtype) + math_ops.cumsum(p, axis).eval(feed_dict={p: x}) + + def test1D(self): + for dtype in self.valid_dtypes: + x = np.arange(1, 6).reshape([5]).astype(dtype) + for axis in (-1, 0): + self._compareAll(x, axis) + + def test2D(self): + for dtype in self.valid_dtypes: + x = np.arange(0, 10).reshape([2, 5]).astype(dtype) + for axis in (-2, -1, 0, 1): + self._compareAll(x, axis) + + def test3D(self): + for dtype in self.valid_dtypes: + x = np.arange(0, 20).reshape([2, 2, 5]).astype(dtype) + for axis in (-3, -2, -1, 0, 1, 2): + self._compareAll(x, axis) + + def test6D(self): + for dtype in self.valid_dtypes: + x = np.arange(1, 145).reshape([2, 2, 3, 3, 2, 2]).astype(dtype) + for axis in range(-6, 6, 3): + self._compareAll(x, axis) + + def testInvalidAxis(self): + x = np.arange(0, 10).reshape([2, 5]).astype(np.float32) + with self.test_session(), self.test_scope(): + input_tensor = ops.convert_to_tensor(x) + with self.assertRaisesWithPredicateMatch( + errors_impl.InvalidArgumentError, + lambda e: "Expected scan axis in the range [-2, 2)" in str(e)): + math_ops.cumsum(input_tensor, -3).eval() + with self.assertRaisesWithPredicateMatch( + errors_impl.InvalidArgumentError, + lambda e: "Expected scan axis in the range [-2, 2)" in str(e)): + math_ops.cumsum(input_tensor, 2).eval() + with self.assertRaisesWithPredicateMatch( + errors_impl.InvalidArgumentError, + lambda e: "axis must be a scalar" in str(e)): + math_ops.cumsum(input_tensor, [0]).eval() + + +class CumprodTest(XLATestCase): + + valid_dtypes = [np.float32] + + def axis_dtypes(self): + return set(self.int_types).intersection([np.int32, np.int64]) + + def _compare(self, x, axis, exclusive, reverse): + np_out = handle_options(np.cumprod, x, axis, exclusive, reverse) + with self.test_session(), self.test_scope(): + p = array_ops.placeholder(x.dtype) + prod = math_ops.cumprod(p, axis, exclusive, reverse) + tf_out = prod.eval(feed_dict={p: x}) + + self.assertAllClose(np_out, tf_out) + + def _compareAll(self, x, axis): + for exclusive in [True, False]: + for reverse in [True, False]: + self._compare(x, axis, exclusive, reverse) + + def testEmpty(self): + for dtype in self.valid_dtypes: + x = np.zeros([0]).astype(dtype) + for axis in (-1, 0): + self._compareAll(x, axis) + + def testAxisType(self): + for dtype in self.valid_dtypes: + x = np.arange(1, 6).reshape([5]).astype(dtype) + for axis_dtype in self.axis_dtypes(): + with self.test_session(), self.test_scope(): + p = array_ops.placeholder(x.dtype) + axis = constant_op.constant(0, axis_dtype) + math_ops.cumprod(x, axis).eval(feed_dict={p: x}) + + def test1D(self): + for dtype in self.valid_dtypes: + x = np.arange(1, 6).reshape([5]).astype(dtype) + for axis in (-1, 0): + self._compareAll(x, axis) + + def test2D(self): + for dtype in self.valid_dtypes: + x = np.arange(1, 11).reshape([2, 5]).astype(dtype) + for axis in (-2, -1, 0, 1): + self._compareAll(x, axis) + + def test3D(self): + for dtype in self.valid_dtypes: + x = np.arange(1, 21).reshape([2, 2, 5]).astype(dtype) + for axis in (-3, -2, -1, 0, 1, 2): + self._compareAll(x, axis) + + def test6D(self): + for dtype in self.valid_dtypes: + x = np.arange(1, 145).reshape([2, 2, 3, 3, 2, 2]).astype(dtype) + for axis in range(-6, 6, 3): + self._compareAll(x, axis) + + def testInvalidAxis(self): + x = np.arange(0, 10).reshape([2, 5]).astype(np.float32) + with self.test_session(), self.test_scope(): + input_tensor = ops.convert_to_tensor(x) + with self.assertRaisesWithPredicateMatch( + errors_impl.InvalidArgumentError, + lambda e: "Expected scan axis in the range [-2, 2)" in str(e)): + math_ops.cumprod(input_tensor, -3).eval() + with self.assertRaisesWithPredicateMatch( + errors_impl.InvalidArgumentError, + lambda e: "Expected scan axis in the range [-2, 2)" in str(e)): + math_ops.cumprod(input_tensor, 2).eval() + with self.assertRaisesWithPredicateMatch( + errors_impl.InvalidArgumentError, + lambda e: "axis must be a scalar" in str(e)): + math_ops.cumprod(input_tensor, [0]).eval() + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/stateless_random_ops_test.py b/tensorflow/compiler/tests/stateless_random_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..4336ebdbd184a081619f0a6951dd4514735c6eb6 --- /dev/null +++ b/tensorflow/compiler/tests/stateless_random_ops_test.py @@ -0,0 +1,118 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for stateless random-number generation ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math + +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.contrib import stateless +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class StatelessRandomOpsTest(XLATestCase): + """Test cases for stateless random-number generator operators.""" + + def _random_types(self): + return [dtypes.float32] + + def testDeterminism(self): + # Stateless values should be equal iff the seeds are equal (roughly) + with self.test_session(), self.test_scope(): + seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) + seeds = [(x, y) for x in range(5) for y in range(5)] * 3 + for stateless_op in [ + stateless.stateless_random_uniform, stateless.stateless_random_normal + ]: + for shape in (), (3,), (2, 5): + for dtype in self._random_types(): + pure = stateless_op(shape, seed=seed_t, dtype=dtype) + values = [(seed, pure.eval(feed_dict={ + seed_t: seed + })) for seed in seeds] + for s0, v0 in values: + for s1, v1 in values: + self.assertEqual(s0 == s1, np.all(v0 == v1)) + + def testRandomUniformIsInRange(self): + with self.test_session() as sess, self.test_scope(): + for dtype in self._random_types(): + seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) + x = stateless.stateless_random_uniform( + shape=[1000], seed=seed_t, dtype=dtype) + y = sess.run(x, {seed_t: [0x12345678, 0xabcdef12]}) + self.assertTrue(np.all(y >= 0)) + self.assertTrue(np.all(y < 1)) + + def _chi_squared(self, x, bins): + """Pearson's Chi-squared test.""" + x = np.ravel(x) + n = len(x) + histogram, _ = np.histogram(x, bins=bins, range=(0, 1)) + expected = n / float(bins) + return np.sum(np.square(histogram - expected) / expected) + + def testDistributionOfStatelessRandomUniform(self): + """Use Pearson's Chi-squared test to test for uniformity.""" + with self.test_session() as sess, self.test_scope(): + for dtype in self._random_types(): + seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) + n = 1000 + x = stateless.stateless_random_uniform( + shape=[n], seed=seed_t, dtype=dtype) + y = sess.run(x, {seed_t: [565656, 121212]}) + # Tests that the values are distributed amongst 10 bins with equal + # probability. 16.92 is the Chi^2 value for 9 degrees of freedom with + # p=0.05. This test is probabilistic and would be flaky if the random + # seed were not fixed. + self.assertTrue(self._chi_squared(y, 10) < 16.92) + + def _normal_cdf(self, x): + """Cumulative distribution function for a standard normal distribution.""" + return 0.5 + 0.5 * np.vectorize(math.erf)(x / math.sqrt(2)) + + def _anderson_darling(self, x): + """Anderson-Darling test for a standard normal distribution.""" + x = np.sort(np.ravel(x)) + n = len(x) + i = np.linspace(1, n, n) + z = np.sum((2 * i - 1) * np.log(self._normal_cdf(x)) + + (2 * (n - i) + 1) * np.log(1 - self._normal_cdf(x))) + return -n - z / n + + def testDistributionOfStatelessRandomNormal(self): + """Use Anderson-Darling test to test distribution appears normal.""" + with self.test_session() as sess, self.test_scope(): + for dtype in self._random_types(): + seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) + n = 1000 + x = stateless.stateless_random_normal( + shape=[n], seed=seed_t, dtype=dtype) + y = sess.run(x, {seed_t: [25252, 314159]}) + # The constant 2.492 is the 5% critical value for the Anderson-Darling + # test where the mean and variance are known. This test is probabilistic + # so to avoid flakiness the seed is fixed. + self.assertTrue(self._anderson_darling(y) < 2.492) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index 76644380bdf2e0c24f6d363ddfaabdff836495d7..0da7442a24201011e3126e53c9d884534a0d721e 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -33,6 +33,17 @@ from tensorflow.python.ops import nn_ops from tensorflow.python.platform import googletest +def nhwc_to_format(x, data_format): + """Converts a numpy array from NHWC format to `data_format`.""" + rank = len(x.shape) + if data_format == "NCHW": + return np.transpose(x, [0, rank - 1] + list(range(1, rank - 1))) + elif data_format == "NHWC": + return x + else: + raise ValueError("Unknown format {}".format(data_format)) + + class UnaryOpsTest(XLATestCase): """Test cases for unary operators.""" @@ -76,6 +87,12 @@ class UnaryOpsTest(XLATestCase): array_ops.diag_part, np.arange(36).reshape([2, 3, 2, 3]).astype(dtype), np.array([[0, 7, 14], [21, 28, 35]], dtype=dtype)) + self._assertOpOutputMatchesExpected( + array_ops.diag, np.array([[1, 2], [3, 4]], dtype=dtype), + np.array( + [[[[1, 0], [0, 0]], [[0, 2], [0, 0]]], [[[0, 0], [3, 0]], + [[0, 0], [0, 4]]]], + dtype=dtype)) self._assertOpOutputMatchesExpected( array_ops.identity, @@ -86,6 +103,21 @@ class UnaryOpsTest(XLATestCase): array_ops.matrix_diag, np.array([[1, 2], [3, 4]], dtype=dtype), np.array([[[1, 0], [0, 2]], [[3, 0], [0, 4]]], dtype=dtype)) + self._assertOpOutputMatchesExpected( + array_ops.matrix_diag, np.array([1, 2, 3, 4], dtype=dtype), + np.array( + [[1, 0, 0, 0], [0, 2, 0, 0], [0, 0, 3, 0], [0, 0, 0, 4]], + dtype=dtype)) + self._assertOpOutputMatchesExpected( + array_ops.matrix_diag, + np.array( + [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]], dtype=dtype), + np.array( + [[[[1, 0, 0], [0, 2, 0], [0, 0, 3]], + [[4, 0, 0], [0, 5, 0], [0, 0, 6]]], + [[[7, 0, 0], [0, 8, 0], [0, 0, 9]], + [[10, 0, 0], [0, 11, 0], [0, 0, 12]]]], + dtype=dtype)) self._assertOpOutputMatchesExpected( array_ops.matrix_diag_part, np.arange(3 * 2 * 4).reshape([3, 2, 4]).astype(dtype), @@ -330,12 +362,22 @@ class UnaryOpsTest(XLATestCase): def testComplexOps(self): for dtype in self.complex_types: - # TODO(b/65408531): math_ops.acosh (needs pow) - # TODO(b/65408531): math_ops.asinh (needs pow) # TODO(b/65408531): Wider support for log (needs atan2). atan2_supported = self.device == "XLA_GPU" if atan2_supported: + self._assertOpOutputMatchesExpected( + math_ops.acosh, + np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype), + expected=np.arccosh( + np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype))) + + self._assertOpOutputMatchesExpected( + math_ops.asinh, + np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype), + expected=np.arcsinh( + np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype))) + self._assertOpOutputMatchesExpected( math_ops.atanh, np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype), @@ -392,19 +434,26 @@ class UnaryOpsTest(XLATestCase): expected=np.log1p( np.array([[1e-14, 1e-15j, 0.6 - 0.3j]], dtype=dtype))) - # TODO(b/34703906): math_ops.rsqrt (needs pow) + val = np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype) + self._assertOpOutputMatchesExpected( + math_ops.rsqrt, val, expected=1 / np.sqrt(val)) + + self._assertOpOutputMatchesExpected( + math_ops.sigmoid, val, expected=1 / (1 + np.exp(-val))) - # TODO(b/34703906): math_ops.sigmoid (needs tanh) + self._assertOpOutputMatchesExpected( + math_ops.sqrt, val, expected=np.sqrt(val)) - # TODO(b/34703906): math_ops.sqrt (needs pow) + self._assertOpOutputMatchesExpected( + math_ops.tanh, + np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype), + expected=np.tanh(np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype))) self._assertOpOutputMatchesExpected( math_ops.tan, np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype), expected=np.tan(np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype))) - # TODO(b/34703906): math_ops.tanh (as itself) - ctypes = {np.complex64: np.float32} self._assertOpOutputMatchesExpected( math_ops.abs, @@ -624,55 +673,88 @@ class UnaryOpsTest(XLATestCase): equality_test=self.ListsAreClose) def testDepthToSpace(self): + def make_op(data_format): + def op(x): + return array_ops.depth_to_space(x, block_size=2, + data_format=data_format) + return op + for dtype in self.numeric_types: - self._assertOpOutputMatchesExpected( - lambda x: array_ops.depth_to_space(x, block_size=2), - np.array([[[[1, 2, 3, 4]]]], dtype=dtype), - expected=np.array([[[[1], [2]], - [[3], [4]]]], dtype=dtype)) + for data_format in ["NCHW", "NHWC"]: + self._assertOpOutputMatchesExpected( + make_op(data_format), + nhwc_to_format(np.array([[[[1, 2, 3, 4]]]], dtype=dtype), + data_format), + expected=nhwc_to_format(np.array([[[[1], [2]], + [[3], [4]]]], dtype=dtype), + data_format)) - self._assertOpOutputMatchesExpected( - lambda x: array_ops.depth_to_space(x, block_size=2), - np.array([[[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]]], dtype=dtype), - expected=np.array([[[[1, 2, 3], [4, 5, 6]], - [[7, 8, 9], [10, 11, 12]]]], dtype=dtype)) + self._assertOpOutputMatchesExpected( + make_op(data_format), + nhwc_to_format( + np.array([[[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]]], + dtype=dtype), + data_format), + expected=nhwc_to_format( + np.array([[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]], + dtype=dtype), + data_format)) - self._assertOpOutputMatchesExpected( - lambda x: array_ops.depth_to_space(x, block_size=2), - np.array([[[[1, 2, 3, 4], - [5, 6, 7, 8]], - [[9, 10, 11, 12], - [13, 14, 15, 16]]]], dtype=dtype), - expected=np.array([[[[1], [2], [5], [6]], - [[3], [4], [7], [8]], - [[9], [10], [13], [14]], - [[11], [12], [15], [16]]]], dtype=dtype)) + self._assertOpOutputMatchesExpected( + make_op(data_format), + nhwc_to_format( + np.array([[[[1, 2, 3, 4], + [5, 6, 7, 8]], + [[9, 10, 11, 12], + [13, 14, 15, 16]]]], dtype=dtype), + data_format), + expected=nhwc_to_format( + np.array([[[[1], [2], [5], [6]], + [[3], [4], [7], [8]], + [[9], [10], [13], [14]], + [[11], [12], [15], [16]]]], dtype=dtype), + data_format)) def testSpaceToDepth(self): + def make_op(data_format): + def op(x): + return array_ops.space_to_depth(x, block_size=2, + data_format=data_format) + return op + for dtype in self.numeric_types: - self._assertOpOutputMatchesExpected( - lambda x: array_ops.space_to_depth(x, block_size=2), - np.array([[[[1], [2]], - [[3], [4]]]], dtype=dtype), - expected=np.array([[[[1, 2, 3, 4]]]], dtype=dtype)) + for data_format in ["NCHW", "NHWC"]: + self._assertOpOutputMatchesExpected( + make_op(data_format), + nhwc_to_format(np.array([[[[1], [2]], + [[3], [4]]]], dtype=dtype), + data_format), + expected=nhwc_to_format(np.array([[[[1, 2, 3, 4]]]], dtype=dtype), + data_format)) - self._assertOpOutputMatchesExpected( - lambda x: array_ops.space_to_depth(x, block_size=2), - np.array([[[[1, 2, 3], [4, 5, 6]], - [[7, 8, 9], [10, 11, 12]]]], dtype=dtype), - expected=np.array([[[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]]], - dtype=dtype)) + self._assertOpOutputMatchesExpected( + make_op(data_format), + nhwc_to_format(np.array([[[[1, 2, 3], [4, 5, 6]], + [[7, 8, 9], [10, 11, 12]]]], dtype=dtype), + data_format), + expected=nhwc_to_format( + np.array([[[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]]], + dtype=dtype), + data_format)) - self._assertOpOutputMatchesExpected( - lambda x: array_ops.space_to_depth(x, block_size=2), - np.array([[[[1], [2], [5], [6]], - [[3], [4], [7], [8]], - [[9], [10], [13], [14]], - [[11], [12], [15], [16]]]], dtype=dtype), - expected=np.array([[[[1, 2, 3, 4], - [5, 6, 7, 8]], - [[9, 10, 11, 12], - [13, 14, 15, 16]]]], dtype=dtype)) + self._assertOpOutputMatchesExpected( + make_op(data_format), + nhwc_to_format(np.array([[[[1], [2], [5], [6]], + [[3], [4], [7], [8]], + [[9], [10], [13], [14]], + [[11], [12], [15], [16]]]], dtype=dtype), + data_format), + expected=nhwc_to_format( + np.array([[[[1, 2, 3, 4], + [5, 6, 7, 8]], + [[9, 10, 11, 12], + [13, 14, 15, 16]]]], dtype=dtype), + data_format)) def _assertSoftplusMatchesExpected(self, features, dtype): features = np.array(features, dtype=dtype) diff --git a/tensorflow/compiler/tests/variable_ops_test.py b/tensorflow/compiler/tests/variable_ops_test.py index c50342dee45eba6ae54f01653ecc81ef096b547b..b08d6ab21e0746558cb3d4818d4c822c45d2e9ee 100644 --- a/tensorflow/compiler/tests/variable_ops_test.py +++ b/tensorflow/compiler/tests/variable_ops_test.py @@ -107,11 +107,26 @@ class VariableOpsTest(XLATestCase): [[[30, 31, 32], [33, 34, 35]], [[0, 1, 2], [3, 4, 5]]]], ).astype(dtype), sess.run(x)) + def testShape(self): + for dtype in self.numeric_types: + init = np.ones([2, 3]).astype(dtype) + with self.test_session() as session, self.test_scope(): + v = resource_variable_ops.ResourceVariable(init) + session.run(variables.variables_initializer([v])) + h = v.handle + s32, s64 = session.run([ + resource_variable_ops.variable_shape(h), + resource_variable_ops.variable_shape(h, out_type=dtypes.int64) + ]) + self.assertEqual(s32.dtype, np.int32) + self.assertEqual(s64.dtype, np.int64) + self.assertAllEqual(s32, [2, 3]) + self.assertAllEqual(s64, [2, 3]) + def testReadWrite(self): """Tests initialization, reading, and writing a resource variable.""" for dtype in self.numeric_types: with self.test_session() as session: - print(ops.get_default_graph()) with self.test_scope(): with variable_scope.variable_scope("ascope", use_resource=True): x = variable_scope.get_variable( diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 3c94bcafc1d19b1bc54887e6f2c25b1886be646e..5d1cb6d73570a1a3efbe0d2d37d9746bc0e2528f 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -1,6 +1,6 @@ licenses(["notice"]) # Apache 2.0 -load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test") package_group( name = "internal", @@ -25,6 +25,30 @@ package( load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured") load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library") +cc_library( + name = "tf2xla_supported_ops_lib", + srcs = ["tf2xla_supported_ops.cc"], + hdrs = ["tf2xla_supported_ops.h"], + visibility = ["//visibility:public"], + deps = [ + ":xla_compiler", + "//tensorflow/compiler/tf2xla/kernels:xla_cpu_only_ops", + "//tensorflow/compiler/tf2xla/kernels:xla_ops", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:ops", + "//tensorflow/core:protos_all_cc", + ], +) + +tf_cc_binary( + name = "tf2xla_supported_ops", + srcs = ["tf2xla_supported_ops_main.cc"], + visibility = ["//visibility:public"], + deps = [":tf2xla_supported_ops_lib"], +) + xla_proto_library( name = "tf2xla_proto", srcs = ["tf2xla.proto"], @@ -67,7 +91,6 @@ cc_library( # Keep dependencies to a minimum here; this library is used in every AOT # binary produced by tfcompile. "//tensorflow/compiler/aot:runtime", - "//tensorflow/compiler/tf2xla:xla_local_runtime_context", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/core:framework_lite", ], @@ -123,6 +146,9 @@ cc_library( ":const_analysis", ":dump_graph", ":functionalize_control_flow", + ":sharding_util", + ":tf2xla_util", + "//tensorflow/compiler/tf2xla/lib:util", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", @@ -169,6 +195,35 @@ cc_library( ], ) +cc_library( + name = "sharding_util", + srcs = ["sharding_util.cc"], + hdrs = ["sharding_util.h"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], +) + +tf_cc_test( + name = "sharding_util_test", + srcs = ["sharding_util_test.cc"], + deps = [ + ":sharding_util", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + # Internal targets below this point. cc_library( @@ -176,11 +231,14 @@ cc_library( srcs = ["tf2xla_util.cc"], hdrs = ["tf2xla_util.h"], deps = [ + ":sharding_util", ":tf2xla_proto", + "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", + "//tensorflow/core:graph", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", ], @@ -190,8 +248,14 @@ tf_cc_test( name = "tf2xla_util_test", srcs = ["tf2xla_util_test.cc"], deps = [ + ":sharding_util", ":tf2xla_util", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:function_ops", + "//tensorflow/cc:ops", + "//tensorflow/core:core_cpu_internal", "//tensorflow/core:lib", + "//tensorflow/core:math_ops_op_lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", @@ -317,13 +381,6 @@ tf_cc_test( ], ) -cc_library( - name = "xla_local_runtime_context", - hdrs = ["xla_local_runtime_context.h"], - visibility = ["//visibility:public"], - deps = ["//tensorflow/core:framework_lite"], -) - cc_library( name = "dump_graph", srcs = [ @@ -350,6 +407,7 @@ cc_library( srcs = ["functionalize_control_flow.cc"], hdrs = ["functionalize_control_flow.h"], deps = [ + ":tf2xla_util", "//tensorflow/compiler/jit:graph_to_functiondef", "//tensorflow/compiler/jit:union_find", "//tensorflow/compiler/tf2xla:dump_graph", @@ -359,6 +417,7 @@ cc_library( "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", + "//tensorflow/core:graph", "//tensorflow/core:lib", ], ) diff --git a/tensorflow/compiler/tf2xla/const_analysis.cc b/tensorflow/compiler/tf2xla/const_analysis.cc index 102a2cf07b51486bb445b0311966717b7e82ace6..ab2f1e9a7ab577bbe704e568b21d9912439605ca 100644 --- a/tensorflow/compiler/tf2xla/const_analysis.cc +++ b/tensorflow/compiler/tf2xla/const_analysis.cc @@ -52,6 +52,8 @@ Status BackwardsConstAnalysis(const Graph& g, {"Conv2DBackpropInput", "input_sizes"}, {"Conv3DBackpropFilterV2", "filter_sizes"}, {"Conv3DBackpropInputV2", "input_sizes"}, + {"Cumprod", "axis"}, + {"Cumsum", "axis"}, {"DepthwiseConv2dNativeBackpropFilter", "filter_sizes"}, {"DepthwiseConv2dNativeBackpropInput", "input_sizes"}, {"DynamicStitch", "indices"}, @@ -69,6 +71,7 @@ Status BackwardsConstAnalysis(const Graph& g, {"Pad", "paddings"}, {"PadV2", "paddings"}, {"MirrorPad", "paddings"}, + {"Multinomial", "num_samples"}, {"Prod", "reduction_indices"}, {"RandomStandardNormal", "shape"}, {"RandomUniform", "shape"}, @@ -77,6 +80,7 @@ Status BackwardsConstAnalysis(const Graph& g, {"Range", "limit"}, {"Range", "delta"}, {"Reshape", "shape"}, + {"ResizeBilinear", "size"}, {"ResourceStridedSliceAssign", "begin"}, {"ResourceStridedSliceAssign", "end"}, {"ResourceStridedSliceAssign", "strides"}, diff --git a/tensorflow/compiler/tf2xla/dump_graph.cc b/tensorflow/compiler/tf2xla/dump_graph.cc index ddd912b87315f7943915153b5bf73531107af54d..03603ee9baefd1d20d220faf63c9c1c427ebdf31 100644 --- a/tensorflow/compiler/tf2xla/dump_graph.cc +++ b/tensorflow/compiler/tf2xla/dump_graph.cc @@ -63,7 +63,12 @@ string MakeUniquePath(string name) { string DumpGraphDefToFile(const string& name, GraphDef const& graph_def) { string path = MakeUniquePath(name); - TF_CHECK_OK(WriteTextProto(Env::Default(), path, graph_def)); + Status status = WriteTextProto(Env::Default(), path, graph_def); + if (!status.ok()) { + VLOG(1) << "Failed to dump GraphDef to file: " << path << " : " << status; + path.clear(); + path = "(unavailable)"; + } return path; } @@ -79,7 +84,13 @@ string DumpGraphToFile(const string& name, Graph const& graph, string DumpFunctionDefToFile(const string& name, FunctionDef const& fdef) { string path = MakeUniquePath(name); - TF_CHECK_OK(WriteTextProto(Env::Default(), path, fdef)); + Status status = WriteTextProto(Env::Default(), path, fdef); + if (!status.ok()) { + VLOG(1) << "Failed to dump FunctionDef to file: " << path << " : " + << status; + path.clear(); + path = "(unavailable)"; + } return path; } diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index 35b6960a98cda1bf098f3e01cac3df8173bdc729..267268298c97560a3409b0bdc134526b60e39e5b 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/jit/graph_to_functiondef.h" #include "tensorflow/compiler/jit/union_find.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/common_runtime/function.h" @@ -129,7 +130,9 @@ Status CopySubgraph(const Graph& graph, const Frame* frame, stack.push_back(src); } Node* src_copy = (*node_map)[e->src()->id()]; - int src_output = squash_src_outputs[e->src()->id()] ? 0 : e->src_output(); + int src_output = squash_src_outputs[e->src()->id()] && !e->IsControlEdge() + ? 0 + : e->src_output(); Node* dst_copy = (*node_map)[e->dst()->id()]; output->AddEdge(src_copy, src_output, dst_copy, e->dst_input()); } @@ -405,7 +408,15 @@ Status FunctionalizeLoop(Graph* graph, Frame* frame, arg.merge->name()); } - // Find the Exit successor of the Switch. + // Update the device on the Identity outputs of the switch to match their + // target. These Identity outputs do not + + // Loop over the switch node's output to: + // - Find the Exit successor. + // - Set the sharding on all Identity outputs of the switch. These + // identity nodes are values used by the loop body or condition. + // The Identity node may have the wrong device so copy the device from + // one of its outputs instead. for (const Edge* edge : arg.switch_node->out_edges()) { if (edge->src_output() == 0 && IsExit(edge->dst())) { if (arg.exit != nullptr) { @@ -413,6 +424,9 @@ Status FunctionalizeLoop(Graph* graph, Frame* frame, arg.switch_node->name()); } arg.exit = edge->dst(); + } else if (StringPiece(edge->dst()->type_string()) == "Identity") { + TF_RETURN_IF_ERROR( + SetNodeShardingFromNeighbors(edge->dst(), /*out_edges=*/true)); } } } @@ -609,11 +623,12 @@ class FunctionalizeCond { FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library) : clusters_(graph->num_node_ids()), library_(library), graph_(graph) {} - // Returns a vector of Merge nodes from the clustered graph where the nodes + // Returns a vector of Switch nodes from the clustered graph where the nodes // are sorted by the number of switch nodes minus number of merge nodes // from a root of the clustered graph to the given Merge node, with ties - // broken by the representative of the Cluster. - std::vector> SortedMergeNodes(); + // broken by the representative of the Cluster. This corresponds to sorting by + // nesting depth, from deepest nested to outermost. + std::vector> SortedSwitchNodes(); // Returns whether the graph has no conditionals. bool NoConditionals() const { return merge_nodes_.empty(); } @@ -640,15 +655,17 @@ class FunctionalizeCond { // extracting the bodies needed for the then and else branch, creates a XlaIf // node, removing the nodes of the branches from the graph and replacing the // merge node with a XlaIf. - Status ConvertMergeToXlaIf(Cluster* merge_cluster); + Status ConvertCorrespondingMergeToXlaIf(Cluster* switch_cluster); // Removes a Switch cluster feeding directly into a Merge cluster by removing // the Switch and Merge nodes and collapsing into a single cluster. - Status RemoveTrivialMerge(Cluster* merge_cluster); + Status RemoveTrivialSwitch(Cluster* switch_cluster); - // Returns the switch cluster corresponding to the merge node. This function - // only returns the switch cluster in the simple case where we have a switch - // node is the entry of a diamond corresponding to a conditional: + // Returns the merge cluster corresponding to the switch node. This function + // only returns the merge cluster in the case where we have a switch node that + // is the single entry point for all paths to a common merge cluster, this + // merge cluster may be created by combining multiple merge clusters, that + // share the switch cluster as common ancestor, together. // // Switch // / \ @@ -657,8 +674,9 @@ class FunctionalizeCond { // merge_cluster // // Note: either of the branches may be empty. The case where both branches are - // empty is handled by RemoveTrivialMerge. - gtl::optional GetSwitchCluster(const Cluster& merge_cluster); + // empty is handled by RemoveTrivialSwitch. + gtl::optional CreateCorrespondingMergeCluster( + Cluster* switch_cluster); // Determines the arguments needed as input to the Merge cluster originating // from the Switch cluster. @@ -717,11 +735,12 @@ string DebugString(const Graph& graph, FunctionalizeCond::ClusterHandle::Vector* clusters) { string ret = "digraph {\ncompound=true;labeljust=\"r\";ranksep=0.24\n"; std::map subgraphs; + auto name = [](const Node* n) { + return strings::StrCat(n->type_string(), "_", n->id()); + }; for (Node* n : graph.nodes()) { - if (n->IsOp()) { - strings::StrAppend(&subgraphs[clusters->at(n).Get()], n->id(), - " [label=\"", n->name(), "\"];\n"); - } + strings::StrAppend(&subgraphs[clusters->at(n).Get()], n->id(), " [label=\"", + name(n), "\"];\n"); } for (auto kv : subgraphs) { strings::StrAppend(&ret, "subgraph cluster_", kv.first.ToString(), " {\n", @@ -729,16 +748,11 @@ string DebugString(const Graph& graph, kv.first.ToString(), "\";\n", kv.second, "}\n"); } for (Node* n : graph.nodes()) { - if (!n->IsOp()) { - continue; - } for (Node* in : n->in_nodes()) { - if (in->IsOp()) { - strings::StrAppend(&ret, in->id(), " -> ", n->id(), ";\n"); - } + strings::StrAppend(&ret, in->id(), " -> ", n->id(), ";\n"); } } - return strings::StrCat(ret, "}"); + return strings::StrCat(ret, "} // end"); } string DebugString(const FunctionalizeCond::ClusteredGraph& clustered_graph) { @@ -747,16 +761,24 @@ string DebugString(const FunctionalizeCond::ClusteredGraph& clustered_graph) { return cluster.representative.ToString(); }; for (auto kv : clustered_graph) { - strings::StrAppend(&ret, kv.first.ToString(), " [label=\"", name(kv.second), - " (", kv.second.switch_nodes.size(), ", ", - kv.second.merge_nodes.size(), ")\"];\n"); + if (!kv.second.switch_nodes.empty() || !kv.second.merge_nodes.empty()) { + strings::StrAppend( + &ret, kv.first.ToString(), " [label=\"", name(kv.second), + kv.second.switch_nodes.empty() + ? "" + : strings::StrCat(" switches=", kv.second.switch_nodes.size()), + kv.second.merge_nodes.empty() + ? "" + : strings::StrCat(" merges=", kv.second.merge_nodes.size()), + "\"];\n"); + } } for (auto kv : clustered_graph) { for (auto in : kv.second.in_nodes) { strings::StrAppend(&ret, name(*in), " -> ", name(kv.second), ";\n"); } } - return strings::StrCat(ret, "}"); + return strings::StrCat(ret, "} // end"); } bool IsDeadSwitch(const Node* node) { @@ -775,10 +797,11 @@ bool IsDeadSwitch(const Node* node) { } void FunctionalizeCond::CreateClusters() { + ClusterHandle source_cluster = ClusterHandle(Graph::kSourceId); + auto& source = clusters_.at(source_cluster); + std::deque>> workqueue; + workqueue.push_back({source_cluster, {}}); for (Node* node : graph_->nodes()) { - if (!node->IsOp()) { - continue; - } if (IsSwitch(node)) { switch_nodes_.insert(node); } else if (IsMerge(node)) { @@ -786,6 +809,12 @@ void FunctionalizeCond::CreateClusters() { } ClusterHandle& cluster = clusters_.at(node).Get(); cluster = ClusterHandle(node->id()); + // Group all source clusters together. + if (node->IsSource() || node->in_edges().empty()) { + clusters_.at(node).Merge(&source); + source.Merge(&clusters_.at(node)); + workqueue.front().second.push_back(node); + } } // If there are no Merge nodes, then terminate. @@ -800,15 +829,117 @@ void FunctionalizeCond::CreateClusters() { // conservatively assuming all merge nodes become XlaIf nodes. clusters_.resize(clusters_.size() + merge_nodes_.size()); - // Merge a cluster with its input, unless the input is a Switch node or - // the node is a Merge node. - for (const Node* node : graph_->nodes()) { - if (IsMerge(node) || IsSwitch(node) || !node->IsOp()) { - continue; + std::unordered_set marked; + while (!workqueue.empty()) { + auto cluster_queue = workqueue.front(); + VLOG(4) << "Cluster: " << cluster_queue.first << " Queue: {" + << str_util::Join(cluster_queue.second, ",", + [](string* output, const Node* node) { + strings::StrAppend(output, node->id()); + }) + << "}"; + + UnionFind& repr = clusters_.at(cluster_queue.first); + workqueue.pop_front(); + std::deque switch_nodes; + std::deque merge_nodes; + std::unordered_set cluster_member; + while (!cluster_queue.second.empty()) { + // Iterate node workqueue and flow forward merging all nodes reachable + // that are neither a Switch or a Merge and whose inputs are all part of + // the same cluster. + Node* cur = cluster_queue.second.front(); + cluster_queue.second.pop_front(); + if (marked.find(cur) != marked.end()) { + continue; + } + if (IsMerge(cur)) { + merge_nodes.push_back(cur); + marked.insert(cur); + continue; + } + if (IsSwitch(cur)) { + switch_nodes.push_back(cur); + marked.insert(cur); + continue; + } + clusters_.at(cur).Merge(&repr); + cluster_member.insert(cur); + for (Node* out : cur->out_nodes()) { + bool all_ancestors_in_cluster = true; + for (Node* in : out->in_nodes()) { + if (IsMerge(out)) { + merge_nodes.push_back(out); + } + if (IsSwitch(out)) { + switch_nodes.push_back(out); + } + if (cluster_member.find(in) == cluster_member.end()) { + all_ancestors_in_cluster = false; + break; + } + } + if (all_ancestors_in_cluster && out->IsOp()) { + cluster_queue.second.push_back(out); + marked.insert(cur); + } + } } - for (const Node* in : node->in_nodes()) { - if (in->IsOp() && !IsSwitch(in) && !IsMerge(in)) { - clusters_.at(node).Merge(&clusters_.at(in)); + + VLOG(4) << "Switches: {" + << str_util::Join(switch_nodes, ",", + [](string* output, const Node* node) { + strings::StrAppend(output, node->id()); + }) + << "}"; + + // Merge Switch nodes with common predicate. + std::unordered_map> predicate_to_switch; + for (Node* node : switch_nodes) { + Node* tmp; + TF_CHECK_OK(node->input_node(1, &tmp)); + predicate_to_switch[tmp].push_back(node); + } + for (auto kv : predicate_to_switch) { + Node* first = kv.second.front(); + for (Node* switch_node : kv.second) { + clusters_.at(first).Merge(&clusters_.at(switch_node)); + } + } + + // Enqueue each edge of the switch node separately. That is, group all the + // nodes that are due to the true/false edge of the switch together and + // consider all nodes that only have a control dependency on the switch node + // separately. We want to group together all nodes that are part of the same + // branch, as these will be extracted into the `then` and `else` functions + // of the functional if. The ops due to control edges are different as they + // could be involved with either branch and merging them here could result + // in invalid graphs. + for (auto kv : predicate_to_switch) { + ClusterHandle none = ClusterHandle(-1); + ClusterHandle first[2] = {none, none}; + std::deque* queue[2]; + for (auto switch_node : kv.second) { + for (const auto e : switch_node->out_edges()) { + if (IsSwitch(e->dst()) || IsMerge(e->dst())) { + continue; + } + // Control edges are enqueued on their own. + if (e->IsControlEdge()) { + workqueue.push_back({Representative(e->dst()), {e->dst()}}); + continue; + } + // Combine all outputs of the same output port of a switch cluster + // into the same workqueue entry. + if (first[e->src_output()] == none) { + ClusterHandle repr = Representative(e->dst()); + first[e->src_output()] = repr; + workqueue.push_back({repr, {}}); + queue[e->src_output()] = &workqueue.back().second; + } + clusters_.at(first[e->src_output()]).Merge(&clusters_.at(e->dst())); + queue[e->src_output()]->push_back(e->dst()); + } } } } @@ -862,7 +993,7 @@ void FunctionalizeCond::CreateClusteredGraph() { for (const Node* in : node->in_nodes()) { ClusterHandle other_repr = Representative(in); // Skip source, sink and internal edges. - if (!in->IsOp() || other_repr == repr) { + if (other_repr == repr) { continue; } Cluster& cluster_node_in = clustered_graph_[other_repr]; @@ -873,7 +1004,7 @@ void FunctionalizeCond::CreateClusteredGraph() { for (const Node* out : node->out_nodes()) { ClusterHandle other_repr = Representative(out); // Skip source, sink and internal edges. - if (!out->IsOp() || other_repr == repr) { + if (other_repr == repr) { continue; } Cluster& cluster_node_out = clustered_graph_[other_repr]; @@ -883,6 +1014,7 @@ void FunctionalizeCond::CreateClusteredGraph() { } return cluster_node; }; + update_cluster_for_node(graph_->source_node()); for (Node* node : switch_nodes_) { update_cluster_for_node(node).switch_nodes.insert(node); } @@ -890,74 +1022,64 @@ void FunctionalizeCond::CreateClusteredGraph() { update_cluster_for_node(node).merge_nodes.insert(node); } - // Merge Switch nodes with common predicate. - std::unordered_map> predicate_to_switch; - for (Node* node : switch_nodes_) { - Node* tmp; - TF_CHECK_OK(node->input_node(1, &tmp)); - predicate_to_switch[tmp].push_back(node); - } - for (auto kv : predicate_to_switch) { - Cluster& first = clustered_graph_.at(Representative(kv.second.front())); - for (Node* switch_node : kv.second) { - ClusterHandle handle = Representative(switch_node); - Cluster& cluster = clustered_graph_.at(handle); - ContractEdge(&cluster, &first, /*remove_from_graph=*/true); - } - } - - // Merge Merge nodes with common input together. - for (Node* node : merge_nodes_) { - Cluster& cluster = clustered_graph_.at(Representative(node)); - for (const Node* in : node->in_nodes()) { - if (!in->IsOp()) { - continue; - } - Cluster& cluster_node_in = clustered_graph_.at(Representative(in)); - // ContractEdge can modify out_nodes of cluster_node_in, so traverse - // over out_nodes assuming it does. - for (auto it = cluster_node_in.out_nodes.begin(); - it != cluster_node_in.out_nodes.end();) { - if (!(*it)->merge_nodes.empty()) { - ContractEdge(*it++, &cluster, /*remove_from_graph=*/true); - } else { - ++it; - } - } - } - } - VLOG(3) << "Graph with clusters: " << DebugString(*graph_, &clusters_); VLOG(3) << "ClusteredGraph: " << DebugString(clustered_graph_); } -gtl::optional FunctionalizeCond::GetSwitchCluster( - const Cluster& merge_cluster) { - VLOG(3) << "GetSwitchCluster for " << merge_cluster.representative; - gtl::optional switch_cluster; - if (merge_cluster.in_nodes.size() > 2) { - return gtl::nullopt; +gtl::optional +FunctionalizeCond::CreateCorrespondingMergeCluster(Cluster* switch_cluster) { + VLOG(3) << "CreateCorrespondingMergeCluster for " + << switch_cluster->representative; + std::unordered_set merges; + std::unordered_set dominated; + dominated.insert(switch_cluster); + std::deque queue; + auto enqueue_or_update_merge = [this, &queue, &merges](Cluster* c) { + if (c->merge_nodes.empty()) { + queue.push_back(c); + } else { + merges.insert(c); + } + }; + // Enqueue all the outputs of the switch cluster in the workqueue. + for (auto* out : switch_cluster->out_nodes) { + enqueue_or_update_merge(out); } - for (Cluster* in : merge_cluster.in_nodes) { - Cluster* cluster = in; - if (in->switch_nodes.empty()) { - if (in->in_nodes.size() != 1) { + std::unordered_set visited; + while (!queue.empty()) { + Cluster* cur = queue.front(); + queue.pop_front(); + if (visited.find(cur) != visited.end()) { + continue; + } + visited.insert(cur); + // Ensure all inputs to the current node are in the dominated set. + for (Cluster* in : cur->in_nodes) { + if (dominated.find(in) == dominated.end()) { return gtl::nullopt; } - // There is only a single `in` cluster. - cluster = *in->in_nodes.begin(); } - if (cluster->switch_nodes.empty()) { - return gtl::nullopt; - } - - if (switch_cluster.has_value() && *switch_cluster != cluster) { - return gtl::nullopt; - } else { - switch_cluster = cluster; + for (Cluster* out : cur->out_nodes) { + // No switch nodes beyond the entry one is expected. + if (!out->switch_nodes.empty()) { + return gtl::nullopt; + } + enqueue_or_update_merge(out); } } - return switch_cluster; + // Return if there are no merge nodes. + if (merges.empty()) { + return gtl::nullopt; + } + auto it = merges.begin(); + Cluster* merge_cluster = *it; + for (++it; it != merges.end(); ++it) { + ContractEdge(*it, merge_cluster); + } + + // TODO(jpienaar): Clean up graph, merging nodes. + + return merge_cluster; } xla::StatusOr FunctionalizeCond::DetermineCondArgs( @@ -1201,11 +1323,11 @@ void FunctionalizeCond::RemoveMergeNodes(Cluster* merge_cluster) { } } -Status FunctionalizeCond::RemoveTrivialMerge(Cluster* merge_cluster) { - Cluster* switch_cluster = *merge_cluster->in_nodes.begin(); - if (switch_cluster->switch_nodes.empty()) { +Status FunctionalizeCond::RemoveTrivialSwitch(Cluster* switch_cluster) { + Cluster* merge_cluster = *switch_cluster->out_nodes.begin(); + if (merge_cluster->merge_nodes.empty()) { return errors::FailedPrecondition( - "Not a trivial merge: no Switch node feeding into Merge node"); + "Not a trivial switch: no Merge node feeding into Switch node"); } for (auto it = merge_cluster->merge_nodes.begin(); @@ -1232,17 +1354,25 @@ Status FunctionalizeCond::RemoveTrivialMerge(Cluster* merge_cluster) { return Status::OK(); } -Status FunctionalizeCond::ConvertMergeToXlaIf(Cluster* merge_cluster) { - VLOG(1) << "ConvertMergeToXlaIf for " << merge_cluster->representative; - gtl::optional switch_cluster = GetSwitchCluster(*merge_cluster); - if (!switch_cluster.has_value()) { +Status FunctionalizeCond::ConvertCorrespondingMergeToXlaIf( + Cluster* switch_cluster) { + VLOG(1) << "ConvertMergeToXlaIf for " << switch_cluster->representative; + gtl::optional maybe_merge = + CreateCorrespondingMergeCluster(switch_cluster); + if (!maybe_merge.has_value()) { return errors::FailedPrecondition( - "Merge cluster was not part of a simple conditional in the clustered " - "graph. Graph nodes in merge cluster ", - NodesToString(merge_cluster->merge_nodes)); + "Switch cluster was not part of a simple conditional in the clustered " + "graph. Graph nodes in switch cluster ", + NodesToString(switch_cluster->switch_nodes)); + } + Cluster* merge_cluster = *maybe_merge; + if (merge_cluster->merge_nodes.empty()) { + return errors::Internal( + "Merge node in clustered graph contains no merge nodes: ", + merge_cluster->representative.ToString()); } TF_ASSIGN_OR_RETURN(auto cond_args, - DetermineCondArgs(*merge_cluster, **switch_cluster)); + DetermineCondArgs(*merge_cluster, *switch_cluster)); // Sort the outputs by ID to produce more stable output. std::vector outputs(merge_cluster->merge_nodes.begin(), @@ -1258,7 +1388,7 @@ Status FunctionalizeCond::ConvertMergeToXlaIf(Cluster* merge_cluster) { // Remove the old nodes from the graph_ and contract the edges of the // clustered graph. for (auto in : merge_cluster->in_nodes) { - if (in != *switch_cluster) { + if (in != switch_cluster) { RemoveClusterNodes(in); } } @@ -1266,23 +1396,20 @@ Status FunctionalizeCond::ConvertMergeToXlaIf(Cluster* merge_cluster) { RemoveUnusedArgs(cond_args.args); auto in_nodes = merge_cluster->in_nodes; for (auto it = in_nodes.begin(); it != in_nodes.end();) { - ContractEdge(*it++, merge_cluster); + ContractEdge(*it++, switch_cluster); } - ContractEdge(*switch_cluster, merge_cluster); - clusters_[if_node].Get() = ClusterHandle(merge_cluster->representative); + ContractEdge(merge_cluster, switch_cluster); + clusters_[if_node].Get() = ClusterHandle(switch_cluster->representative); return Status::OK(); } std::vector> -FunctionalizeCond::SortedMergeNodes() { +FunctionalizeCond::SortedSwitchNodes() { VLOG(2) << "ProcessClusteredGraph"; std::stack> stack; - for (auto& c : clustered_graph_) { - if (c.second.in_nodes.empty()) { - stack.push({0, &c.second}); - } - } + // Initialize with the source node. + stack.push({0, &clustered_graph_[Representative(graph_->source_node())]}); // Perform a depth-first traversal of the clustered graph computing the // switch-merge depth. @@ -1300,10 +1427,10 @@ FunctionalizeCond::SortedMergeNodes() { size_t new_depth = depth; if (!n->merge_nodes.empty()) { - queue.emplace_back(depth, n); --new_depth; } if (!n->switch_nodes.empty()) { + queue.emplace_back(depth, n); ++new_depth; } for (Cluster* e : n->out_nodes) { @@ -1333,25 +1460,30 @@ Status FunctionalizeCond::Functionalize(Graph* graph, } fc.CreateClusteredGraph(); - auto queue = fc.SortedMergeNodes(); + auto queue = fc.SortedSwitchNodes(); for (auto it = queue.begin(); it != queue.end();) { - Cluster* merge_cluster = (*it).second; + Cluster* switch_cluster = (*it).second; ++it; - if (merge_cluster->in_nodes.size() == 1) { - TF_RETURN_IF_ERROR(fc.RemoveTrivialMerge(merge_cluster)); + if (switch_cluster->out_nodes.size() == 1) { + TF_RETURN_IF_ERROR(fc.RemoveTrivialSwitch(switch_cluster)); } else { - TF_RETURN_IF_ERROR(fc.ConvertMergeToXlaIf(merge_cluster)); + TF_RETURN_IF_ERROR(fc.ConvertCorrespondingMergeToXlaIf(switch_cluster)); } - // Contract newly Merge free merge_cluster with incoming nodes without + // Contract newly Switch free switch_cluster with outgoing nodes without // Switch or Merge nodes. - std::vector in_nodes(merge_cluster->in_nodes.begin(), - merge_cluster->in_nodes.end()); - for (auto in : in_nodes) { - if (in->merge_nodes.empty() && in->switch_nodes.empty()) { - fc.ContractEdge(in, merge_cluster); + for (auto& nodes : {switch_cluster->out_nodes, switch_cluster->in_nodes}) { + std::vector copy_nodes(nodes.begin(), nodes.end()); + for (auto* node : copy_nodes) { + if (node->merge_nodes.empty() && node->switch_nodes.empty()) { + fc.ContractEdge(node, switch_cluster); + } } } + + VLOG(3) << "Graph with clusters: " + << DebugString(*fc.graph_, &fc.clusters_); + VLOG(3) << "ClusteredGraph: " << DebugString(fc.clustered_graph_); } if (!fc.switch_nodes_.empty()) { diff --git a/tensorflow/compiler/tf2xla/g3doc/cpu_supported_ops.md b/tensorflow/compiler/tf2xla/g3doc/cpu_supported_ops.md new file mode 100644 index 0000000000000000000000000000000000000000..82b3b46a2f1e97001d1e0c6b993ec243170bc7d8 --- /dev/null +++ b/tensorflow/compiler/tf2xla/g3doc/cpu_supported_ops.md @@ -0,0 +1,242 @@ +**Supported operators for device: XLA_CPU_JIT** + +Operator | Type Constraint +------------------------------------- | --------------- +`Abs` | `T={double,float,int32,int64}` +`Acosh` | `T={complex64,double,float}` +`Add` | `T={complex64,double,float,int32,int64}` +`AddN` | `T={complex64,double,float,int32,int64,uint32,uint64}` +`All` | `Tidx={int32,int64}` +`Angle` | `Tout={double,float}`
`T={complex64}` +`Any` | `Tidx={int32,int64}` +`ApproximateEqual` | `T={complex64,double,float,int32,int64,uint32,uint64}` +`ArgMax` | `Tidx={int32,int64}`
`output_type={int32,int64}`
`T={float}` +`ArgMin` | `Tidx={int32,int64}`
`output_type={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`Asinh` | `T={complex64,double,float}` +`AssignAddVariableOp` | `dtype={complex64,double,float,int32,int64,uint32,uint64}` +`AssignSubVariableOp` | `dtype={complex64,double,float,int32,int64,uint32,uint64}` +`AssignVariableOp` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Atan2` | `T={double,float}` +`Atanh` | `T={complex64,double,float}` +`AvgPool` | `T={double,float}` +`AvgPool3D` | `T={double,float}` +`AvgPool3DGrad` | `T={double,float}` +`AvgPoolGrad` | `T={double,float}` +`BatchMatMul` | `T={complex64,double,float,int32}` +`BatchToSpace` | `Tidx={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`BatchToSpaceND` | `Tcrops={int32,int64}`
`Tblock_shape={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`BiasAdd` | `T={complex64,double,float,int32,int64,uint32,uint64}` +`BiasAddGrad` | `T={complex64,double,float,int32,int64,uint32,uint64}` +`BiasAddV1` | `T={complex64,double,float,int32,int64,uint32,uint64}` +`BitwiseAnd` | `T={int32,int64,uint32,uint64}` +`BitwiseOr` | `T={int32,int64,uint32,uint64}` +`BroadcastArgs` | `T={int32,int64}` +`BroadcastGradientArgs` | `T={int32,int64}` +`Cast` | `DstT={bool,complex64,double,float,int32,int64,uint32,uint64}`
`SrcT={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Ceil` | `T={double,float}` +`Cholesky` | `T={complex64,double,float}` +`Complex` | `Tout={complex64}`
`T={double,float}` +`ComplexAbs` | `Tout={double,float}`
`T={complex64}` +`Concat` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`ConcatOffset` | +`ConcatV2` | `Tidx={int32}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Conj` | `T={complex64}` +`Const` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`ControlTrigger` | +`Conv2D` | `T={float}` +`Conv2DBackpropFilter` | `T={float}` +`Conv2DBackpropInput` | `T={float}` +`Conv3D` | `T={double,float}` +`Conv3DBackpropFilterV2` | `T={double,float}` +`Conv3DBackpropInputV2` | `T={double,float}` +`Cos` | `T={complex64,double,float}` +`Cosh` | `T={complex64,double,float}` +`Cross` | `T={double,float,int32,int64,uint32,uint64}` +`Cumprod` | `Tidx={int32,int64}`
`T={float}` +`Cumsum` | `Tidx={int32,int64}`
`T={float}` +`DepthToSpace` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`DepthwiseConv2dNative` | `T={double,float}` +`DepthwiseConv2dNativeBackpropFilter` | `T={double,float}` +`DepthwiseConv2dNativeBackpropInput` | `T={double,float}` +`Diag` | `T={complex64,double,float,int32,int64}` +`DiagPart` | `T={complex64,double,float,int32,int64}` +`Div` | `T={complex64,double,float,int32,int64}` +`DynamicStitch` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Elu` | `T={double,float}` +`EluGrad` | `T={double,float}` +`Equal` | `T={bool,complex64,double,float,int32,int64}` +`Exp` | `T={complex64,double,float}` +`ExpandDims` | `Tdim={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Expm1` | `T={complex64,double,float}` +`Fill` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Floor` | `T={double,float}` +`FloorDiv` | `T={complex64,double,float,int32,int64}` +`FloorMod` | `T={double,float,int32,int64}` +`FusedBatchNorm` | `T={float}` +`FusedBatchNormGrad` | `T={float}` +`FusedBatchNormGradV2` | `U={float}`
`T={float}` +`FusedBatchNormV2` | `U={float}`
`T={float}` +`Gather` | `Tindices={int32,int64}`
`Tparams={bool,complex64,double,float,int32,int64,uint32,uint64}` +`GatherV2` | `Taxis={int32,int64}`
`Tindices={int32,int64}`
`Tparams={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Greater` | `T={double,float,int32,int64,uint32,uint64}` +`GreaterEqual` | `T={double,float,int32,int64,uint32,uint64}` +`Identity` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`IdentityN` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Imag` | `Tout={double,float}`
`T={complex64}` +`Inv` | `T={complex64,double,float,int32,int64}` +`Invert` | `T={int32,int64,uint32,uint64}` +`InvertPermutation` | `T={int32}` +`IsFinite` | `T={double,float}` +`IsInf` | `T={double,float}` +`IsNan` | `T={double,float}` +`L2Loss` | `T={double,float}` +`LRN` | `T={float}` +`LRNGrad` | `T={float}` +`LeftShift` | `T={int32,int64,uint32,uint64}` +`Less` | `T={double,float,int32,int64,uint32,uint64}` +`LessEqual` | `T={double,float,int32,int64,uint32,uint64}` +`LinSpace` | `Tidx={int32,int64}`
`T={double,float}` +`Log` | `T={complex64,double,float}` +`Log1p` | `T={complex64,double,float}` +`LogSoftmax` | `T={double,float}` +`LogicalAnd` | +`LogicalNot` | +`LogicalOr` | +`MatMul` | `T={complex64,double,float}` +`MatrixDiag` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`MatrixDiagPart` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Max` | `Tidx={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`MaxPool` | `T={double,float,int32,int64}` +`MaxPool3D` | `T={float}` +`MaxPool3DGrad` | `TInput={float}`
`T={float}` +`MaxPoolGrad` | `T={double,float,int32,int64,uint32,uint64}` +`Maximum` | `T={double,float,int32,int64}` +`Mean` | `Tidx={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`Min` | `Tidx={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`Minimum` | `T={double,float,int32,int64}` +`MirrorPad` | `Tpaddings={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Mod` | `T={double,float,int32,int64}` +`Mul` | `T={complex64,double,float,int32,int64}` +`Multinomial` | `output_dtype={int32,int64}`
`T={double,float,int32,int64,uint32,uint64}` +`Neg` | `T={complex64,double,float,int32,int64}` +`NoOp` | +`NotEqual` | `T={bool,complex64,double,float,int32,int64}` +`OneHot` | `TI={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`OnesLike` | `T={bool,complex64,double,float,int32,int64}` +`Pack` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Pad` | `Tpaddings={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`PadV2` | `Tpaddings={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`ParallelDynamicStitch` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Pow` | `T={complex64,double,float,int32,int64}` +`PreventGradient` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Prod` | `Tidx={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`QuantizeAndDequantizeV2` | `T={double,float}` +`RandomStandardNormal` | `dtype={float}` +`RandomUniform` | `T={int32,int64}`
`dtype={double,float}` +`RandomUniformInt` | `T={int32,int64}`
`Tout={int32,int64}` +`Range` | `Tidx={double,float,int32,int64}` +`Rank` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`ReadVariableOp` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Real` | `Tout={double,float}`
`T={complex64}` +`RealDiv` | `T={complex64,double,float,int32,int64}` +`Reciprocal` | `T={complex64,double,float,int32,int64}` +`ReciprocalGrad` | `T={complex64,double,float}` +`Relu` | `T={double,float,int32,int64,uint32,uint64}` +`Relu6` | `T={double,float,int32,int64,uint32,uint64}` +`Relu6Grad` | `T={double,float,int32,int64,uint32,uint64}` +`ReluGrad` | `T={double,float,int32,int64,uint32,uint64}` +`Reshape` | `Tshape={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`ResourceApplyAdagrad` | `T={double,float}` +`ResourceApplyAdam` | `T={double,float}` +`ResourceApplyFtrl` | `T={double,float}` +`ResourceApplyFtrlV2` | `T={double,float}` +`ResourceApplyGradientDescent` | `T={double,float}` +`ResourceApplyMomentum` | `T={double,float}` +`ResourceApplyRMSProp` | `T={double,float}` +`ResourceGather` | `Tindices={int32,int64}`
`dtype={complex64,double,float,int32,int64,uint32,uint64}` +`ResourceStridedSliceAssign` | `Index={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Reverse` | `T={bool,complex64,double,float,int32,int64}` +`ReverseV2` | `T={bool,complex64,double,float,int32,int64}`
`Tidx={int32,int64}` +`RightShift` | `T={int32,int64,uint32,uint64}` +`Rint` | `T={double,float}` +`Round` | `T={complex64,double,float,int32,int64}` +`Rsqrt` | `T={complex64,double,float}` +`RsqrtGrad` | `T={complex64,double,float}` +`Select` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Selu` | `T={double,float}` +`SeluGrad` | `T={double,float}` +`Shape` | `out_type={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`ShapeN` | `out_type={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Sigmoid` | `T={complex64,double,float}` +`SigmoidGrad` | `T={complex64,double,float}` +`Sign` | `T={complex64,double,float,int32,int64}` +`Sin` | `T={complex64,double,float}` +`Sinh` | `T={complex64,double,float}` +`Size` | `out_type={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Slice` | `Index={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Softmax` | `T={double,float}` +`SoftmaxCrossEntropyWithLogits` | `T={double,float}` +`Softplus` | `T={double,float,int32,int64,uint32,uint64}` +`SoftplusGrad` | `T={double,float,int32,int64,uint32,uint64}` +`Softsign` | `T={double,float,int32,int64,uint32,uint64}` +`SoftsignGrad` | `T={double,float,int32,int64,uint32,uint64}` +`SpaceToBatch` | `Tpaddings={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`SpaceToBatchND` | `Tblock_shape={int32,int64}`
`Tpaddings={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`SpaceToDepth` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`SparseMatMul` | `Tb={float}`
`Ta={float}` +`SparseSoftmaxCrossEntropyWithLogits` | `Tlabels={int32,int64}`
`T={double,float}` +`Split` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`SplitV` | `Tlen={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Sqrt` | `T={complex64,double,float}` +`SqrtGrad` | `T={complex64,double,float}` +`Square` | `T={complex64,double,float,int32,int64}` +`SquaredDifference` | `T={complex64,double,float,int32,int64}` +`Squeeze` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`StackCloseV2` | +`StackPopV2` | `elem_type={bool,complex64,double,float,int32,int64,uint32,uint64}` +`StackPushV2` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`StackV2` | `elem_type={bool,complex64,double,float,int32,int64,uint32,uint64}` +`StatelessRandomNormal` | `Tseed={int32}`
`T={int32,int64}`
`dtype={float}` +`StatelessRandomUniform` | `Tseed={int32}`
`T={int32,int64}`
`dtype={float}` +`StopGradient` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`StridedSlice` | `Index={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`StridedSliceGrad` | `Index={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Sub` | `T={complex64,double,float,int32,int64}` +`Sum` | `Tidx={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`SymbolicGradient` | `Tout={bool,complex64,double,float,int32,int64,uint32,uint64}`
`Tin={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Tan` | `T={complex64,double,float,int32,int64}` +`Tanh` | `T={complex64,double,float}` +`TanhGrad` | `T={complex64,double,float}` +`TensorArrayCloseV3` | +`TensorArrayConcatV3` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`TensorArrayGatherV3` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`TensorArrayGradV3` | +`TensorArrayReadV3` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`TensorArrayScatterV3` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`TensorArraySizeV3` | +`TensorArraySplitV3` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`TensorArrayV3` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`TensorArrayWriteV3` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Tile` | `Tmultiples={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Transpose` | `Tperm={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`TruncateDiv` | `T={complex64,double,float,int32,int64}` +`TruncateMod` | `T={double,float,int32,int64}` +`TruncatedNormal` | `T={int32,int64}`
`dtype={double,float}` +`Unpack` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`UnsortedSegmentSum` | `Tnumsegments={int32,int64}`
`Tindices={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`VarIsInitializedOp` | +`VariableShape` | `out_type={int32,int64}` +`XlaWhile` | `T={bool,complex64,double,float,int32,int64,resource,uint32,uint64}` +`ZerosLike` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`_Arg` | `T={bool,complex64,double,float,int32,int64,resource,uint32,uint64}` +`_ArrayToList` | `out_types={bool,complex64,double,float,int32,int64,uint32,uint64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`_ListToArray` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
`Tin={bool,complex64,double,float,int32,int64,uint32,uint64}` +`_Retval` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`_XLARecv` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`_XLASend` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` + +To regenerate this table, run: + +```shell +bazel run -c opt -- tensorflow/compiler/tf2xla:tf2xla_supported_ops --device=XLA_CPU_JIT +``` diff --git a/tensorflow/compiler/tf2xla/g3doc/gpu_supported_ops.md b/tensorflow/compiler/tf2xla/g3doc/gpu_supported_ops.md new file mode 100644 index 0000000000000000000000000000000000000000..d4b7621ad2858fe17e93d292dd807e4f7c1c336b --- /dev/null +++ b/tensorflow/compiler/tf2xla/g3doc/gpu_supported_ops.md @@ -0,0 +1,238 @@ +**Supported operators for device: XLA_GPU_JIT** + +Operator | Type Constraint +------------------------------------- | --------------- +`Abs` | `T={double,float,int32,int64}` +`Acosh` | `T={complex64,double,float}` +`Add` | `T={complex64,double,float,int32,int64}` +`AddN` | `T={complex64,double,float,int32,int64,uint32,uint64}` +`All` | `Tidx={int32,int64}` +`Angle` | `Tout={double,float}`
`T={complex64}` +`Any` | `Tidx={int32,int64}` +`ApproximateEqual` | `T={complex64,double,float,int32,int64,uint32,uint64}` +`ArgMax` | `Tidx={int32,int64}`
`output_type={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`ArgMin` | `Tidx={int32,int64}`
`output_type={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`Asinh` | `T={complex64,double,float}` +`AssignAddVariableOp` | `dtype={complex64,double,float,int32,int64,uint32,uint64}` +`AssignSubVariableOp` | `dtype={complex64,double,float,int32,int64,uint32,uint64}` +`AssignVariableOp` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Atan2` | `T={double,float}` +`Atanh` | `T={complex64,double,float}` +`AvgPool` | `T={double,float}` +`AvgPool3D` | `T={double,float}` +`AvgPool3DGrad` | `T={double,float}` +`AvgPoolGrad` | `T={double,float}` +`BatchMatMul` | `T={complex64,double,float,int32}` +`BatchToSpace` | `Tidx={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`BatchToSpaceND` | `Tcrops={int32,int64}`
`Tblock_shape={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`BiasAdd` | `T={complex64,double,float,int32,int64,uint32,uint64}` +`BiasAddGrad` | `T={complex64,double,float,int32,int64,uint32,uint64}` +`BiasAddV1` | `T={complex64,double,float,int32,int64,uint32,uint64}` +`BitwiseAnd` | `T={int32,int64,uint32,uint64}` +`BitwiseOr` | `T={int32,int64,uint32,uint64}` +`BroadcastArgs` | `T={int32,int64}` +`BroadcastGradientArgs` | `T={int32,int64}` +`Cast` | `DstT={bool,complex64,double,float,int32,int64,uint32,uint64}`
`SrcT={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Ceil` | `T={double,float}` +`Cholesky` | `T={complex64,double,float}` +`Complex` | `Tout={complex64}`
`T={double,float}` +`ComplexAbs` | `Tout={double,float}`
`T={complex64}` +`Concat` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`ConcatOffset` | +`ConcatV2` | `Tidx={int32}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Conj` | `T={complex64}` +`Const` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`ControlTrigger` | +`Conv2D` | `T={float}` +`Conv2DBackpropFilter` | `T={float}` +`Conv2DBackpropInput` | `T={float}` +`Conv3D` | `T={double,float}` +`Conv3DBackpropFilterV2` | `T={double,float}` +`Conv3DBackpropInputV2` | `T={double,float}` +`Cos` | `T={complex64,double,float}` +`Cosh` | `T={complex64,double,float}` +`Cross` | `T={double,float,int32,int64,uint32,uint64}` +`Cumprod` | `Tidx={int32,int64}`
`T={float}` +`Cumsum` | `Tidx={int32,int64}`
`T={float}` +`DepthToSpace` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`DepthwiseConv2dNative` | `T={double,float}` +`DepthwiseConv2dNativeBackpropFilter` | `T={double,float}` +`DepthwiseConv2dNativeBackpropInput` | `T={double,float}` +`Diag` | `T={complex64,double,float,int32,int64}` +`DiagPart` | `T={complex64,double,float,int32,int64}` +`Div` | `T={complex64,double,float,int32,int64}` +`DynamicStitch` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Elu` | `T={double,float}` +`EluGrad` | `T={double,float}` +`Equal` | `T={bool,complex64,double,float,int32,int64}` +`Exp` | `T={complex64,double,float}` +`ExpandDims` | `Tdim={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Expm1` | `T={complex64,double,float}` +`Fill` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Floor` | `T={double,float}` +`FloorDiv` | `T={complex64,double,float,int32,int64}` +`FloorMod` | `T={double,float,int32,int64}` +`FusedBatchNorm` | `T={float}` +`FusedBatchNormGrad` | `T={float}` +`FusedBatchNormGradV2` | `U={float}`
`T={float}` +`FusedBatchNormV2` | `U={float}`
`T={float}` +`Gather` | `Tindices={int32,int64}`
`Tparams={bool,complex64,double,float,int32,int64,uint32,uint64}` +`GatherV2` | `Taxis={int32,int64}`
`Tindices={int32,int64}`
`Tparams={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Greater` | `T={double,float,int32,int64,uint32,uint64}` +`GreaterEqual` | `T={double,float,int32,int64,uint32,uint64}` +`Identity` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`IdentityN` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Imag` | `Tout={double,float}`
`T={complex64}` +`Inv` | `T={complex64,double,float,int32,int64}` +`Invert` | `T={int32,int64,uint32,uint64}` +`InvertPermutation` | `T={int32}` +`IsFinite` | `T={double,float}` +`IsInf` | `T={double,float}` +`IsNan` | `T={double,float}` +`L2Loss` | `T={double,float}` +`LRN` | `T={float}` +`LRNGrad` | `T={float}` +`LeftShift` | `T={int32,int64,uint32,uint64}` +`Less` | `T={double,float,int32,int64,uint32,uint64}` +`LessEqual` | `T={double,float,int32,int64,uint32,uint64}` +`LinSpace` | `Tidx={int32,int64}`
`T={double,float}` +`Log` | `T={complex64,double,float}` +`Log1p` | `T={complex64,double,float}` +`LogSoftmax` | `T={double,float}` +`LogicalAnd` | +`LogicalNot` | +`LogicalOr` | +`MatMul` | `T={complex64,double,float}` +`MatrixDiag` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`MatrixDiagPart` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Max` | `Tidx={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`MaxPool` | `T={double,float,int32,int64}` +`MaxPool3D` | `T={float}` +`MaxPool3DGrad` | `TInput={float}`
`T={float}` +`MaxPoolGrad` | `T={double,float,int32,int64,uint32,uint64}` +`Maximum` | `T={double,float,int32,int64}` +`Mean` | `Tidx={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`Min` | `Tidx={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`Minimum` | `T={double,float,int32,int64}` +`MirrorPad` | `Tpaddings={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Mod` | `T={double,float,int32,int64}` +`Mul` | `T={complex64,double,float,int32,int64}` +`Multinomial` | `output_dtype={int32,int64}`
`T={double,float,int32,int64,uint32,uint64}` +`Neg` | `T={complex64,double,float,int32,int64}` +`NoOp` | +`NotEqual` | `T={bool,complex64,double,float,int32,int64}` +`OneHot` | `TI={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`OnesLike` | `T={bool,complex64,double,float,int32,int64}` +`Pack` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Pad` | `Tpaddings={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`PadV2` | `Tpaddings={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`ParallelDynamicStitch` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Pow` | `T={complex64,double,float,int32,int64}` +`PreventGradient` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Prod` | `Tidx={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`QuantizeAndDequantizeV2` | `T={double,float}` +`Range` | `Tidx={double,float,int32,int64}` +`Rank` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`ReadVariableOp` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Real` | `Tout={double,float}`
`T={complex64}` +`RealDiv` | `T={complex64,double,float,int32,int64}` +`Reciprocal` | `T={complex64,double,float,int32,int64}` +`ReciprocalGrad` | `T={complex64,double,float}` +`Relu` | `T={double,float,int32,int64,uint32,uint64}` +`Relu6` | `T={double,float,int32,int64,uint32,uint64}` +`Relu6Grad` | `T={double,float,int32,int64,uint32,uint64}` +`ReluGrad` | `T={double,float,int32,int64,uint32,uint64}` +`Reshape` | `Tshape={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`ResourceApplyAdagrad` | `T={double,float}` +`ResourceApplyAdam` | `T={double,float}` +`ResourceApplyFtrl` | `T={double,float}` +`ResourceApplyFtrlV2` | `T={double,float}` +`ResourceApplyGradientDescent` | `T={double,float}` +`ResourceApplyMomentum` | `T={double,float}` +`ResourceApplyRMSProp` | `T={double,float}` +`ResourceGather` | `Tindices={int32,int64}`
`dtype={complex64,double,float,int32,int64,uint32,uint64}` +`ResourceStridedSliceAssign` | `Index={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Reverse` | `T={bool,complex64,double,float,int32,int64}` +`ReverseV2` | `T={bool,complex64,double,float,int32,int64}`
`Tidx={int32,int64}` +`RightShift` | `T={int32,int64,uint32,uint64}` +`Rint` | `T={double,float}` +`Round` | `T={complex64,double,float,int32,int64}` +`Rsqrt` | `T={complex64,double,float}` +`RsqrtGrad` | `T={complex64,double,float}` +`Select` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Selu` | `T={double,float}` +`SeluGrad` | `T={double,float}` +`Shape` | `out_type={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`ShapeN` | `out_type={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Sigmoid` | `T={complex64,double,float}` +`SigmoidGrad` | `T={complex64,double,float}` +`Sign` | `T={complex64,double,float,int32,int64}` +`Sin` | `T={complex64,double,float}` +`Sinh` | `T={complex64,double,float}` +`Size` | `out_type={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Slice` | `Index={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Softmax` | `T={double,float}` +`SoftmaxCrossEntropyWithLogits` | `T={double,float}` +`Softplus` | `T={double,float,int32,int64,uint32,uint64}` +`SoftplusGrad` | `T={double,float,int32,int64,uint32,uint64}` +`Softsign` | `T={double,float,int32,int64,uint32,uint64}` +`SoftsignGrad` | `T={double,float,int32,int64,uint32,uint64}` +`SpaceToBatch` | `Tpaddings={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`SpaceToBatchND` | `Tblock_shape={int32,int64}`
`Tpaddings={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`SpaceToDepth` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`SparseMatMul` | `Tb={float}`
`Ta={float}` +`SparseSoftmaxCrossEntropyWithLogits` | `Tlabels={int32,int64}`
`T={double,float}` +`Split` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`SplitV` | `Tlen={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Sqrt` | `T={complex64,double,float}` +`SqrtGrad` | `T={complex64,double,float}` +`Square` | `T={complex64,double,float,int32,int64}` +`SquaredDifference` | `T={complex64,double,float,int32,int64}` +`Squeeze` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`StackCloseV2` | +`StackPopV2` | `elem_type={bool,complex64,double,float,int32,int64,uint32,uint64}` +`StackPushV2` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`StackV2` | `elem_type={bool,complex64,double,float,int32,int64,uint32,uint64}` +`StatelessRandomNormal` | `Tseed={int32}`
`T={int32,int64}`
`dtype={float}` +`StatelessRandomUniform` | `Tseed={int32}`
`T={int32,int64}`
`dtype={float}` +`StopGradient` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`StridedSlice` | `Index={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`StridedSliceGrad` | `Index={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Sub` | `T={complex64,double,float,int32,int64}` +`Sum` | `Tidx={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`SymbolicGradient` | `Tout={bool,complex64,double,float,int32,int64,uint32,uint64}`
`Tin={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Tan` | `T={complex64,double,float,int32,int64}` +`Tanh` | `T={complex64,double,float}` +`TanhGrad` | `T={complex64,double,float}` +`TensorArrayCloseV3` | +`TensorArrayConcatV3` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`TensorArrayGatherV3` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`TensorArrayGradV3` | +`TensorArrayReadV3` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`TensorArrayScatterV3` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`TensorArraySizeV3` | +`TensorArraySplitV3` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`TensorArrayV3` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`TensorArrayWriteV3` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Tile` | `Tmultiples={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Transpose` | `Tperm={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`TruncateDiv` | `T={complex64,double,float,int32,int64}` +`TruncateMod` | `T={double,float,int32,int64}` +`Unpack` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`UnsortedSegmentSum` | `Tnumsegments={int32,int64}`
`Tindices={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`VarIsInitializedOp` | +`VariableShape` | `out_type={int32,int64}` +`XlaWhile` | `T={bool,complex64,double,float,int32,int64,resource,uint32,uint64}` +`ZerosLike` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`_Arg` | `T={bool,complex64,double,float,int32,int64,resource,uint32,uint64}` +`_ArrayToList` | `out_types={bool,complex64,double,float,int32,int64,uint32,uint64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`_ListToArray` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
`Tin={bool,complex64,double,float,int32,int64,uint32,uint64}` +`_Retval` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`_XLARecv` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`_XLASend` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` + +To regenerate this table, run: + +```shell +bazel run -c opt -- tensorflow/compiler/tf2xla:tf2xla_supported_ops --device=XLA_GPU_JIT +``` diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 2b43e313eb42c288b891f97c0b6cd3cacdc77711..3e24cf042e17ad4e212d82ac4f24fec06a6c780f 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -18,6 +18,8 @@ tf_kernel_library( "bias_ops.cc", "binary_ops.cc", "cast_op.cc", + "categorical_op.cc", + "cholesky_op.cc", "concat_op.cc", "const_op.cc", "conv_ops.cc", @@ -33,6 +35,7 @@ tf_kernel_library( "gather_op.cc", "gather_op_helpers.h", "identity_op.cc", + "image_resize_ops.cc", "index_ops.cc", "l2loss_op.cc", "lrn_ops.cc", @@ -52,17 +55,20 @@ tf_kernel_library( "reshape_op.cc", "retval_op.cc", "reverse_op.cc", + "scan_ops.cc", "segment_reduction_ops.cc", "select_op.cc", "sendrecv_ops.cc", "sequence_ops.cc", "shape_op.cc", + "shape_util.cc", "slice_op.cc", "softmax_op.cc", "spacetobatch_op.cc", "spacetodepth_op.cc", "split_op.cc", "stack_ops.cc", + "stateless_random_ops.cc", "strided_slice_op.cc", "tensor_array_ops.cc", "tile_ops.cc", @@ -75,12 +81,17 @@ tf_kernel_library( hdrs = [ "gather_op.h", "index_ops.h", + "shape_util.h", ], deps = [ ":while_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla/lib:batch_dot", + "//tensorflow/compiler/tf2xla/lib:cholesky", + "//tensorflow/compiler/tf2xla/lib:util", "//tensorflow/compiler/tf2xla/ops:sendrecv_ops", + "//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", @@ -89,8 +100,11 @@ tf_kernel_library( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client/lib:arithmetic", "//tensorflow/core:framework", + "//tensorflow/core:image_ops_op_lib", "//tensorflow/core:lib", + "//tensorflow/core:linalg_ops_op_lib", "//tensorflow/core:protos_all_cc", + "//tensorflow/core:stateless_random_ops_op_lib", "//tensorflow/core/kernels:bounds_check", "//tensorflow/core/kernels:concat_lib", "//tensorflow/core/kernels:constant_op", diff --git a/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc index 73ccc151c1d6bdf70105badd962903297f090abe..a015b8e0e8949f8aaa03a78b0f88b7ea8d6aaa1c 100644 --- a/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc @@ -13,11 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// XLA-specific BatchMatMul Op. -// The current implementation simply unrolls the computation along the batch -// dimension. -// TODO(dominikg,phawkins): Use a real batched matmul instead of unrolling. - +#include "tensorflow/compiler/tf2xla/lib/batch_dot.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -32,110 +28,10 @@ class BatchMatMulOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - const TensorShape x_shape = ctx->InputShape(0); - const TensorShape y_shape = ctx->InputShape(1); - - // Check that both tensors have the same number of dimensions. There must be - // at least two (the batch dimensions can be empty). - OP_REQUIRES(ctx, x_shape.dims() == y_shape.dims(), - errors::InvalidArgument("In[0] and In[1] has different ndims: ", - x_shape.DebugString(), " vs. ", - y_shape.DebugString())); - const int ndims = x_shape.dims(); - OP_REQUIRES( - ctx, ndims >= 2, - errors::InvalidArgument("In[0] and In[1] ndims must be >= 2: ", ndims)); - - // The batch dimensions must be equal and the matrix dimensions must be - // valid. - std::vector dimensions; - int batch_count = 1; - for (int i = 0; i < ndims - 2; ++i) { - OP_REQUIRES( - ctx, x_shape.dim_size(i) == y_shape.dim_size(i), - errors::InvalidArgument("In[0].dim(", i, ") and In[1].dim(", i, - ") must be the same: ", x_shape.DebugString(), - " vs ", y_shape.DebugString())); - dimensions.push_back(x_shape.dim_size(i)); - batch_count *= x_shape.dim_size(i); - } - - int x_inner_dim = adj_x_ ? (ndims - 2) : (ndims - 1); - int y_inner_dim = adj_y_ ? (ndims - 1) : (ndims - 2); - OP_REQUIRES( - ctx, x_shape.dim_size(x_inner_dim) == y_shape.dim_size(y_inner_dim), - errors::InvalidArgument( - "In[0] mismatch In[1] shape: ", x_shape.dim_size(x_inner_dim), - " vs. ", y_shape.dim_size(y_inner_dim), ": ", x_shape.DebugString(), - " ", y_shape.DebugString(), " ", adj_x_, " ", adj_y_)); - - int x_outer_dim = adj_x_ ? (ndims - 1) : (ndims - 2); - int y_outer_dim = adj_y_ ? (ndims - 2) : (ndims - 1); - dimensions.push_back(x_shape.dim_size(x_outer_dim)); - dimensions.push_back(y_shape.dim_size(y_outer_dim)); - - xla::ComputationBuilder* builder = ctx->builder(); - - xla::ComputationDataHandle x_handle = ctx->Input(0); - if (BaseType(input_type(0)) == DT_COMPLEX64 && adj_x_) { - x_handle = builder->Conj(x_handle); - } - xla::ComputationDataHandle y_handle = ctx->Input(1); - if (BaseType(input_type(1)) == DT_COMPLEX64 && adj_y_) { - y_handle = builder->Conj(y_handle); - } - - // Reshape input tensors into 3D tensors by flattening the batch - // dimensions. This makes it easier to unroll the batch dimension. - auto x_flat = - builder->Reshape(x_handle, {batch_count, x_shape.dim_size(ndims - 2), - x_shape.dim_size(ndims - 1)}); - auto y_flat = - builder->Reshape(y_handle, {batch_count, y_shape.dim_size(ndims - 2), - y_shape.dim_size(ndims - 1)}); - - // Slice batches into individual matrices and multiply them. - std::vector out_slices; - for (int i = 0; i < batch_count; ++i) { - // Slice off individual matrices and reshape to 2D tensors. - auto x_slice = builder->Slice( - x_flat, {i, 0, 0}, - {i + 1, x_shape.dim_size(ndims - 2), x_shape.dim_size(ndims - 1)}, - {1, 1, 1}); - x_slice = builder->Reshape( - x_slice, {x_shape.dim_size(ndims - 2), x_shape.dim_size(ndims - 1)}); - auto y_slice = builder->Slice( - y_flat, {i, 0, 0}, - {i + 1, y_shape.dim_size(ndims - 2), y_shape.dim_size(ndims - 1)}, - {1, 1, 1}); - y_slice = builder->Reshape( - y_slice, {y_shape.dim_size(ndims - 2), y_shape.dim_size(ndims - 1)}); - - // Transpose if needed. - auto lhs = adj_x_ ? builder->Transpose(x_slice, {1, 0}) : x_slice; - auto rhs = adj_y_ ? builder->Transpose(y_slice, {1, 0}) : y_slice; - - // Multiply matrices and add an outer singleton dimension to the output - // so we can concatenate along the flattened batch dimension later. - auto out = builder->Dot(lhs, rhs); - out = builder->Reshape(out, - {1, dimensions[ndims - 2], dimensions[ndims - 1]}); - out_slices.push_back(out); - } - - // Concatenate output slices and reshape to original number of dimensions. - xla::ComputationDataHandle data; - if (out_slices.empty()) { - // It is illegal to pass an empty list to ConcatInDim. - // The batch count is empty, so both inputs must have zero elements. - // Arbitrarily use the left input as the argument to Reshape(). - data = x_handle; - } else { - data = builder->ConcatInDim(out_slices, 0); - } - data = builder->Reshape(data, dimensions); - - ctx->SetOutput(0, data); + auto result = + BatchDot(ctx->builder(), ctx->Input(0), ctx->Input(1), adj_x_, adj_y_); + OP_REQUIRES_OK(ctx, result.status()); + ctx->SetOutput(0, result.ValueOrDie()); } private: diff --git a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc index 248e9d111e556dcdd75581aa6562a66fc8b57063..a249b1869f547f8e5aa725f9f5cf391b10429928 100644 --- a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ // XLA implementation of BatchNorm operations. -#include "tensorflow/compiler/tf2xla/literal_util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -26,43 +26,63 @@ namespace { class FusedBatchNormOp : public XlaOpKernel { public: explicit FusedBatchNormOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { - string data_format; OP_REQUIRES_OK(ctx, ctx->GetAttr("epsilon", &epsilon_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("is_training", &is_training_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format)); - TensorFormat tensor_format; - if (ctx->GetAttr("data_format", &data_format).ok()) { - OP_REQUIRES(ctx, FormatFromString(data_format, &tensor_format), - errors::InvalidArgument("Invalid data format")); - OP_REQUIRES( - ctx, (tensor_format == FORMAT_NHWC || tensor_format == FORMAT_NCHW), - errors::InvalidArgument("Not supported format")); - feature_index_ = GetTensorFeatureDimIndex(/*num_dims=*/4, tensor_format); - } + string data_format_str; + OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str)); + OP_REQUIRES( + ctx, FormatFromString(data_format_str, &data_format_), + errors::InvalidArgument("Invalid data format: ", data_format_str)); + OP_REQUIRES(ctx, + (data_format_ == FORMAT_NHWC || data_format_ == FORMAT_NCHW), + errors::InvalidArgument( + "Unsupported data format ", ToString(data_format_), + "; supported formats are NHWC and NCHW")); } void Compile(XlaOpKernelContext* ctx) override { + xla::PrimitiveType input_type; + OP_REQUIRES_OK(ctx, + DataTypeToPrimitiveType(ctx->input_type(0), &input_type)); + xla::PrimitiveType scale_type; + OP_REQUIRES_OK(ctx, + DataTypeToPrimitiveType(ctx->input_type(1), &scale_type)); + + xla::ComputationBuilder* builder = ctx->builder(); + + xla::ComputationDataHandle input = ctx->Input(0); + TensorShape input_shape = ctx->InputShape(0); + + int feature_index = + GetTensorFeatureDimIndex(input_shape.dims(), data_format_); + + // TODO(b/69928690): support mixed precision in the XLA batch normalization + // operators. As a workaround, cast everything to the statistics type (which + // may be more precise than the input type). + input = builder->ConvertElementType(input, scale_type); + if (is_training_) { - xla::ComputationDataHandle output = ctx->builder()->BatchNormTraining( - ctx->Input(0), ctx->Input(1), ctx->Input(2), epsilon_, - feature_index_); + xla::ComputationDataHandle output = builder->BatchNormTraining( + input, ctx->Input(1), ctx->Input(2), epsilon_, feature_index); // In training mode, outputs the normalized value as well as the // calculated mean and variance. - for (int i = 0; i < 3; i++) { - ctx->SetOutput(i, ctx->builder()->GetTupleElement(output, i)); - } + ctx->SetOutput(0, builder->ConvertElementType( + builder->GetTupleElement(output, 0), input_type)); + ctx->SetOutput(1, builder->GetTupleElement(output, 1)); + ctx->SetOutput(2, builder->GetTupleElement(output, 2)); + // Output 3 and 4 for "FusedBatchNorm" are currently marked as "reserved // space 1 & 2". They are used to pass the per-batch mean and // variance to the gradient. Here we maintain the same behavior by setting // them to the mean and variance calculated by BatchNormTraining. - ctx->SetOutput(3, ctx->builder()->GetTupleElement(output, 1)); - ctx->SetOutput(4, ctx->builder()->GetTupleElement(output, 2)); + ctx->SetOutput(3, builder->GetTupleElement(output, 1)); + ctx->SetOutput(4, builder->GetTupleElement(output, 2)); } else { - xla::ComputationDataHandle output = ctx->builder()->BatchNormInference( - ctx->Input(0), ctx->Input(1), ctx->Input(2), ctx->Input(3), - ctx->Input(4), epsilon_, feature_index_); - ctx->SetOutput(0, output); + xla::ComputationDataHandle output = builder->BatchNormInference( + input, ctx->Input(1), ctx->Input(2), ctx->Input(3), ctx->Input(4), + epsilon_, feature_index); + ctx->SetOutput(0, builder->ConvertElementType(output, input_type)); // Directly send input to output as mean and variance in inference mode. ctx->SetOutput(1, ctx->Input(3)); ctx->SetOutput(2, ctx->Input(4)); @@ -73,55 +93,113 @@ class FusedBatchNormOp : public XlaOpKernel { private: float epsilon_; - int64 feature_index_; + TensorFormat data_format_; bool is_training_; }; REGISTER_XLA_OP(Name("FusedBatchNorm"), FusedBatchNormOp); +REGISTER_XLA_OP(Name("FusedBatchNormV2"), FusedBatchNormOp); class FusedBatchNormGradOp : public XlaOpKernel { public: explicit FusedBatchNormGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { - string data_format; OP_REQUIRES_OK(ctx, ctx->GetAttr("epsilon", &epsilon_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format)); - bool is_training; - OP_REQUIRES_OK(ctx, ctx->GetAttr("is_training", &is_training)); - CHECK(is_training) << "FusedBatchNormGradOp with is_training=False cannot " - "be used with XLA for now!"; - TensorFormat tensor_format; - if (ctx->GetAttr("data_format", &data_format).ok()) { - OP_REQUIRES(ctx, FormatFromString(data_format, &tensor_format), - errors::InvalidArgument("Invalid data format")); - OP_REQUIRES( - ctx, (tensor_format == FORMAT_NHWC || tensor_format == FORMAT_NCHW), - errors::InvalidArgument("Not supported format")); - feature_index_ = GetTensorFeatureDimIndex(4, tensor_format); - } + OP_REQUIRES_OK(ctx, ctx->GetAttr("is_training", &is_training_)); + string data_format_str; + OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str)); + OP_REQUIRES( + ctx, FormatFromString(data_format_str, &data_format_), + errors::InvalidArgument("Invalid data format: ", data_format_str)); + OP_REQUIRES(ctx, + (data_format_ == FORMAT_NHWC || data_format_ == FORMAT_NCHW), + errors::InvalidArgument( + "Unsupported data format ", ToString(data_format_), + "; supported formats are NHWC and NCHW")); } void Compile(XlaOpKernelContext* ctx) override { - auto grad_output = ctx->Input(0); - auto activation = ctx->Input(1); + xla::ComputationBuilder* b = ctx->builder(); + + auto grad_backprop = ctx->Input(0); + auto activations = ctx->Input(1); auto scale = ctx->Input(2); auto mean = ctx->Input(3); auto var = ctx->Input(4); - xla::ComputationDataHandle output = ctx->builder()->BatchNormGrad( - activation, scale, mean, var, grad_output, epsilon_, feature_index_); - for (int i = 0; i < 3; i++) { - ctx->SetOutput(i, ctx->builder()->GetTupleElement(output, i)); + TensorShape input_shape = ctx->InputShape(0); + int feature_index = + GetTensorFeatureDimIndex(input_shape.dims(), data_format_); + + DataType input_dtype = ctx->input_type(0); + DataType scale_dtype = ctx->input_type(2); + xla::PrimitiveType input_type; + OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(input_dtype, &input_type)); + xla::PrimitiveType scale_type; + OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(scale_dtype, &scale_type)); + + // TODO(b/69928690): support mixed precision in the XLA batch normalization + // operators. For now, cast everything to the statistics type (which + // may be more precise than the input type). + grad_backprop = b->ConvertElementType(grad_backprop, scale_type); + activations = b->ConvertElementType(activations, scale_type); + + xla::ComputationDataHandle x_backprop; + xla::ComputationDataHandle scale_backprop; + xla::ComputationDataHandle offset_backprop; + if (is_training_) { + xla::ComputationDataHandle output = + b->BatchNormGrad(activations, scale, mean, var, grad_backprop, + epsilon_, feature_index); + + x_backprop = b->GetTupleElement(output, 0); + scale_backprop = b->GetTupleElement(output, 1); + offset_backprop = b->GetTupleElement(output, 2); + } else { + // Reduce over all dimensions except the feature dim. + std::vector reduction_dims(input_shape.dims() - 1); + std::iota(reduction_dims.begin(), reduction_dims.begin() + feature_index, + 0); + std::iota(reduction_dims.begin() + feature_index, reduction_dims.end(), + feature_index + 1); + // offset_backprop = sum(y_backprop) + // scale_backprop = y_backprop * ((x - pop_mean) * rsqrt(pop_var + + // epsilon)) + // x_backprop = y_backprop * (scale * rsqrt(pop_var + epsilon)) + offset_backprop = + b->Reduce(grad_backprop, XlaHelpers::Zero(b, scale_dtype), + *ctx->GetOrCreateAdd(scale_dtype), reduction_dims); + + // scratch1 = rsqrt(pop_var + epsilon) + auto neg_half = XlaHelpers::FloatLiteral(b, scale_dtype, -0.5); + auto scratch1 = + b->Pow(b->Add(var, b->ConstantR0(epsilon_)), neg_half); + + // scratch2 = sum(y_backprop * (x - mean)) + auto scratch2 = b->Reduce( + b->Mul(grad_backprop, b->Sub(activations, mean, {feature_index})), + XlaHelpers::Zero(b, scale_dtype), *ctx->GetOrCreateAdd(scale_dtype), + reduction_dims); + + x_backprop = + b->Mul(grad_backprop, b->Mul(scratch1, scale), {feature_index}); + scale_backprop = b->Mul(scratch1, scratch2); } - ctx->SetOutput(3, ctx->builder()->GetTupleElement(output, 1)); - ctx->SetOutput(4, ctx->builder()->GetTupleElement(output, 2)); + + ctx->SetOutput(0, b->ConvertElementType(x_backprop, input_type)); + ctx->SetOutput(1, scale_backprop); + ctx->SetOutput(2, offset_backprop); + ctx->SetConstantOutput(3, Tensor(scale_dtype, {})); + ctx->SetConstantOutput(4, Tensor(scale_dtype, {})); } private: + TensorFormat data_format_; float epsilon_; - int64 feature_index_; + bool is_training_; }; REGISTER_XLA_OP(Name("FusedBatchNormGrad"), FusedBatchNormGradOp); +REGISTER_XLA_OP(Name("FusedBatchNormGradV2"), FusedBatchNormGradOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc index 1de91924326464338352b1ac9edf77141f25ad35..2436a6074a11ad66387b232dd1c5aa135875bfc3 100644 --- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" namespace tensorflow { namespace { @@ -75,7 +76,7 @@ static xla::ComputationDataHandle FloorDivImpl(xla::ComputationBuilder* b, auto abs_y = b->Abs(y); auto t = b->Neg(b->Sub(b->Add(abs_x, abs_y), one)); auto result = b->Select(different_sign, b->Div(t, abs_y), b->Div(x, y)); - if (dtype == DT_FLOAT || dtype == DT_DOUBLE) { + if (DataTypeIsFloating(dtype)) { result = b->Floor(result); } return result; diff --git a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..592f3ecc3ce2abf33ddffe8b0e59c4e12e73e956 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc @@ -0,0 +1,98 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// XLA implementations of Categorical op. + +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" + +namespace tensorflow { +namespace { + +class CategoricalOp : public XlaOpKernel { + public: + explicit CategoricalOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + // Get the logits + const xla::ComputationDataHandle& logits = ctx->Input(0); + TensorShape logits_shape = ctx->InputShape(0); + int64 num_samples; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &num_samples)); + OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(logits_shape), + errors::InvalidArgument("logits should be a matrix, got shape ", + logits_shape.DebugString())); + OP_REQUIRES(ctx, num_samples >= 0, + errors::InvalidArgument( + "num_samples should be nonnegative, got ", num_samples)); + + for (int i = 0; i < 2; i++) { + const int64 dim = logits_shape.dim_size(i); + OP_REQUIRES( + ctx, static_cast(dim) == dim, + errors::InvalidArgument("logits.shape = ", logits_shape.DebugString(), + " too large for int")); + } + + const int64 batch_size = logits_shape.dim_size(0); + const int64 num_classes = logits_shape.dim_size(1); + + xla::ComputationBuilder* builder = ctx->builder(); + + std::array uniform_shape_array = { + {batch_size, num_samples, num_classes}}; + xla::PrimitiveType uniform_xla_type; + OP_REQUIRES_OK(ctx, + DataTypeToPrimitiveType(input_type(0), &uniform_xla_type)); + xla::Shape uniform_shape = + xla::ShapeUtil::MakeShape(uniform_xla_type, uniform_shape_array); + auto uniforms = builder->RngUniform( + XlaHelpers::Zero(builder, input_type(0)), + XlaHelpers::One(builder, input_type(0)), uniform_shape); + + // Use Gumbel softmax trick to generate categorical samples. + // See: + // https://hips.seas.harvard.edu/blog/2013/04/06/the-gumbel-max-trick-for-discrete-distributions/ + // TODO(b/68769470): Switch to using a cumulative sum approach. + auto softmax_entries = + builder->Sub(logits, builder->Log(builder->Neg(builder->Log(uniforms))), + /*broadcast_dimensions=*/{0, 2}); + + TensorShape softmax_shape(uniform_shape_array); + xla::ComputationDataHandle argmax; + OP_REQUIRES_OK( + ctx, + XlaHelpers::ArgMax(builder, ctx, softmax_entries, softmax_shape, + input_type(0), output_type(0), /*axis=*/2, &argmax)); + + ctx->SetOutput(0, argmax); + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(CategoricalOp); +}; + +// TODO(b/68769717): Rename this sampler to Categorical. +REGISTER_XLA_OP(Name("Multinomial"), CategoricalOp); + +} // anonymous namespace +} // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_namespace_compat.h b/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc similarity index 50% rename from tensorflow/core/distributed_runtime/rpc/grpc_namespace_compat.h rename to tensorflow/compiler/tf2xla/kernels/cholesky_op.cc index c178927f5d5411e30bee2470b8b544ff76c28396..87d858f763560be454c162e0cf40307c68217663 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_namespace_compat.h +++ b/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc @@ -13,20 +13,27 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_NAMESPACE_COMPAT_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_NAMESPACE_COMPAT_H_ - -// This file is a transitional place-holder until gRPC versions consistently -// use namespace grpc::internal for library-internal structures - -namespace grpc { -// ensure internal namespace exists -namespace internal { -// bring in contents of external namespace -using namespace ::grpc; -} // namespace internal -// bring in contents of internal namespace -using namespace internal; -} // namespace grpc - -#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_NAMESPACE_COMPAT_H_ +#include "tensorflow/compiler/tf2xla/lib/cholesky.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" + +namespace tensorflow { +namespace { + +class CholeskyOp : public XlaOpKernel { + public: + explicit CholeskyOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + void Compile(XlaOpKernelContext* ctx) override { + auto result = Cholesky(ctx->builder(), ctx->Input(0)); + if (!result.ok()) { + ctx->SetStatus(result.status()); + return; + } + ctx->SetOutput(0, result.ValueOrDie()); + } +}; + +REGISTER_XLA_OP(Name("Cholesky"), CholeskyOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/const_op.cc b/tensorflow/compiler/tf2xla/kernels/const_op.cc index 9833323d851e00e7ca76d0b39cd2b216748a17fa..8f78b4c8f90cf00d5fa9ba71a78bb1c0fe280dc6 100644 --- a/tensorflow/compiler/tf2xla/kernels/const_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/const_op.cc @@ -40,6 +40,11 @@ class ConstOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { TensorShape shape(proto_.tensor_shape()); + if (proto_.dtype() == DT_STRING) { + LOG(WARNING) << "Not computing Const of type DT_STRING"; + ctx->SetInvalidOutput(0); + return; + } xla::ComputationBuilder* b = ctx->builder(); // To avoid blowups for large constants filled with the same value, diff --git a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc index 885f716afafca7ba23770e38f6693eed1ba50982..aaddbe811c6fbf6da296640eb5a75e82b2fedcfa 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc @@ -46,72 +46,130 @@ TensorShape ExpandedFilterShapeForDepthwiseConvolution( return expanded_shape; } +// Broadcast zeros to ExpandedFilterShapeForDepthwiseConvolution. +xla::ComputationDataHandle CreateExpandedZero( + const TensorShape& filter_shape, DataType dtype, + xla::ComputationBuilder* builder) { + TensorShape expanded_filter_shape = + ExpandedFilterShapeForDepthwiseConvolution(filter_shape); + return builder->Broadcast(XlaHelpers::Zero(builder, dtype), + expanded_filter_shape.dim_sizes()); +} + +// Create a mask for depthwise convolution that will make a normal convolution +// produce the same results as a depthwise convolution. For a [2, 2, 3, 2] +// depthwise filter this returns a [2, 2, 3, 6] tesnsor +// 1 1 0 0 0 0 1 1 0 0 0 0 +// 0 0 1 1 0 0 0 0 1 1 0 0 +// 0 0 0 0 1 1 0 0 0 0 1 1 +// +// 1 1 0 0 0 0 1 1 0 0 0 0 +// 0 0 1 1 0 0 0 0 1 1 0 0 +// 0 0 0 0 1 1 0 0 0 0 1 1 +// +// The first step is to create a one tensor, A, that is [3] +// 0 1 2 +// +// and another tensor, B, that is [3 * 2] +// 0 1 2 3 4 5 +// +// and divide B it by 2 to get +// 0 0 1 1 2 2 +// +// then we broadcast the B to [2, 2, 3, 3 * 2] +// 0 0 1 1 2 2 0 0 1 1 2 2 +// 0 0 1 1 2 2 0 0 1 1 2 2 +// 0 0 1 1 2 2 0 0 1 1 2 2 +// +// 0 0 1 1 2 2 0 0 1 1 2 2 +// 0 0 1 1 2 2 0 0 1 1 2 2 +// 0 0 1 1 2 2 0 0 1 1 2 2 +// +// Finally compare A and broadcasted B in dimension 2 amd return the result at +// the beginning of the comment. +xla::ComputationDataHandle CreateExpandedFilterMask( + const TensorShape& filter_shape, xla::ComputationBuilder* builder) { + TensorShape expanded_filter_shape = + ExpandedFilterShapeForDepthwiseConvolution(filter_shape); + int64 depthwise_multiplier = filter_shape.dim_size(filter_shape.dims() - 1); + int64 input_feature = filter_shape.dim_size(filter_shape.dims() - 2); + + // Create a M sized linspace and an M*N sized linspace that will be + // broadcasted into perpendicular dimensions and compared. + xla::ComputationDataHandle input_feature_iota; + // DT_INT32 Iota will always return status::OK(). + TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32, input_feature, + &input_feature_iota)); + xla::ComputationDataHandle expanded_feature_iota; + TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32, + input_feature * depthwise_multiplier, + &expanded_feature_iota)); + + // Divide the M*N sized linspace by the depthwise_multiplier to create + // [0 0 1 1 2 2] in the example in the function comment. + expanded_feature_iota = + builder->Div(expanded_feature_iota, + XlaHelpers::IntegerLiteral(builder, DataType::DT_INT32, + depthwise_multiplier)); + + // Broadcast the N*M linspace to [H, W, ..., M, M*N]. + auto expanded_feature_broadcast_dims = expanded_filter_shape.dim_sizes(); + expanded_feature_broadcast_dims.pop_back(); + auto broadcasted_expanded_feature_iota = builder->Broadcast( + expanded_feature_iota, expanded_feature_broadcast_dims); + + // Compare the broadcasted linspace to the input feature linspace in the + // input feature dimension to create a diagonal predicate. + return builder->Eq(broadcasted_expanded_feature_iota, input_feature_iota, + {expanded_filter_shape.dims() - 2}); +} + // Expands a filter of shape [H, W, ..., M, N] to [H, W, ..., M, M*N] by adding // zeros for the cross-depth filters. Used to build a depthwise convolution. xla::ComputationDataHandle ExpandFilterForDepthwiseConvolution( const TensorShape& filter_shape, DataType dtype, const xla::ComputationDataHandle& filter, xla::ComputationBuilder* builder) { - // Filter has shape [H, W, ..., M, N] - // Dilate to [H, W, ..., M*M, N] using M inter-element padding, and then - // reshape to [H, W, ..., M, M*N]. - int num_spatial_dims = filter_shape.dims() - 2; - const int64 in_depth = filter_shape.dim_size(num_spatial_dims); - xla::PaddingConfig padding = xla::MakeNoPaddingConfig(filter_shape.dims()); - padding.mutable_dimensions(num_spatial_dims)->set_interior_padding(in_depth); - auto dilated_filter = - builder->Pad(filter, XlaHelpers::Zero(builder, dtype), padding); - + int64 depthwise_multiplier = filter_shape.dim_size(filter_shape.dims() - 1); + int64 input_feature = filter_shape.dim_size(filter_shape.dims() - 2); TensorShape expanded_filter_shape = ExpandedFilterShapeForDepthwiseConvolution(filter_shape); - return builder->Reshape(dilated_filter, expanded_filter_shape.dim_sizes()); + + // Create a [H, W, ..., 1, N*M] reshape of the filter. + TensorShape implicit_broadcast_filter_shape = expanded_filter_shape; + implicit_broadcast_filter_shape.set_dim( + implicit_broadcast_filter_shape.dims() - 2, 1); + implicit_broadcast_filter_shape.set_dim( + implicit_broadcast_filter_shape.dims() - 1, + depthwise_multiplier * input_feature); + auto implicit_broadcast_filter = + builder->Reshape(filter, implicit_broadcast_filter_shape.dim_sizes()); + + // Broadcast the filter to [H, W, ..., M, M*N]. + auto expanded_zero = CreateExpandedZero(filter_shape, dtype, builder); + auto expanded_filter = builder->Add(implicit_broadcast_filter, expanded_zero); + + // If the filter mask is set, choose the broadcasted filter, othwerwise, + // choose zero. + return builder->Select(CreateExpandedFilterMask(filter_shape, builder), + expanded_filter, expanded_zero); } // Inverse of ExpandFilterForDepthwiseConvolution. xla::ComputationDataHandle ContractFilterForDepthwiseBackprop( - const TensorShape& filter_shape, DataType dtype, + XlaOpKernelContext* ctx, const TensorShape& filter_shape, DataType dtype, const xla::ComputationDataHandle& filter_backprop, xla::ComputationBuilder* builder) { - int num_spatial_dims = filter_shape.dims() - 2; - - // Reshape to [H, W, ..., M*M, N] - TensorShape shape = filter_shape; - int64 in_depth = filter_shape.dim_size(num_spatial_dims); - shape.set_dim(num_spatial_dims, in_depth * in_depth); - auto reshaped = builder->Reshape(filter_backprop, shape.dim_sizes()); - - std::vector zeros(filter_shape.dims()); - std::vector strides(filter_shape.dims(), 1LL); - strides[num_spatial_dims] = in_depth + 1; - return builder->Slice(reshaped, zeros, shape.dim_sizes(), strides); - - // Alternate implementation for backends without strided Slice() support. - // TODO(phawkins): Remove when all backends support strided slice. - // // Pad [..., M * (M + 1), N] - // xla::PaddingConfig config = - // xla::MakeNoPaddingConfig(filter_shape.dims()); - // config.mutable_dimensions(num_spatial_dims) - // ->set_edge_padding_high(in_depth); - // auto zero = XlaHelpers::Zero(builder, dtype); - // auto padded = builder->Pad(reshaped, zero, config); - // - // // Reshape to [..., M, M + 1, N] - // shape = filter_shape; - // shape.set_dim(num_spatial_dims, in_depth); - // shape.set_dim(num_spatial_dims + 1, in_depth + 1); - // int64 out_depth = filter_shape.dim_size(num_spatial_dims + 1); - // shape.AddDim(out_depth); - // reshaped = builder->Reshape(padded, shape.dim_sizes()); - // - // // Slice to [..., M, 1, N] - // std::vector zeros(shape.dims()); - // std::vector strides(shape.dims(), 1LL); - // shape.set_dim(num_spatial_dims + 1, 1); - // auto sliced = builder->Slice(reshaped, zeros, shape.dim_sizes(), - // strides); - // - // // Reshape to [..., M, N] - // return builder->Reshape(sliced, filter_shape.dim_sizes()); + TensorShape expanded_filter_shape = + ExpandedFilterShapeForDepthwiseConvolution(filter_shape); + auto masked_expanded_filter = builder->Select( + CreateExpandedFilterMask(filter_shape, builder), filter_backprop, + CreateExpandedZero(filter_shape, dtype, builder)); + return builder->Reshape( + builder->Reduce(masked_expanded_filter, XlaHelpers::Zero(builder, dtype), + *ctx->GetOrCreateAdd(dtype), + {expanded_filter_shape.dims() - 2}), + filter_shape.dim_sizes()); } class ConvOp : public XlaOpKernel { @@ -121,6 +179,7 @@ class ConvOp : public XlaOpKernel { : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims), depthwise_(depthwise) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("dilations", &dilations_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_)); @@ -144,6 +203,23 @@ class ConvOp : public XlaOpKernel { errors::Unimplemented("Current implementation does not yet support " "strides in the batch and depth dimensions.")); + OP_REQUIRES(ctx, dilations_.size() == num_dims(), + errors::InvalidArgument("Dilations field must " + "specify ", + num_dims(), " dimensions")); + OP_REQUIRES( + ctx, dilations_[batch_dim] == 1 && dilations_[feature_dim] == 1, + errors::Unimplemented("Current implementation does not yet support " + "dilations in the batch and depth dimensions.")); + for (int i = 0; i < num_spatial_dims_; ++i) { + int input_dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); + OP_REQUIRES( + ctx, dilations_[input_dim] == 1, + errors::Unimplemented("Current implementation does not yet support " + "dilations in the ", + i, "th spatial dimension.")); + } + const TensorShape input_shape = ctx->InputShape(0); // Input filter is of the following dimensions: // [ filter_rows, filter_cols, ..., in_depth, out_depth] @@ -184,10 +260,11 @@ class ConvOp : public XlaOpKernel { dims.set_input_feature_dimension(feature_dim); dims.set_output_feature_dimension(feature_dim); for (int i = 0; i < num_spatial_dims_; ++i) { - int input_dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); - dims.add_spatial_dimensions(input_dim); + const int64 dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); + dims.add_input_spatial_dimensions(dim); dims.add_kernel_spatial_dimensions(i); - window_strides.push_back(strides_.at(input_dim)); + dims.add_output_spatial_dimensions(dim); + window_strides.push_back(strides_.at(dim)); } dims.set_kernel_input_feature_dimension(num_spatial_dims_); dims.set_kernel_output_feature_dimension(num_spatial_dims_ + 1); @@ -203,6 +280,7 @@ class ConvOp : public XlaOpKernel { protected: const int num_spatial_dims_; const bool depthwise_; + std::vector dilations_; std::vector strides_; Padding padding_; TensorFormat data_format_ = FORMAT_NHWC; @@ -240,6 +318,7 @@ class ConvBackpropInputOp : public XlaOpKernel { : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims), depthwise_(depthwise) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("dilations", &dilations_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_)); string data_format; @@ -262,6 +341,23 @@ class ConvBackpropInputOp : public XlaOpKernel { errors::Unimplemented("Current implementation does not yet support " "strides in the batch and depth dimensions.")); + OP_REQUIRES(ctx, dilations_.size() == num_dims(), + errors::InvalidArgument("Dilations field must " + "specify ", + num_dims(), " dimensions")); + OP_REQUIRES( + ctx, dilations_[batch_dim] == 1 && dilations_[feature_dim] == 1, + errors::Unimplemented("Current implementation does not yet support " + "dilations in the batch and depth dimensions.")); + for (int i = 0; i < num_spatial_dims_; ++i) { + int input_dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); + OP_REQUIRES( + ctx, dilations_[input_dim] == 1, + errors::Unimplemented("Current implementation does not yet support " + "dilations in the ", + i, "th spatial dimension.")); + } + TensorShape input_shape; OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &input_shape)); @@ -302,9 +398,10 @@ class ConvBackpropInputOp : public XlaOpKernel { std::vector lhs_dilation(num_spatial_dims_); std::vector ones(num_spatial_dims_, 1); for (int i = 0; i < num_spatial_dims_; ++i) { - dnums.add_spatial_dimensions( - GetTensorSpatialDimIndex(num_dims(), data_format_, i)); + int64 dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); + dnums.add_input_spatial_dimensions(dim); dnums.add_kernel_spatial_dimensions(i); + dnums.add_output_spatial_dimensions(dim); kernel_spatial_dims[i] = i; padding[i] = {dims.spatial_dims[i].pad_before, @@ -334,6 +431,7 @@ class ConvBackpropInputOp : public XlaOpKernel { protected: const int num_spatial_dims_; const bool depthwise_; + std::vector dilations_; std::vector strides_; Padding padding_; TensorFormat data_format_ = FORMAT_NHWC; @@ -371,6 +469,7 @@ class ConvBackpropFilterOp : public XlaOpKernel { : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims), depthwise_(depthwise) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("dilations", &dilations_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_)); string data_format; @@ -390,6 +489,23 @@ class ConvBackpropFilterOp : public XlaOpKernel { errors::InvalidArgument("Current implementation does not yet support " "strides in the batch and depth dimensions.")); + OP_REQUIRES(ctx, dilations_.size() == num_dims(), + errors::InvalidArgument("Dilations field must " + "specify ", + num_dims(), " dimensions")); + OP_REQUIRES( + ctx, dilations_[n_dim] == 1 && dilations_[c_dim] == 1, + errors::Unimplemented("Current implementation does not yet support " + "dilations in the batch and depth dimensions.")); + for (int i = 0; i < num_spatial_dims_; ++i) { + int input_dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); + OP_REQUIRES( + ctx, dilations_[input_dim] == 1, + errors::Unimplemented("Current implementation does not yet support " + "dilations in the ", + i, "th spatial dimension.")); + } + const TensorShape activations_shape = ctx->InputShape(0); TensorShape filter_shape; OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(1, &filter_shape)); @@ -424,9 +540,7 @@ class ConvBackpropFilterOp : public XlaOpKernel { // Swap n_dim and c_dim in the activations. dnums.set_input_batch_dimension(c_dim); - dnums.set_output_batch_dimension(c_dim); dnums.set_input_feature_dimension(n_dim); - dnums.set_output_feature_dimension(n_dim); // The gradients become the RHS of the convolution. // The gradients have shape [batch, out_rows, out_cols, ..., out_depth] @@ -438,9 +552,16 @@ class ConvBackpropFilterOp : public XlaOpKernel { std::vector rhs_dilation(num_spatial_dims_); std::vector ones(num_spatial_dims_, 1); + // Tensorflow filter shape is [ H, W, ..., inC, outC ]. for (int i = 0; i < num_spatial_dims_; ++i) { - int dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); - dnums.add_spatial_dimensions(dim); + dnums.add_output_spatial_dimensions(i); + } + dnums.set_output_batch_dimension(num_spatial_dims_); + dnums.set_output_feature_dimension(num_spatial_dims_ + 1); + + for (int i = 0; i < num_spatial_dims_; ++i) { + int64 dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); + dnums.add_input_spatial_dimensions(dim); dnums.add_kernel_spatial_dimensions(dim); // We will also need to pad the input with zeros such that after the @@ -498,31 +619,17 @@ class ConvBackpropFilterOp : public XlaOpKernel { /*window_strides=*/ones, padding, /*lhs_dilation=*/ones, rhs_dilation, dnums); - // The layout of filter_backprop will match the layout of - // padded_activations - // and so will have layout: [out_feature, h, w, ..., in_feature] - // Tensorflow filter shape is [ H, W, ..., inC, outC ], so we transpose the - // output. - std::vector transpose_dims; - transpose_dims.reserve(num_dims()); - for (int i = 0; i < num_spatial_dims_; ++i) { - transpose_dims.push_back(dnums.spatial_dimensions(i)); - } - transpose_dims.push_back(c_dim); - transpose_dims.push_back(n_dim); - xla::ComputationDataHandle filter_backprop_reshaped = - b->Transpose(filter_backprop, transpose_dims); - if (depthwise_) { - filter_backprop_reshaped = ContractFilterForDepthwiseBackprop( - filter_shape, ctx->input_type(0), filter_backprop_reshaped, b); + filter_backprop = ContractFilterForDepthwiseBackprop( + ctx, filter_shape, ctx->input_type(0), filter_backprop, b); } - ctx->SetOutput(0, filter_backprop_reshaped); + ctx->SetOutput(0, filter_backprop); } protected: const int num_spatial_dims_; const bool depthwise_; + std::vector dilations_; std::vector strides_; Padding padding_; TensorFormat data_format_ = FORMAT_NHWC; diff --git a/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc index a4ea65ea89e348cb77412efb0c5c0fcb1a9f33f3..96d7809f7995634b6bc31ab801b93526d9da7e6f 100644 --- a/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/util/tensor_format.h" namespace tensorflow { namespace { @@ -23,6 +24,16 @@ namespace { class DepthToSpaceOp : public XlaOpKernel { public: explicit DepthToSpaceOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + string data_format_str; + OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str)); + OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_), + errors::InvalidArgument("Invalid data format")); + + OP_REQUIRES(ctx, data_format_ == FORMAT_NCHW || data_format_ == FORMAT_NHWC, + errors::InvalidArgument("Unsupported data format ", + ToString(data_format_), + "; expected formats NHWC or NCHW")); + OP_REQUIRES_OK(ctx, ctx->GetAttr("block_size", &block_size_)); OP_REQUIRES( ctx, block_size_ > 1, @@ -31,18 +42,79 @@ class DepthToSpaceOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { const TensorShape input_tensor_shape = ctx->InputShape(0); - // The input is presumed to be [batch, height, width, depth] int input_rank = input_tensor_shape.dims(); static const int kRequiredDims = 4; OP_REQUIRES(ctx, kRequiredDims == input_rank, - errors::InvalidArgument("Input rank should be: ", kRequiredDims, - " instead of: ", input_rank)); + errors::InvalidArgument("Input rank should be ", kRequiredDims, + "; got: ", input_rank)); const gtl::InlinedVector input_shape = input_tensor_shape.dim_sizes(); xla::ComputationBuilder* b = ctx->builder(); xla::ComputationDataHandle input = ctx->Input(0); + int feature_dim = GetTensorFeatureDimIndex(input_rank, data_format_); + int num_spatial_dims = GetTensorSpatialDims(input_rank, data_format_); + + std::vector reshaped_shape; + std::vector transpose_order; + std::vector output_shape; + reshaped_shape.reserve(input_rank); + transpose_order.reserve(input_rank); + output_shape.reserve(input_rank); + if (data_format_ == FORMAT_NHWC) { + reshaped_shape.push_back(input_shape[0]); + for (int i = 0; i < num_spatial_dims; ++i) { + reshaped_shape.push_back(input_shape[1 + i]); + } + int64 block_elems = 1; + for (int i = 0; i < num_spatial_dims; ++i) { + reshaped_shape.push_back(block_size_); + block_elems *= block_size_; + } + reshaped_shape.push_back(input_shape[feature_dim] / block_elems); + + transpose_order.push_back(0); + for (int i = 0; i < num_spatial_dims; ++i) { + transpose_order.push_back(i + 1); + transpose_order.push_back(i + 1 + num_spatial_dims); + } + transpose_order.push_back(feature_dim + num_spatial_dims); + + output_shape.push_back(input_shape[0]); + for (int i = 0; i < num_spatial_dims; ++i) { + output_shape.push_back(input_shape[1 + i] * block_size_); + } + output_shape.push_back(input_shape[feature_dim] / block_elems); + } else { + // NCHW format. + reshaped_shape.push_back(input_shape[0]); + int64 block_elems = 1; + for (int i = 0; i < num_spatial_dims; ++i) { + reshaped_shape.push_back(block_size_); + block_elems *= block_size_; + } + reshaped_shape.push_back(input_shape[feature_dim] / block_elems); + for (int i = 0; i < num_spatial_dims; ++i) { + reshaped_shape.push_back(input_shape[2 + i]); + } + + transpose_order.push_back(0); + transpose_order.push_back(1 + num_spatial_dims); + for (int i = 0; i < num_spatial_dims; ++i) { + transpose_order.push_back(2 + num_spatial_dims + i); + transpose_order.push_back(1 + i); + } + + output_shape.push_back(input_shape[0]); + output_shape.push_back(input_shape[feature_dim] / block_elems); + for (int i = 0; i < num_spatial_dims; ++i) { + output_shape.push_back(input_shape[2 + i] * block_size_); + } + } + + // Note: comments are given in NHWC format; NCHW is similar with a different + // dimension order. // 1. Reshape `input` to `reshaped` of shape: // // [batch, @@ -51,14 +123,14 @@ class DepthToSpaceOp : public XlaOpKernel { // block_size_, // block_size_, // depth / (block_size_ * block_size_)] - OP_REQUIRES(ctx, input_shape[3] % (block_size_ * block_size_) == 0, + OP_REQUIRES(ctx, + input_shape[feature_dim] % (block_size_ * block_size_) == 0, errors::InvalidArgument( "Input depth dimension (", input_shape[3], ") is not divisible by square of the block size (", block_size_, ")")); - xla::ComputationDataHandle reshaped = b->Reshape( - input, {input_shape[0], input_shape[1], input_shape[2], block_size_, - block_size_, input_shape[3] / (block_size_ * block_size_)}); + + xla::ComputationDataHandle reshaped = b->Reshape(input, reshaped_shape); // 2. Permute dimensions of `reshaped` to produce // `permuted_reshaped` of shape: @@ -70,7 +142,7 @@ class DepthToSpaceOp : public XlaOpKernel { // block_size_, // depth / (block_size_ * block_size_)] xla::ComputationDataHandle permuted_reshaped = - b->Transpose(reshaped, {0, 1, 3, 2, 4, 5}); + b->Transpose(reshaped, transpose_order); // 3. Reshape `permuted_reshaped` to flatten `block_shape` into the // batch dimension, producing an output tensor of shape: @@ -80,15 +152,14 @@ class DepthToSpaceOp : public XlaOpKernel { // input_shape[2] * block_size_, // depth / (block_size_ * block_size_)] // - xla::ComputationDataHandle output = b->Reshape( - permuted_reshaped, {input_shape[0], input_shape[1] * block_size_, - input_shape[2] * block_size_, - input_shape[3] / (block_size_ * block_size_)}); + xla::ComputationDataHandle output = + b->Reshape(permuted_reshaped, output_shape); ctx->SetOutput(0, output); } private: + TensorFormat data_format_; int block_size_; }; REGISTER_XLA_OP(Name("DepthToSpace"), DepthToSpaceOp); diff --git a/tensorflow/compiler/tf2xla/kernels/diag_op.cc b/tensorflow/compiler/tf2xla/kernels/diag_op.cc index ec5017f6ab96bd3fc273a746b77fbb7e74fd9f35..765ea922a532a085a552192348ab360c4c30ff0a 100644 --- a/tensorflow/compiler/tf2xla/kernels/diag_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/diag_op.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/tf2xla/lib/util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -22,6 +24,62 @@ limitations under the License. namespace tensorflow { namespace { +// Create a diagonal / batch diagonal matrix with 'input' on the diagonal. +xla::StatusOr CreateDiagonal( + const xla::ComputationDataHandle& input, int64 last_dim_size, + tensorflow::gtl::ArraySlice other_dims, XlaOpKernelContext* ctx, + xla::ComputationBuilder* builder) { + // Create two matrices that have the following forms, and compare them: + // + // [[0, 0, 0, 0] [[0, 1, 2, 3] + // [1, 1, 1, 1] [0, 1, 2, 3] + // [2, 2, 2, 2] [0, 1, 2, 3] + // [3, 3, 3, 3]] [0, 1, 2, 3]] + // + // This produces a predicate matrix of the right size, with "true" on the + // diagonal. + xla::ComputationDataHandle iota; + TF_RETURN_IF_ERROR( + XlaHelpers::Iota(builder, DataType::DT_INT32, last_dim_size, &iota)); + xla::ComputationDataHandle iota_broadcast = + builder->Broadcast(iota, {last_dim_size}); + xla::ComputationDataHandle mask = builder->Eq(iota_broadcast, iota, {0}); + + // If this is a batched diagonal, broadcast the mask across the other + // dimensions. + if (!other_dims.empty()) { + mask = builder->Broadcast(mask, other_dims); + } + + // Broadcast the input, and then use the mask computed above to select the + // diagonal: + // e.g, in 2D: + // [[t, f, f] [[1, 1, 1] [[0, 0, 0] [[1, 0, 0] + // select( [f, t, f] , [4, 4, 4] , [0, 0, 0] ) = [0, 4, 0] + // [f, f, t]] [9, 9, 9]] [0, 0, 0]] [0, 0, 9]] + // + // Broadcasting the input is less-than-trivial, since we need to broadcast + // into a "middle" dimension. We can do this with a reshape + implicit + // broadcast. + // TODO(b/30112114): Replace with in-dim broadcast when those are supported. + std::vector broadcast_dims(other_dims.begin(), other_dims.end()); + broadcast_dims.push_back(1LL); + broadcast_dims.push_back(last_dim_size); + xla::ComputationDataHandle input_broadcast = + builder->Reshape(input, broadcast_dims); + + broadcast_dims[broadcast_dims.size() - 2] = last_dim_size; + xla::PrimitiveType element_type; + TF_RETURN_IF_ERROR( + DataTypeToPrimitiveType(ctx->input_type(0), &element_type)); + auto broadcast_shape = + xla::ShapeUtil::MakeShape(element_type, broadcast_dims); + xla::ComputationDataHandle zeros = Zeros(builder, broadcast_shape); + + input_broadcast = builder->Add(input_broadcast, zeros); + return builder->Select(mask, input_broadcast, zeros); +} + class DiagOp : public XlaOpKernel { public: explicit DiagOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} @@ -29,6 +87,8 @@ class DiagOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { xla::ComputationBuilder* builder = ctx->builder(); + OP_REQUIRES(ctx, ctx->num_inputs() >= 1, + errors::InvalidArgument("Diag op must have at an input")); const TensorShape input_shape = ctx->InputShape(0); auto dims = input_shape.dim_sizes(); @@ -36,7 +96,7 @@ class DiagOp : public XlaOpKernel { errors::InvalidArgument("Expected 1 <= dims, got shape ", input_shape.DebugString())); - xla::ComputationDataHandle diag = ctx->Input(0); + xla::ComputationDataHandle input = ctx->Input(0); // Picture: // tf.diag([1, 2, 3, 4]) ==> [[1, 0, 0, 0] @@ -46,13 +106,13 @@ class DiagOp : public XlaOpKernel { // Flattens the input to 1D. int64 size = input_shape.num_elements(); - diag = builder->Reshape(diag, {size}); + input = builder->Reshape(input, {size}); - // Adds inter-element padding of 'size'. - xla::PaddingConfig config; - auto* dim = config.add_dimensions(); - dim->set_interior_padding(size); - diag = builder->Pad(diag, XlaHelpers::Zero(builder, input_type(0)), config); + // Create an R2 with the R1 diagonal. + auto diag_or_status = + CreateDiagonal(input, size, /*other_dims=*/{}, ctx, builder); + OP_REQUIRES_OK(ctx, diag_or_status.status()); + xla::ComputationDataHandle diag = diag_or_status.ValueOrDie(); // Reshapes to the final shape. std::vector new_dims(dims.size() * 2); @@ -141,6 +201,8 @@ class MatrixDiagOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { xla::ComputationBuilder* builder = ctx->builder(); + OP_REQUIRES(ctx, ctx->num_inputs() >= 1, + errors::InvalidArgument("MatrixDiag op must have at an input")); const TensorShape input_shape = ctx->InputShape(0); auto dims = input_shape.dim_sizes(); @@ -152,17 +214,13 @@ class MatrixDiagOp : public XlaOpKernel { int last_dim = dims.size() - 1; int64 last_dim_size = input_shape.dim_size(last_dim); + tensorflow::gtl::ArraySlice other_dims(dims); + other_dims.pop_back(); - // Adds inter-element padding of 'last_dim_size' to the last dimension. - xla::PaddingConfig config = xla::MakeNoPaddingConfig(dims.size()); - auto* dim = config.mutable_dimensions(last_dim); - dim->set_interior_padding(last_dim_size); - diag = builder->Pad(diag, XlaHelpers::Zero(builder, input_type(0)), config); - - // Reshapes to the final shape. - dims.push_back(last_dim_size); - diag = builder->Reshape(diag, dims); - + auto diag_or_status = + CreateDiagonal(diag, last_dim_size, other_dims, ctx, builder); + OP_REQUIRES_OK(ctx, diag_or_status.status()); + diag = diag_or_status.ValueOrDie(); ctx->SetOutput(0, diag); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..d91ebb500b4479dbb3c8e2ea7719bc79dc24ba4f --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc @@ -0,0 +1,367 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/array4d.h" +#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/lib/math/math_util.h" + +namespace tensorflow { +namespace { + +// We implement bilinear interpolation by upsampling followed by convolution. +// The basic idea is as follows. To scale from NxN to RxR: +// +// 1. S := (N - 1) / gcd(N-1, R-1) +// 2. k := (R - 1) / gcd(N-1, R-1) +// 3. Convolution(kxk, stride=S, lhs_dilation=k, padding=k-1) +// +// For example, to Scale from 7x7 -> 15x15: +// +// 1. S := (7-1) / gcd(7-1, 15-1) = 6 / gcd(6, 14) = 6 / 2 = 3 +// 2. k := (15 - 1) / gcd(7-1, 15-1) = 14 / gcd(6, 14) = 14 / 2 = 7 +// 3. Convolution(7x7, stride=3, lhs_dilation=3, padding=2) +// +// +// The 7x7 -> 15x15 case is much too large to write out in full as an +// example. The smallest interesting example is 3x3 -> 4x4. +// +// S := 2 +// k := 3 +// +// 00 03 06 00 00 00 00 00 00 00 00 00 00 00 00 02 04 06 +// 09 12 15 -> 00 00 00 00 00 00 00 00 00 00 00 -> 06 08 10 12 +// 18 21 24 00 00 00 00 00 03 00 00 06 00 00 12 14 16 18 +// 00 00 00 00 00 00 00 00 00 00 00 18 20 22 24 +// 00 00 00 00 00 00 00 00 00 00 00 +// 00 00 09 00 00 12 00 00 15 00 00 +// 00 00 00 00 00 00 00 00 00 00 00 +// 00 00 00 00 00 00 00 00 00 00 00 +// 00 00 18 00 00 21 00 00 24 00 00 +// 00 00 00 00 00 00 00 00 00 00 00 +// 00 00 00 00 00 00 00 00 00 00 00 +// +// with the following convolutional kernel, with stride [2, 2]: +// 1 2 3 2 1 +// 2 4 6 4 2 +// 1/9 * 3 6 9 6 3 +// 2 4 6 4 2 +// 1 2 3 2 1 + +// Computes the size of the convolutional kernel and stride to use when resizing +// from in_size to out_size. +struct ResizeConvolutionDims { + // Size of the kernel to use. + std::vector kernel_size; + + // Stride of the convolution to use. + std::vector stride; +}; +ResizeConvolutionDims ComputeResizeConvolutionParameters( + gtl::ArraySlice in_size, gtl::ArraySlice out_size) { + CHECK_EQ(in_size.size(), out_size.size()); + int num_spatial_dims = in_size.size(); + ResizeConvolutionDims dims; + dims.kernel_size.resize(num_spatial_dims); + dims.stride.resize(num_spatial_dims); + for (int i = 0; i < num_spatial_dims; ++i) { + if (in_size[i] == 1) { + // We must handle input size 1 specially because XLA convolution does + // not allow stride 0. + dims.stride[i] = dims.kernel_size[i] = 1; + } else if (out_size[i] == 1) { + // If in_size[i] > 1 but out_size[i] == 1, then we slice out the first + // entry before resizing. + dims.stride[i] = dims.kernel_size[i] = 1; + } else { + int64 gcd = MathUtil::GCD(static_cast(in_size[i] - 1), + static_cast(out_size[i] - 1)); + dims.stride[i] = (in_size[i] - 1) / gcd; + dims.kernel_size[i] = (out_size[i] - 1) / gcd; + } + } + return dims; +} + +xla::ComputationDataHandle MakeBilinearResizeKernel( + xla::ComputationBuilder* builder, gtl::ArraySlice kernel_size, + int64 channels) { + // Form a 2D convolution kernel like: + // 1 2 3 2 1 + // 2 4 6 4 2 + // 1/9 * 3 6 9 6 3 + // 2 4 6 4 2 + // 1 2 3 2 1 + // by multiplying two 1D kernels of the form: + // 1/3 * [1 2 3 2 1] + auto make_1d_kernel = [](int64 n) { + std::vector kernel(n * 2 - 1); + for (int64 i = 0; i < n; ++i) { + float v = i + 1; + kernel[i] = v; + kernel[n * 2 - 2 - i] = v; + } + return kernel; + }; + + // Form a block diagonal kernel where each channel interacts only with itself. + xla::Array4D diag(1, 1, channels, channels, 0.0f); + for (int i = 0; i < channels; ++i) { + diag(0, 0, i, i) = 1.0f / (kernel_size[0] * kernel_size[1]); + } + return builder->Mul( + builder->ConstantR1(make_1d_kernel(kernel_size[0])), + builder->Mul(builder->ConstantR1(make_1d_kernel(kernel_size[1])), + builder->ConstantR4FromArray4D(diag), + /*broadcast_dimensions=*/{1}), + /*broadcast_dimensions=*/{0}); +} + +class ResizeBilinearOp : public XlaOpKernel { + public: + explicit ResizeBilinearOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("align_corners", &align_corners_)); + OP_REQUIRES( + ctx, align_corners_ == true, + errors::Unimplemented( + "ResizeBilinear with align_corners=False is not yet implemented")); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::ComputationBuilder* b = ctx->builder(); + + TensorShape input_shape = ctx->InputShape(0); + OP_REQUIRES(ctx, input_shape.dims() == 4, + errors::InvalidArgument("input must be 4-dimensional", + input_shape.DebugString())); + const int64 batch = input_shape.dim_size(0); + const std::vector in_size = {input_shape.dim_size(1), + input_shape.dim_size(2)}; + const int64 channels = input_shape.dim_size(3); + OP_REQUIRES(ctx, in_size[0] > 0 && in_size[1] > 0, + errors::InvalidArgument("input size must be positive, got [", + in_size[0], ",", in_size[1], "]")); + + std::vector out_size; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &out_size)); + OP_REQUIRES(ctx, out_size.size() == 2, + errors::InvalidArgument("output size must be length 2, got ", + out_size.size())); + OP_REQUIRES(ctx, out_size[0] > 0 && out_size[1] > 0, + errors::InvalidArgument("output size must be positive, got [", + out_size[0], ",", out_size[1], "]")); + + const int num_spatial_dims = 2; + + xla::ComputationDataHandle input = ctx->Input(0); + + // If in_size[i] > 1 and out_size[i] == 1, slice out the first input in + // dimension i. + std::vector slice_size = in_size; + bool slice_input = false; + for (int i = 0; i < num_spatial_dims; ++i) { + if (in_size[i] > 1 && out_size[i] == 1) { + // If in_size[i] > 1 but out_size[i] == 1, then we slice out the first + // entry before resizing. + slice_input = true; + slice_size[i] = 1; + } + } + if (slice_input) { + input = b->Slice(input, {0, 0, 0, 0}, + {batch, slice_size[0], slice_size[1], channels}, + {1, 1, 1, 1}); + } + + // Output is always type float. + input = b->ConvertElementType(input, xla::F32); + + // Picture for a 1x3 to 1x4 resize: + // stride = 2, kernel size = 3 + // Input: + // 3 6 9 + // Input with dilation and padding: + // 0 0 3 0 0 6 0 0 9 0 0 + // Convolution kernel: + // 1/3 * [1 2 3 2 1] + // Output: + // 3 5 7 9 + xla::ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + for (int i = 0; i < num_spatial_dims; ++i) { + dnums.add_input_spatial_dimensions(1 + i); + dnums.add_output_spatial_dimensions(1 + i); + dnums.add_kernel_spatial_dimensions(i); + } + dnums.set_kernel_input_feature_dimension(num_spatial_dims); + dnums.set_kernel_output_feature_dimension(num_spatial_dims + 1); + + ResizeConvolutionDims dims = + ComputeResizeConvolutionParameters(in_size, out_size); + xla::ComputationDataHandle kernel = + MakeBilinearResizeKernel(b, dims.kernel_size, channels); + xla::ComputationDataHandle output = b->ConvGeneralDilated( + input, kernel, dims.stride, + /*padding=*/ + {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, + {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}}, + /*lhs_dilation=*/dims.kernel_size, + /*rhs_dilation=*/{1, 1}, dnums); + + // Add broadcasts to handle expanding from a size == 1 dimension to a + // size > 1 dimension. + for (int i = 0; i < num_spatial_dims; ++i) { + if (in_size[i] == 1 && out_size[i] > 1) { + output = b->Add(output, b->ConstantR1(out_size[i], 0), + /*broadcast_dimensions=*/{1 + i}); + } + } + + ctx->SetOutput(0, output); + } + + private: + bool align_corners_; +}; + +REGISTER_XLA_OP(Name("ResizeBilinear"), ResizeBilinearOp); + +class ResizeBilinearGradOp : public XlaOpKernel { + public: + explicit ResizeBilinearGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("align_corners", &align_corners_)); + OP_REQUIRES( + ctx, align_corners_ == true, + errors::Unimplemented("ResizeBilinearGrad with align_corners=False is " + "not yet implemented")); + + DataType output_dtype; + OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &output_dtype)); + OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(output_dtype, &output_type_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::ComputationBuilder* b = ctx->builder(); + + TensorShape input_shape = ctx->InputShape(1); + OP_REQUIRES(ctx, input_shape.dims() == 4, + errors::InvalidArgument("input must be 4-dimensional", + input_shape.DebugString())); + const int64 batch = input_shape.dim_size(0); + const std::vector in_size = {input_shape.dim_size(1), + input_shape.dim_size(2)}; + const int64 channels = input_shape.dim_size(3); + OP_REQUIRES(ctx, in_size[0] > 0 && in_size[1] > 0, + errors::InvalidArgument("input size must be positive, got [", + in_size[0], ",", in_size[1], "]")); + + TensorShape grad_shape = ctx->InputShape(0); + OP_REQUIRES(ctx, grad_shape.dims() == 4, + errors::InvalidArgument("gradient must be 4-dimensional", + grad_shape.DebugString())); + const int64 grad_batch = grad_shape.dim_size(0); + const std::vector grad_size = {grad_shape.dim_size(1), + grad_shape.dim_size(2)}; + const int64 grad_channels = grad_shape.dim_size(3); + OP_REQUIRES(ctx, batch == grad_batch, + errors::InvalidArgument( + "activations and gradients must have the same batch size (", + batch, " vs. ", grad_batch, ")")); + OP_REQUIRES(ctx, grad_size[0] > 0 && grad_size[1] > 0, + errors::InvalidArgument("gradient size must be positive, got [", + grad_size[0], ",", grad_size[1], "]")); + OP_REQUIRES( + ctx, channels == grad_channels, + errors::InvalidArgument( + "activations and gradients must have the same number of channels (", + channels, " vs. ", grad_channels, ")")); + + const int num_spatial_dims = 2; + + xla::ComputationDataHandle grad = ctx->Input(0); + + ResizeConvolutionDims dims = + ComputeResizeConvolutionParameters(in_size, grad_size); + + // To form the backward convolution, we keep the kernel unchanged (it is + // already symmetric) and swap the roles of strides and LHS dilation. + xla::ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + for (int i = 0; i < num_spatial_dims; ++i) { + dnums.add_input_spatial_dimensions(1 + i); + dnums.add_output_spatial_dimensions(1 + i); + dnums.add_kernel_spatial_dimensions(i); + } + dnums.set_kernel_input_feature_dimension(num_spatial_dims); + dnums.set_kernel_output_feature_dimension(num_spatial_dims + 1); + xla::ComputationDataHandle kernel = + MakeBilinearResizeKernel(b, dims.kernel_size, channels); + + // Broadcast the input kernel where the forward op expanded from a size == 1 + // dimension to a size > 1 dimension. This has the effect of summing the + // gradient contributions in that dimension. + for (int i = 0; i < num_spatial_dims; ++i) { + if (in_size[i] == 1 && grad_size[i] > 1) { + kernel = b->Add(kernel, b->ConstantR1(grad_size[i], 0), + /*broadcast_dimensions=*/{i}); + } + } + + xla::ComputationDataHandle output = b->ConvGeneralDilated( + grad, kernel, /*window_strides=*/dims.kernel_size, + /*padding=*/ + {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, + {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}}, + /*lhs_dilation=*/dims.stride, + /*rhs_dilation=*/{1, 1}, dnums); + + // If in_size[i] > 1 and grad_size[i] == 1, pad the output in dimension i. + // Opposite of the slice performed by the forward op. + xla::PaddingConfig padding = xla::MakeNoPaddingConfig(4); + bool pad_output = false; + for (int i = 0; i < num_spatial_dims; ++i) { + if (in_size[i] > 1 && grad_size[i] == 1) { + pad_output = true; + padding.mutable_dimensions(1 + i)->set_edge_padding_high(in_size[i] - + 1); + } + } + if (pad_output) { + output = b->Pad(output, b->ConstantR0(0.0f), padding); + } + + output = b->ConvertElementType(output, output_type_); + ctx->SetOutput(0, output); + } + + private: + bool align_corners_; + xla::PrimitiveType output_type_; +}; + +REGISTER_XLA_OP(Name("ResizeBilinearGrad"), ResizeBilinearGradOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops.cc b/tensorflow/compiler/tf2xla/kernels/index_ops.cc index b8769b3ea2be0a791d9c3e5e7acd8b6184442af2..e0dc1870f2a4934c35163f0cc10196e8fcbed9be 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops.cc @@ -60,54 +60,20 @@ void XlaArgMinMaxOp::Compile(XlaOpKernelContext* ctx) { input_shape.DebugString())); DataType index_type = output_type(0); - xla::PrimitiveType xla_input_type; - OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(input_type(0), &xla_input_type)); - xla::PrimitiveType xla_index_type; - OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(index_type, &xla_index_type)); xla::ComputationBuilder* b = ctx->builder(); xla::ComputationDataHandle input = ctx->Input(0); - xla::ComputationDataHandle init_value; - const xla::Computation* reducer; + xla::ComputationDataHandle output; if (is_min_) { - init_value = XlaHelpers::MaxValue(b, input_type(0)); - reducer = ctx->GetOrCreateMin(input_type(0)); + OP_REQUIRES_OK(ctx, + XlaHelpers::ArgMin(b, ctx, input, input_shape, input_type(0), + index_type, axis, &output)); } else { - init_value = XlaHelpers::MinValue(b, input_type(0)); - reducer = ctx->GetOrCreateMax(input_type(0)); + OP_REQUIRES_OK(ctx, + XlaHelpers::ArgMax(b, ctx, input, input_shape, input_type(0), + index_type, axis, &output)); } - xla::ComputationDataHandle input_max = - b->Reduce(input, init_value, *reducer, /*dimensions_to_reduce=*/{axis}); - std::vector broadcast_dims(input_dims - 1); - std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0); - std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1); - // Compute a mask that has 1s for elements equal to the maximum. - xla::ComputationDataHandle partial_mask = b->ConvertElementType( - b->Eq(input, input_max, broadcast_dims), xla_index_type); - - // In order to make identity elements for a bitwise And, we: - // Left shift the 1 to the leftmost bit, yielding 0x10...0 - // Arithmetic right shift the 1 back to the rightmost bit, yielding 0xFF...F - int32 bits_in_type = - xla::ShapeUtil::ByteSizeOfPrimitiveType(xla_index_type) * 8 - 1; - xla::ComputationDataHandle shift_amount = - XlaHelpers::IntegerLiteral(b, index_type, bits_in_type); - xla::ComputationDataHandle full_mask = b->ShiftRightArithmetic( - b->ShiftLeft(partial_mask, shift_amount), shift_amount); - - // And with the vector [0, 1, 2, ...] to convert each 0xFF...F into its index. - xla::ComputationDataHandle iota; - OP_REQUIRES_OK(ctx, XlaHelpers::Iota(b, index_type, axis_size, &iota)); - xla::ComputationDataHandle product = - b->And(full_mask, iota, /*broadcast_dimensions=*/{axis}); - - // If there are multiple maximum elements, choose the one with the highest - // index. - xla::ComputationDataHandle output = - b->Reduce(product, XlaHelpers::MinValue(b, index_type), - *ctx->GetOrCreateMax(index_type), - /*dimensions_to_reduce=*/{axis}); ctx->SetOutput(0, output); } diff --git a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc index fcef497e5845d9080bc83b54e92dcf2fdecf5f12..644abd5905c6ce5a8f61792a1986560bab891040 100644 --- a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc @@ -23,8 +23,8 @@ limitations under the License. namespace tensorflow { namespace { -constexpr std::array kMatmulTypes = { - {DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64}}; +constexpr std::array kMatmulTypes = { + {DT_HALF, DT_BFLOAT16, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64}}; class MatMulOp : public XlaOpKernel { public: @@ -85,10 +85,7 @@ class SparseMatMulOp : public MatMulOp { ~SparseMatMulOp() override = default; }; -REGISTER_XLA_OP(Name("SparseMatMul") - .TypeConstraint("Ta", kFloatTypes) - .TypeConstraint("Tb", kFloatTypes), - SparseMatMulOp); +REGISTER_XLA_OP(Name("SparseMatMul"), SparseMatMulOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/retval_op.cc b/tensorflow/compiler/tf2xla/kernels/retval_op.cc index 462267d1504f16a5fc1f34f5804649416699005a..c283e3b02c2676785952e3e17bffa671b0dabc1e 100644 --- a/tensorflow/compiler/tf2xla/kernels/retval_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/retval_op.cc @@ -60,7 +60,13 @@ class RetvalOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, ctx->ConstantInput(0, &literal)); OP_REQUIRES_OK(ctx, tc.AddConstRetval(index_, dtype_, literal)); } else { - tc.AddRetval(index_, dtype_, input); + // The core from which a return value is returned depends on the core + // assignment of the input to the retval .Since we can't change the core + // assignment of as this point, create a tuple/get-tuple-element + // combination so that the core will be set on them. + auto tuple_elem = + ctx->builder()->GetTupleElement(ctx->builder()->Tuple({input}), 0); + tc.AddRetval(index_, dtype_, tuple_elem); } } } diff --git a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..650f8c7dc8be0cb08997ec641ca3f82352166fdd --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc @@ -0,0 +1,141 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/kernels/concat_lib.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace { + +// TODO(phawkins): implement double-sized windowed reductions in XLA and remove +// the type constraint. +constexpr std::array kScanOpTypes = { + {DT_HALF, DT_BFLOAT16, DT_FLOAT}}; + +class ScanOp : public XlaOpKernel { + public: + ScanOp(OpKernelConstruction* ctx, bool sum) : XlaOpKernel(ctx), sum_(sum) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("reverse", &reverse_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("exclusive", &exclusive_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + const TensorShape input_shape = ctx->InputShape(0); + const TensorShape tensor_axis_shape = ctx->InputShape(1); + + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(tensor_axis_shape), + errors::InvalidArgument("ScanOp: axis must be a scalar, not ", + tensor_axis_shape.DebugString())); + + int64 axis; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &axis)); + if (axis < 0) { + axis += input_shape.dims(); + } + OP_REQUIRES( + ctx, FastBoundsCheck(axis, input_shape.dims()), + errors::InvalidArgument("ScanOp: Expected scan axis in the range [", + -input_shape.dims(), ", ", input_shape.dims(), + "), but got ", axis)); + + DataType dtype = ctx->input_type(0); + + if (input_shape.num_elements() == 0) { + // Exit early if there is nothing to compute. + ctx->SetOutput(0, ctx->Input(0)); + return; + } + + xla::ComputationBuilder* builder = ctx->builder(); + + std::vector window_strides(input_shape.dims(), 1); + std::vector window_dims(input_shape.dims(), 1); + window_dims[axis] = input_shape.dim_size(axis); + + std::vector> padding(input_shape.dims(), {0, 0}); + padding[axis].first = input_shape.dim_size(axis) - 1; + // In exclusive mode, add an extra padding element so there is a complete + // window of padding before the data starts. + if (exclusive_) { + ++padding[axis].first; + } + if (reverse_) { + std::swap(padding[axis].first, padding[axis].second); + } + + xla::ComputationDataHandle input = ctx->Input(0); + xla::ComputationDataHandle init; + const xla::Computation* reducer; + if (sum_) { + init = XlaHelpers::Zero(builder, dtype); + reducer = ctx->GetOrCreateAdd(dtype); + } else { + init = XlaHelpers::One(builder, dtype); + reducer = ctx->GetOrCreateMul(dtype); + } + auto output = builder->ReduceWindowWithGeneralPadding( + ctx->Input(0), init, *reducer, window_dims, window_strides, padding); + + // In exclusive mode, we have computed an extra element containing the sum + // of all the input elements. Slice off this extra "last" element. + if (exclusive_) { + if (reverse_) { + output = builder->SliceInDim(output, 1, input_shape.dim_size(axis) + 1, + 1, axis); + + } else { + output = + builder->SliceInDim(output, 0, input_shape.dim_size(axis), 1, axis); + } + } + ctx->SetOutput(0, output); + } + + private: + const bool sum_; // True=cumulative sum. False=cumulative product. + bool reverse_; + bool exclusive_; +}; + +class CumsumOp : public ScanOp { + public: + explicit CumsumOp(OpKernelConstruction* ctx) : ScanOp(ctx, /*sum=*/true) {} +}; +REGISTER_XLA_OP(Name("Cumsum").TypeConstraint("T", kScanOpTypes), CumsumOp); + +class CumprodOp : public ScanOp { + public: + explicit CumprodOp(OpKernelConstruction* ctx) : ScanOp(ctx, /*sum=*/false) {} +}; +REGISTER_XLA_OP(Name("Cumprod").TypeConstraint("T", kScanOpTypes), CumprodOp); + +} // anonymous namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc index 24a99f253d6dc8bb699fff587c363b12c227e821..e205fadd2b1bcae96a7bfa1bc83096d405ce22c4 100644 --- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc @@ -15,6 +15,7 @@ limitations under the License. // XLA-specific Shape Ops. +#include "tensorflow/compiler/tf2xla/kernels/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" @@ -27,56 +28,42 @@ namespace { class ShapeOp : public XlaOpKernel { public: - explicit ShapeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + explicit ShapeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("out_type", &out_dtype_)); + } void Compile(XlaOpKernelContext* ctx) override { const TensorShape input_shape = ctx->InputShape(0); - const int rank = input_shape.dims(); - Tensor shape_constant(DT_INT32, TensorShape({rank})); - auto vec = shape_constant.vec(); - // TODO(dga): support int64. b/28119922. - for (int i = 0; i < rank; ++i) { - int64 dim_size = input_shape.dim_size(i); - OP_REQUIRES( - ctx, FastBoundsCheck(dim_size, std::numeric_limits::max()), - errors::InvalidArgument("Shape does not support tensors > int32max", - " but dim ", i, " is ", dim_size)); - vec(i) = static_cast(dim_size); - } - + Tensor shape_constant(out_dtype_, TensorShape({input_shape.dims()})); + OP_REQUIRES_OK(ctx, TensorShapeToConstant(input_shape, &shape_constant)); ctx->SetConstantOutput(0, shape_constant); } + + private: + DataType out_dtype_; }; REGISTER_XLA_OP(Name("Shape"), ShapeOp); class ShapeNOp : public XlaOpKernel { public: - explicit ShapeNOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + explicit ShapeNOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("out_type", &out_dtype_)); + } void Compile(XlaOpKernelContext* ctx) override { for (int i = 0; i < ctx->num_inputs(); ++i) { - const TensorShape shape = ctx->InputShape(i); - const int dims = shape.dims(); - Tensor shape_constant(DT_INT32, TensorShape({dims})); - auto vec = shape_constant.vec(); - - // TODO(dga): support int64. b/28119922. - for (int j = 0; j < dims; ++j) { - int64 dim_size = shape.dim_size(j); - OP_REQUIRES( - ctx, FastBoundsCheck(dim_size, std::numeric_limits::max()), - errors::InvalidArgument("Shape does not support tensors > int32max", - " but shape ", i, " dim ", j, " is ", - dim_size)); - vec(j) = static_cast(dim_size); - } - + const TensorShape input_shape = ctx->InputShape(i); + Tensor shape_constant(out_dtype_, TensorShape({input_shape.dims()})); + OP_REQUIRES_OK(ctx, TensorShapeToConstant(input_shape, &shape_constant)); ctx->SetConstantOutput(i, shape_constant); } } bool IsExpensive() override { return false; } + + private: + DataType out_dtype_; }; REGISTER_XLA_OP(Name("ShapeN"), ShapeNOp); diff --git a/tensorflow/compiler/tf2xla/kernels/shape_util.cc b/tensorflow/compiler/tf2xla/kernels/shape_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..76ea5f525598f511f295eb5a30f3cf603fbf57aa --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/shape_util.cc @@ -0,0 +1,48 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/kernels/shape_util.h" + +#include + +#include "tensorflow/core/kernels/bounds_check.h" + +namespace tensorflow { + +Status TensorShapeToConstant(const TensorShape& input_shape, + Tensor* shape_constant) { + const int dims = input_shape.dims(); + if (shape_constant->dtype() == DT_INT32) { + auto vec = shape_constant->vec(); + for (int i = 0; i < dims; ++i) { + int64 dim_size = input_shape.dim_size(i); + if (!FastBoundsCheck(dim_size, std::numeric_limits::max())) { + return errors::InvalidArgument( + "Shape with out_type=int32 does not support tensors > int32max", + " but dim ", i, " is ", dim_size); + } + vec(i) = static_cast(dim_size); + } + } else { + auto vec = shape_constant->vec(); + for (int i = 0; i < dims; ++i) { + int64 dim_size = input_shape.dim_size(i); + vec(i) = dim_size; + } + } + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/shape_util.h b/tensorflow/compiler/tf2xla/kernels/shape_util.h new file mode 100644 index 0000000000000000000000000000000000000000..575086e118080f6799a54d3ae6409b2b641c4341 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/shape_util.h @@ -0,0 +1,34 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_SHAPE_UTIL_H_ +#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_SHAPE_UTIL_H_ + +#include + +#include "tensorflow/core/framework/tensor.h" + +namespace tensorflow { + +// Converts a TensorShape to a constant Tensor. +// +// The input TensorShape input_shape is used to populate the elements of +// shape_constant, which is modified in place. +Status TensorShapeToConstant(const TensorShape& input_shape, + Tensor* shape_constant); + +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_TF2XLA_KERNELS_SHAPE_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc index 89befda346ec06fec23ab1d1c9d910ded8cd806d..806fda632cde64c1b37ae3b9199028d6b6b0a215 100644 --- a/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/util/tensor_format.h" namespace tensorflow { namespace { @@ -23,6 +24,16 @@ namespace { class SpaceToDepthOp : public XlaOpKernel { public: explicit SpaceToDepthOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + string data_format_str; + OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str)); + OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_), + errors::InvalidArgument("Invalid data format")); + + OP_REQUIRES(ctx, data_format_ == FORMAT_NCHW || data_format_ == FORMAT_NHWC, + errors::InvalidArgument("Unsupported data format ", + ToString(data_format_), + "; expected formats NHWC or NCHW")); + OP_REQUIRES_OK(ctx, ctx->GetAttr("block_size", &block_size_)); OP_REQUIRES( ctx, block_size_ > 1, @@ -31,34 +42,100 @@ class SpaceToDepthOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { const TensorShape input_tensor_shape = ctx->InputShape(0); - // The input is presumed to be [batch, height, width, depth] int input_rank = input_tensor_shape.dims(); static const int kRequiredDims = 4; OP_REQUIRES(ctx, kRequiredDims == input_rank, - errors::InvalidArgument("Input rank should be: ", kRequiredDims, - " instead of: ", input_rank)); + errors::InvalidArgument("Input rank should be ", kRequiredDims, + "; got ", input_rank)); const gtl::InlinedVector input_shape = input_tensor_shape.dim_sizes(); xla::ComputationBuilder* b = ctx->builder(); xla::ComputationDataHandle input = ctx->Input(0); + int feature_dim = GetTensorFeatureDimIndex(input_rank, data_format_); + int num_spatial_dims = GetTensorSpatialDims(input_rank, data_format_); + + std::vector reshaped_shape; + std::vector transpose_order; + std::vector output_shape; + reshaped_shape.reserve(input_rank); + transpose_order.reserve(input_rank); + output_shape.reserve(input_rank); + if (data_format_ == FORMAT_NHWC) { + int64 block_elems = 1; + for (int i = 0; i < num_spatial_dims; ++i) { + OP_REQUIRES(ctx, input_shape[1 + i] % block_size_ == 0, + errors::InvalidArgument( + "input shape[", 1 + i, "]=", input_shape[1 + i], + " is not divisible by block_size=", block_size_)); + block_elems *= block_size_; + } + + reshaped_shape.push_back(input_shape[0]); + for (int i = 0; i < num_spatial_dims; ++i) { + reshaped_shape.push_back(input_shape[1 + i] / block_size_); + reshaped_shape.push_back(block_size_); + } + reshaped_shape.push_back(input_shape[feature_dim]); + + transpose_order.push_back(0); + for (int i = 0; i < num_spatial_dims; ++i) { + transpose_order.push_back(i * 2 + 1); + } + for (int i = 0; i < num_spatial_dims; ++i) { + transpose_order.push_back(i * 2 + 2); + } + transpose_order.push_back(feature_dim + num_spatial_dims); + + output_shape.push_back(input_shape[0]); + for (int i = 0; i < num_spatial_dims; ++i) { + output_shape.push_back(input_shape[1 + i] / block_size_); + } + output_shape.push_back(input_shape[feature_dim] * block_elems); + } else { + // FORMAT_NCHW + int64 block_elems = 1; + for (int i = 0; i < num_spatial_dims; ++i) { + OP_REQUIRES(ctx, input_shape[2 + i] % block_size_ == 0, + errors::InvalidArgument( + "input shape[", 2 + i, "]=", input_shape[2 + i], + " is not divisible by block_size=", block_size_)); + block_elems *= block_size_; + } + + reshaped_shape.push_back(input_shape[0]); + reshaped_shape.push_back(input_shape[feature_dim]); + for (int i = 0; i < num_spatial_dims; ++i) { + reshaped_shape.push_back(input_shape[2 + i] / block_size_); + reshaped_shape.push_back(block_size_); + } + + transpose_order.push_back(0); + for (int i = 0; i < num_spatial_dims; ++i) { + transpose_order.push_back(i * 2 + 3); + } + transpose_order.push_back(feature_dim); + for (int i = 0; i < num_spatial_dims; ++i) { + transpose_order.push_back(i * 2 + 2); + } + + output_shape.push_back(input_shape[0]); + output_shape.push_back(input_shape[feature_dim] * block_elems); + for (int i = 0; i < num_spatial_dims; ++i) { + output_shape.push_back(input_shape[2 + i] / block_size_); + } + } + + // Note: comments are given in NHWC format; NCHW is similar with a different + // dimension order. // 1. Reshape `input` to `reshaped` of shape: // // [batch, // input_shape[1] / block_size_, block_size_, // input_shape[2] / block_size_, block_size_, // depth] - const int block_rank = 2; - for (int i = 0; i < block_rank; ++i) { - OP_REQUIRES(ctx, input_shape[1 + i] % block_size_ == 0, - errors::InvalidArgument( - "input shape[", 1 + i, "]=", input_shape[1 + i], - " is not divisible by block_size=", block_size_)); - } - xla::ComputationDataHandle reshaped = b->Reshape( - input, {input_shape[0], input_shape[1] / block_size_, block_size_, - input_shape[2] / block_size_, block_size_, input_shape[3]}); + xla::ComputationDataHandle reshaped = b->Reshape(input, reshaped_shape); // 2. Permute dimensions of `reshaped` to produce // `permuted_reshaped` of shape: @@ -69,7 +146,7 @@ class SpaceToDepthOp : public XlaOpKernel { // block_size_, block_size_, // depth] xla::ComputationDataHandle permuted_reshaped = - b->Transpose(reshaped, {0, 1, 3, 2, 4, 5}); + b->Transpose(reshaped, transpose_order); // 3. Reshape `permuted_reshaped` to flatten `block_shape` into the // batch dimension, producing an output tensor of shape: @@ -79,15 +156,14 @@ class SpaceToDepthOp : public XlaOpKernel { // input_shape[2] / block_size_, // block_size_ * block_size_ * depth] // - xla::ComputationDataHandle output = b->Reshape( - permuted_reshaped, {input_shape[0], input_shape[1] / block_size_, - input_shape[2] / block_size_, - block_size_ * block_size_ * input_shape[3]}); + xla::ComputationDataHandle output = + b->Reshape(permuted_reshaped, output_shape); ctx->SetOutput(0, output); } private: + TensorFormat data_format_; int block_size_; }; REGISTER_XLA_OP(Name("SpaceToDepth"), SpaceToDepthOp); diff --git a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..b10880de77e6b9811008076cd4a959c284e558d1 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc @@ -0,0 +1,279 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/lib/core/casts.h" +#include "tensorflow/core/lib/math/math_util.h" + +namespace tensorflow { +namespace { + +// Rotates a 32-bit integer 'v' left by 'distance' bits. +xla::ComputationDataHandle RotateLeftS32(xla::ComputationBuilder* builder, + const xla::ComputationDataHandle& v, + int distance) { + return builder->Or( + builder->ShiftLeft(v, builder->ConstantR0(distance)), + builder->ShiftRightLogical(v, builder->ConstantR0(32 - distance))); +} + +// TODO(b/65209188): add a primitive XOR to XLA and call it here, rather than +// building XOR out of other bitwise operators. +xla::ComputationDataHandle BitwiseXor(xla::ComputationBuilder* builder, + const xla::ComputationDataHandle& x, + const xla::ComputationDataHandle& y) { + return builder->Or(builder->And(x, builder->Not(y)), + builder->And(builder->Not(x), y)); +} + +using ThreeFry2x32State = std::array; + +// Implements the ThreeFry counter-based PRNG algorithm. +// Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3. +// http://www.thesalmons.org/john/random123/papers/random123sc11.pdf +ThreeFry2x32State ThreeFry2x32(xla::ComputationBuilder* builder, + ThreeFry2x32State input, ThreeFry2x32State key) { + // Rotation distances specified by the Threefry2x32 algorithm. + constexpr std::array rotations = {13, 15, 26, 6, 17, 29, 16, 24}; + ThreeFry2x32State x; + + std::array ks; + // 0x1BD11BDA is a parity constant specified by the ThreeFry2x32 algorithm. + ks[2] = builder->ConstantR0(0x1BD11BDA); + for (int i = 0; i < 2; ++i) { + ks[i] = key[i]; + x[i] = input[i]; + ks[2] = BitwiseXor(builder, ks[2], key[i]); + } + + x[0] = builder->Add(x[0], ks[0]); + x[1] = builder->Add(x[1], ks[1]); + + // Performs a single round of the Threefry2x32 algorithm, with a rotation + // amount 'rotation'. + auto round = [builder](ThreeFry2x32State v, int rotation) { + v[0] = builder->Add(v[0], v[1]); + v[1] = RotateLeftS32(builder, v[1], rotation); + v[1] = BitwiseXor(builder, v[0], v[1]); + return v; + }; + + // There are no known statistical flaws with 13 rounds of Threefry2x32. + // We are conservative and use 20 rounds. + x = round(x, rotations[0]); + x = round(x, rotations[1]); + x = round(x, rotations[2]); + x = round(x, rotations[3]); + x[0] = builder->Add(x[0], ks[1]); + x[1] = builder->Add(builder->Add(x[1], ks[2]), builder->ConstantR0(1)); + + x = round(x, rotations[4]); + x = round(x, rotations[5]); + x = round(x, rotations[6]); + x = round(x, rotations[7]); + x[0] = builder->Add(x[0], ks[2]); + x[1] = builder->Add(builder->Add(x[1], ks[0]), builder->ConstantR0(2)); + + x = round(x, rotations[0]); + x = round(x, rotations[1]); + x = round(x, rotations[2]); + x = round(x, rotations[3]); + x[0] = builder->Add(x[0], ks[0]); + x[1] = builder->Add(builder->Add(x[1], ks[1]), builder->ConstantR0(3)); + + x = round(x, rotations[4]); + x = round(x, rotations[5]); + x = round(x, rotations[6]); + x = round(x, rotations[7]); + x[0] = builder->Add(x[0], ks[1]); + x[1] = builder->Add(builder->Add(x[1], ks[2]), builder->ConstantR0(4)); + + x = round(x, rotations[0]); + x = round(x, rotations[1]); + x = round(x, rotations[2]); + x = round(x, rotations[3]); + x[0] = builder->Add(x[0], ks[2]); + x[1] = builder->Add(builder->Add(x[1], ks[0]), builder->ConstantR0(5)); + + return x; +} + +// Returns a tensor of 'shape' random values uniformly distributed in the range +// [minval, maxval) +xla::ComputationDataHandle RandomUniform(xla::ComputationBuilder* builder, + const xla::ComputationDataHandle& seed, + const TensorShape& shape, + double minval, double maxval) { + // Split the seed into two 32-bit scalars to form a key. + auto seed0 = builder->Reshape(builder->Slice(seed, {0}, {1}, {1}), {}); + auto seed1 = builder->Reshape(builder->Slice(seed, {1}, {2}, {1}), {}); + ThreeFry2x32State key = {seed0, seed1}; + const int64 size = shape.num_elements(); + + const int64 half_size = MathUtil::CeilOfRatio(size, 2); + const bool size_is_odd = (half_size * 2 != size); + + // Fill the generator inputs with unique counter values. + ThreeFry2x32State inputs; + TF_CHECK_OK(XlaHelpers::Iota(builder, DT_INT32, half_size, &inputs[0])); + inputs[1] = builder->Add(inputs[0], builder->ConstantR0(half_size)); + ThreeFry2x32State outputs = ThreeFry2x32(builder, inputs, key); + + if (size_is_odd) { + outputs[1] = builder->Slice(outputs[1], {0}, {half_size - 1}, {1}); + } + + auto bits = + builder->Reshape(builder->ConcatInDim(outputs, 0), shape.dim_sizes()); + + // Form 22 random mantissa bits, with a leading 1 bit. The leading 1 bit + // forces the random bits into the mantissa. + constexpr int kFloatBits = 32; + constexpr int kMantissaBits = 23; + bits = builder->Or( + builder->ShiftRightLogical( + bits, builder->ConstantR0(kFloatBits - kMantissaBits)), + builder->ConstantR0(bit_cast(1.0f))); + auto floats = builder->BitcastConvertType(bits, xla::F32); + + // We have a floating point number in the range [1.0, 2.0). + // Subtract 1.0f to shift to the range [0.0, 1.0) + floats = builder->Sub(floats, builder->ConstantR0(1.0f)); + // Multiply and add to shift to the range [minval, maxval). + floats = builder->Mul(floats, builder->ConstantR0(maxval - minval)); + floats = builder->Add(floats, builder->ConstantR0(minval)); + return floats; +} + +// Approximation for the inverse error function from +// Giles, M., "Approximating the erfinv function". +// The approximation has the form: +// w = -log((1 - x) * (1 + x)) +// if ( w < 5 ) { +// w = w - 2.5 +// p = sum_{i=1}^n lq[i]*w^i +// } else { +// w = sqrt(w) - 3 +// p = sum_{i=1}^n gq[i]*w^i +// } +// return p*x +xla::ComputationDataHandle ErfInvF32(xla::ComputationBuilder* b, + const xla::ComputationDataHandle& x, + const TensorShape& shape) { + constexpr int kDegree = 9; + constexpr std::array w_less_than_5_constants = { + 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f, + -4.39150654e-06f, 0.00021858087f, -0.00125372503f, + -0.00417768164f, 0.246640727f, 1.50140941f}; + constexpr std::array w_greater_than_5_constants = { + -0.000200214257f, 0.000100950558f, 0.00134934322f, + -0.00367342844f, 0.00573950773f, -0.0076224613f, + 0.00943887047f, 1.00167406f, 2.83297682f}; + + auto one = b->ConstantR0(1.0); + auto w = b->Neg(b->Log(b->Mul(b->Sub(one, x), b->Add(one, x)))); + + auto lt = b->Lt(w, b->ConstantR0(5.0)); + auto coefficient = [&](int i) { + return b->Select( + lt, + b->Broadcast(b->ConstantR0(w_less_than_5_constants[i]), + shape.dim_sizes()), + b->Broadcast(b->ConstantR0(w_greater_than_5_constants[i]), + shape.dim_sizes())); + }; + w = b->Select(lt, b->Sub(w, b->ConstantR0(2.5f)), + b->Sub(b->SqrtF32(w), b->ConstantR0(3.0f))); + auto p = coefficient(0); + for (int i = 1; i < kDegree; ++i) { + p = b->Add(coefficient(i), b->Mul(p, w)); + } + return b->Mul(p, x); +} + +} // namespace + +class StatelessRandomUniformOp : public XlaOpKernel { + public: + explicit StatelessRandomUniformOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + xla::ComputationBuilder* builder = ctx->builder(); + + TensorShape shape; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape)); + + TensorShape seed_shape = ctx->InputShape(1); + OP_REQUIRES(ctx, seed_shape.dims() == 1 && seed_shape.dim_size(0) == 2, + errors::InvalidArgument("seed must have shape [2], not ", + seed_shape.DebugString())); + xla::ComputationDataHandle seed = ctx->Input(1); + ctx->SetOutput(0, RandomUniform(builder, seed, shape, 0.0, 1.0)); + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomUniformOp); +}; + +// TODO(phawkins): generalize to non-float, non-int32 seed types. +REGISTER_XLA_OP(Name("StatelessRandomUniform") + .TypeConstraint("dtype", DT_FLOAT) + .TypeConstraint("Tseed", DT_INT32), + StatelessRandomUniformOp); + +class StatelessRandomNormalOp : public XlaOpKernel { + public: + explicit StatelessRandomNormalOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + TensorShape shape; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape)); + + TensorShape seed_shape = ctx->InputShape(1); + OP_REQUIRES(ctx, seed_shape == TensorShape({2}), + errors::InvalidArgument("seed must have shape [2], not ", + seed_shape.DebugString())); + xla::ComputationDataHandle seed = ctx->Input(1); + xla::ComputationBuilder* builder = ctx->builder(); + auto uniform = RandomUniform(builder, seed, shape, -1.0, 1.0); + // Convert uniform distribution to normal distribution by computing + // sqrt(2) * erfinv(x) + auto normal = builder->Mul(builder->ConstantR0(std::sqrt(2.0)), + ErfInvF32(builder, uniform, shape)); + ctx->SetOutput(0, normal); + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomNormalOp); +}; + +// TODO(phawkins): generalize to non-float, non-int32 seed types. +REGISTER_XLA_OP(Name("StatelessRandomNormal") + .TypeConstraint("dtype", DT_FLOAT) + .TypeConstraint("Tseed", DT_INT32), + StatelessRandomNormalOp); + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc index 351fda251798e43b607fb445f2c98abd57b3d86b..03c22354a9425189e6cf7ee5a7201c90ecb1908d 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc @@ -311,6 +311,32 @@ class TensorArrayGatherOp : public XlaOpKernel { xla::ComputationDataHandle ta = resource->value; + // Look for the case where the gather takes a simple slice from the + // tensor array (0, 1, 2, 3, 4, ..., N) + std::vector const_indices; + Status status = ctx->ConstantInputAsIntVector(1, &const_indices); + if (status.ok()) { + bool gather_is_dense_slice = true; + for (auto i = 0; i < const_indices.size(); i++) { + if (const_indices[i] != i) { + gather_is_dense_slice = false; + break; + } + } + + if (gather_is_dense_slice) { + std::vector begin(ta_shape.dims(), 0); + std::vector strides(ta_shape.dims(), 1); + std::vector end(ta_shape.dims(), 1); + end[0] = const_indices.size(); + for (auto i = 1; i < ta_shape.dims(); i++) { + end[i] = ta_shape.dim_size(i); + } + ctx->SetOutput(0, b->Slice(ta, begin, end, strides)); + return; + } + } + xla::ComputationDataHandle gather = XlaComputeGatherDynamicSlice( ctx, ta, ta_shape, indices, indices_shape, 0, dtype_, index_type, b); ctx->SetOutput(0, gather); @@ -352,28 +378,47 @@ class TensorArrayScatterOp : public XlaOpKernel { const xla::ComputationDataHandle value = ctx->Input(2); const xla::ComputationDataHandle flow = ctx->Input(3); - auto slice_dims = value_shape.dim_sizes(); - slice_dims[0] = 1LL; - - std::vector value_starts(value_shape.dims(), 0); - auto value_ends = value_shape.dim_sizes(); - - std::vector value_strides(value_shape.dims(), 1); - - // For every (index, value) pair, update the corresponding TensorArray - // storage. - for (int i = 0; i < num_indices; ++i) { - // Slice out part of the value. - value_starts[0] = i; - value_ends[0] = i + 1; - auto slice = b->Slice(value, value_starts, value_ends, value_strides); + // Look for the case where the scatter is for each sub-tensor in order. The + // tensor array implementation allows for this to be a straight addition. + bool scatter_all_elements_in_order = false; + std::vector const_indices; + Status status = ctx->ConstantInputAsIntVector(1, &const_indices); + if (status.ok() && num_indices == value_shape.dim_size(0)) { + scatter_all_elements_in_order = true; + for (auto i = 0; i < num_indices; i++) { + if (const_indices[i] != i) { + scatter_all_elements_in_order = false; + break; + } + } + } - // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. - auto index = b->Slice(indices, {i}, {i + 1}, {1}); - auto start_indices = - b->Pad(b->Reshape(index, {1}), b->ConstantR0(0), - xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}})); - ta = DynamicAddSlice(b, ta, slice, slice_dims, start_indices); + if (scatter_all_elements_in_order) { + ta = b->Add(ta, value); + } else { + auto slice_dims = value_shape.dim_sizes(); + slice_dims[0] = 1LL; + + std::vector value_starts(value_shape.dims(), 0); + auto value_ends = value_shape.dim_sizes(); + + std::vector value_strides(value_shape.dims(), 1); + + // For every (index, value) pair, update the corresponding TensorArray + // storage. + for (int i = 0; i < num_indices; ++i) { + // Slice out part of the value. + value_starts[0] = i; + value_ends[0] = i + 1; + auto slice = b->Slice(value, value_starts, value_ends, value_strides); + + // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. + auto index = b->Slice(indices, {i}, {i + 1}, {1}); + auto start_indices = + b->Pad(b->Reshape(index, {1}), b->ConstantR0(0), + xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}})); + ta = DynamicAddSlice(b, ta, slice, slice_dims, start_indices); + } } resource->value = ta; diff --git a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc index b19ea22f50d2dd44e8d1d81f5930263f364030e1..68847ae7a2cb926edd9d29007e24b0db7fb5a75f 100644 --- a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" +#include "tensorflow/compiler/tf2xla/kernels/shape_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -22,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/kernels/no_op.h" namespace tensorflow { @@ -121,5 +123,26 @@ class ResourceGatherOp : public XlaOpKernel { REGISTER_XLA_OP(Name("ResourceGather").TypeConstraint("dtype", kNumericTypes), ResourceGatherOp); +class VariableShapeOp : public XlaOpKernel { + public: + explicit VariableShapeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("out_type", &out_dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + DataType variable_dtype; + TensorShape shape; + OP_REQUIRES_OK(ctx, + ctx->GetVariableTypeAndShape(0, &variable_dtype, &shape)); + Tensor shape_constant(out_dtype_, TensorShape({shape.dims()})); + OP_REQUIRES_OK(ctx, TensorShapeToConstant(shape, &shape_constant)); + ctx->SetConstantOutput(0, shape_constant); + } + + private: + DataType out_dtype_; +}; + +REGISTER_XLA_OP(Name("VariableShape"), VariableShapeOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..21ad21f73737a289390ed1ea767db1078d05b466 --- /dev/null +++ b/tensorflow/compiler/tf2xla/lib/BUILD @@ -0,0 +1,120 @@ +# Utilities for building XLA computations. + +licenses(["notice"]) # Apache 2.0 + +package( + default_visibility = ["//tensorflow/compiler/tf2xla:friends"], +) + +# Filegroup used to collect source files for dependency checking. +filegroup( + name = "c_srcs", + data = glob([ + "**/*.cc", + "**/*.h", + ]), +) + +load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test") + +cc_library( + name = "batch_dot", + srcs = ["batch_dot.cc"], + hdrs = ["batch_dot.h"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/client:computation", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "cholesky", + srcs = ["cholesky.cc"], + hdrs = ["cholesky.h"], + deps = [ + ":batch_dot", + ":triangular_solve", + ":util", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/client:computation", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "triangular_solve", + srcs = ["triangular_solve.cc"], + hdrs = ["triangular_solve.h"], + deps = [ + ":batch_dot", + ":util", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/client:computation", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/core:lib", + ], +) + +xla_test( + name = "triangular_solve_test", + srcs = ["triangular_solve_test.cc"], + deps = [ + ":triangular_solve", + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + +cc_library( + name = "util", + srcs = ["util.cc"], + hdrs = ["util.h"], + deps = [ + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/client:computation", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/core:lib", + ], +) + +# ----------------------------------------------------------------------------- + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.cc b/tensorflow/compiler/tf2xla/lib/batch_dot.cc new file mode 100644 index 0000000000000000000000000000000000000000..9b0e6174475c22e325c090bec5f1d56822e106bc --- /dev/null +++ b/tensorflow/compiler/tf2xla/lib/batch_dot.cc @@ -0,0 +1,118 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/lib/batch_dot.h" + +#include +#include + +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { + +// The current implementation simply unrolls the computation along the batch +// dimension. +xla::StatusOr BatchDot( + xla::ComputationBuilder* builder, xla::ComputationDataHandle x, + xla::ComputationDataHandle y, bool transpose_x, bool transpose_y) { + TF_ASSIGN_OR_RETURN(std::unique_ptr x_shape, + builder->GetShape(x)); + TF_ASSIGN_OR_RETURN(std::unique_ptr y_shape, + builder->GetShape(y)); + + // Check that both tensors have the same number of dimensions. There must be + // at least two (the batch dimensions can be empty). + if (xla::ShapeUtil::Rank(*x_shape) != xla::ShapeUtil::Rank(*y_shape)) { + return errors::InvalidArgument( + "Arguments to BatchedDot have different ranks: ", + xla::ShapeUtil::HumanString(*x_shape), " vs. ", + xla::ShapeUtil::HumanString(*y_shape)); + } + const int ndims = xla::ShapeUtil::Rank(*x_shape); + if (ndims < 2) { + return errors::InvalidArgument( + "Arguments to BatchedDot must have rank >= 2: ", ndims); + } + + // The batch dimensions must be equal and the matrix dimensions must be + // valid. + std::vector batch_dimension_numbers; + for (int i = 0; i < ndims - 2; ++i) { + if (x_shape->dimensions(i) != y_shape->dimensions(i)) { + return errors::InvalidArgument( + "Dimension ", i, " of inputs to BatchedDot must be equal: ", + xla::ShapeUtil::HumanString(*x_shape), " vs ", + xla::ShapeUtil::HumanString(*y_shape)); + } + batch_dimension_numbers.push_back(i); + } + + int x_inner_dim = transpose_x ? (ndims - 2) : (ndims - 1); + int y_inner_dim = transpose_y ? (ndims - 1) : (ndims - 2); + if (x_shape->dimensions(x_inner_dim) != y_shape->dimensions(y_inner_dim)) { + return errors::InvalidArgument( + "Dimensions ", x_inner_dim, " and ", y_inner_dim, + " of arguments to BatchedDot must be equal: ", + xla::ShapeUtil::HumanString(*x_shape), " transpose: ", transpose_x, + " vs. ", xla::ShapeUtil::HumanString(*y_shape), + " transpose: ", transpose_y); + } + + // Check for zero lhs/rhs dim size. + if (xla::ShapeUtil::HasZeroElements(*x_shape) || + xla::ShapeUtil::HasZeroElements(*y_shape)) { + std::vector dimensions(batch_dimension_numbers.size()); + for (int i = 0; i < batch_dimension_numbers.size(); ++i) { + dimensions[i] = x_shape->dimensions(batch_dimension_numbers[i]); + } + int x_outer_dim = transpose_x ? (ndims - 1) : (ndims - 2); + int y_outer_dim = transpose_y ? (ndims - 2) : (ndims - 1); + dimensions.push_back(x_shape->dimensions(x_outer_dim)); + dimensions.push_back(y_shape->dimensions(y_outer_dim)); + return builder->Broadcast( + builder->ConstantLiteral(xla::Literal::Zero(x_shape->element_type())), + dimensions); + } + + if (x_shape->element_type() == xla::C64 && transpose_x) { + x = builder->Conj(x); + } + if (y_shape->element_type() == xla::C64 && transpose_y) { + y = builder->Conj(y); + } + + // If there are no batch dimensions, use a regular Dot. + // TODO(b/69062148) Remove this code when Dot emitters can be passed + // dimensions to transpose directly (i.e. without requiring a Transpose HLO). + if (batch_dimension_numbers.empty()) { + auto lhs = transpose_x ? builder->Transpose(x, {1, 0}) : x; + auto rhs = transpose_y ? builder->Transpose(y, {1, 0}) : y; + return builder->Dot(lhs, rhs); + } + + xla::DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(x_inner_dim); + dot_dnums.add_rhs_contracting_dimensions(y_inner_dim); + for (auto batch_dimension_number : batch_dimension_numbers) { + dot_dnums.add_lhs_batch_dimensions(batch_dimension_number); + dot_dnums.add_rhs_batch_dimensions(batch_dimension_number); + } + return builder->DotGeneral(x, y, dot_dnums); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.h b/tensorflow/compiler/tf2xla/lib/batch_dot.h new file mode 100644 index 0000000000000000000000000000000000000000..b46bc7417d29dc5b7e9649ac28cc78b57d4b619c --- /dev/null +++ b/tensorflow/compiler/tf2xla/lib/batch_dot.h @@ -0,0 +1,51 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_BATCH_DOT_H_ +#define TENSORFLOW_COMPILER_TF2XLA_LIB_BATCH_DOT_H_ + +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" + +namespace tensorflow { + +// Multiplies slices of two tensors in batches. + +// Multiplies all slices of `Tensor` `x` and `y` (each slice can be +// viewed as an element of a batch), and arranges the individual results +// in a single output tensor of the same batch size. Each of the +// individual slices can optionally be transposed before multiplication by +// setting the `transpose_x` or `transpose_y` flag to `true`. +// +// The input tensors `x` and `y` are 2-D or higher with shape `[..., r_x, c_x]` +// and `[..., r_y, c_y]`. +// +// The output tensor is 2-D or higher with shape `[..., r_o, c_o]`, where: +// +// r_o = c_x if transpose_x else r_x +// c_o = r_y if transpose_y else c_y +// +// It is computed as: +// +// output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :]) +// TODO(phawkins): add an option to take the complex conjugate of the LHS or +// RHS. +xla::StatusOr BatchDot( + xla::ComputationBuilder* builder, xla::ComputationDataHandle x, + xla::ComputationDataHandle y, bool transpose_x, bool transpose_y); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_BATCH_DOT_H_ diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.cc b/tensorflow/compiler/tf2xla/lib/cholesky.cc new file mode 100644 index 0000000000000000000000000000000000000000..b3cc489adf6042acb3f56b3a0a6c8fbe43bde629 --- /dev/null +++ b/tensorflow/compiler/tf2xla/lib/cholesky.cc @@ -0,0 +1,166 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/lib/cholesky.h" + +#include +#include + +#include "tensorflow/compiler/tf2xla/lib/batch_dot.h" +#include "tensorflow/compiler/tf2xla/lib/triangular_solve.h" +#include "tensorflow/compiler/tf2xla/lib/util.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { + +namespace { + +// def cholesky_unblocked(a): +// assert len(a.shape) == 2 and a.shape[-2] == a.shape[-1] +// n = a.shape[-2] +// l = np.zeros_like(a) +// for j in xrange(n): +// r = l[..., j, :j] +// l[..., j, j] = np.sqrt(a[..., j, j] - np.dot(r, r)) +// l[..., j+1:, j] = (a[..., j+1:, j] - np.dot(l[..., j+1:, :j], +// np.transpose(r))) / l[..., j, j] +// return l +xla::StatusOr CholeskyUnblocked( + xla::ComputationBuilder* builder, const xla::ComputationDataHandle& a) { + TF_ASSIGN_OR_RETURN(std::unique_ptr shape, builder->GetShape(a)); + xla::ComputationDataHandle l = Zeros(builder, *shape); + const int64 n = xla::ShapeUtil::GetDimension(*shape, -2); + for (int j = 0; j < n; ++j) { + // Picture of block structure: + // ... \ + // \ + // -- r -- d + // |\ + // B c \ + // | \ + // | ... + // + // ^ + // column j + TF_ASSIGN_OR_RETURN(auto d, + SliceInMinorDims(builder, a, {j, j}, {j + 1, j + 1})); + TF_ASSIGN_OR_RETURN(auto c, + SliceInMinorDims(builder, a, {j + 1, j}, {n, j + 1})); + xla::ComputationDataHandle new_d_squared = d; + xla::ComputationDataHandle br; + if (j > 0) { + TF_ASSIGN_OR_RETURN(auto r, + SliceInMinorDims(builder, l, {j, 0}, {j + 1, j})); + TF_ASSIGN_OR_RETURN(auto b, + SliceInMinorDims(builder, l, {j + 1, 0}, {n, j})); + TF_ASSIGN_OR_RETURN(auto r_squared, + BatchDot(builder, r, r, /*transpose_x=*/false, + /*transpose_y=*/true)); + new_d_squared = builder->Sub(new_d_squared, r_squared); + + TF_ASSIGN_OR_RETURN(br, BatchDot(builder, b, r, /*transpose_x=*/false, + /*transpose_y=*/true)); + } + auto new_d_inv = builder->Pow( + new_d_squared, FloatLiteral(builder, shape->element_type(), -0.5)); + auto new_d = builder->Mul(new_d_inv, new_d_squared); + TF_ASSIGN_OR_RETURN(l, UpdateSliceInMinorDims(builder, l, new_d, {j, j})); + + if (j > 0) { + c = builder->Sub(c, br); + } + auto new_c = builder->Mul(c, new_d_inv); + TF_ASSIGN_OR_RETURN(l, + UpdateSliceInMinorDims(builder, l, new_c, {j + 1, j})); + } + return l; +} + +} // namespace + +xla::StatusOr Cholesky( + xla::ComputationBuilder* builder, xla::ComputationDataHandle a, + int64 block_size) { + TF_ASSIGN_OR_RETURN(std::unique_ptr a_shape, + builder->GetShape(a)); + const int ndims = xla::ShapeUtil::Rank(*a_shape); + if (ndims < 2) { + return errors::InvalidArgument( + "Arguments to Cholesky must have rank >= 2: ", ndims); + } + + const int64 n = xla::ShapeUtil::GetDimension(*a_shape, -1); + if (n != xla::ShapeUtil::GetDimension(*a_shape, -2)) { + return errors::InvalidArgument( + "Arguments to Cholesky must be square matrices: ", + xla::ShapeUtil::HumanString(*a_shape)); + } + + if (block_size < 1) { + return errors::InvalidArgument( + "block_size argument to Cholesky must be >= 1; got ", block_size); + } + + // Blocked left-looking Cholesky factorization. + // Algorithm 1 from + // Haidar, Azzam, et al. "High-performance Cholesky factorization for GPU-only + // execution." Proceedings of General Purpose GPUs. ACM, 2017. + xla::ComputationDataHandle l = Zeros(builder, *a_shape); + for (int64 i = 0; i < n; i += block_size) { + int64 k = std::min(block_size, n - i); + if (i > 0) { + // TODO(phawkins): consider implementing SYRK for the diagonal part of + // the panel. + // a[i:, i:i+k] -= np.dot(l[i:, :i], np.transpose(l[i:i+k, :i])) + TF_ASSIGN_OR_RETURN(auto lhs, + SliceInMinorDims(builder, l, {i, 0}, {n, i})); + TF_ASSIGN_OR_RETURN(auto rhs, + SliceInMinorDims(builder, l, {i, 0}, {i + k, i})); + TF_ASSIGN_OR_RETURN(auto delta, + BatchDot(builder, lhs, rhs, /*transpose_x=*/false, + /*transpose_y=*/true)); + TF_ASSIGN_OR_RETURN(auto before, + SliceInMinorDims(builder, a, {i, i}, {n, i + k})); + TF_ASSIGN_OR_RETURN( + a, UpdateSliceInMinorDims(builder, a, builder->Sub(before, delta), + {i, i})); + } + + // l[i:i+k, i:i+k] = cholesky_unblocked(a[i:i+k, i:i+k]) + TF_ASSIGN_OR_RETURN(auto x, + SliceInMinorDims(builder, a, {i, i}, {i + k, i + k})); + TF_ASSIGN_OR_RETURN(auto factorized, CholeskyUnblocked(builder, x)); + TF_ASSIGN_OR_RETURN(l, + UpdateSliceInMinorDims(builder, l, factorized, {i, i})); + + if (i + k < n) { + // l[i+k:, i:i+k] = trsm_right_transpose(l[i:i+k, i:i+k], a[i+k:, i:i+k]) + TF_ASSIGN_OR_RETURN(auto panel, + SliceInMinorDims(builder, a, {i + k, i}, {n, i + k})); + TF_ASSIGN_OR_RETURN(auto update, + TriangularSolve(builder, factorized, panel, + /*block_size=*/8)); + TF_ASSIGN_OR_RETURN( + l, UpdateSliceInMinorDims(builder, l, update, {i + k, i})); + } + } + return l; +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.h b/tensorflow/compiler/tf2xla/lib/cholesky.h new file mode 100644 index 0000000000000000000000000000000000000000..2bead7359baaf3582c1230adf0cd4a90046859d2 --- /dev/null +++ b/tensorflow/compiler/tf2xla/lib/cholesky.h @@ -0,0 +1,38 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_CHOLESKY_H_ +#define TENSORFLOW_COMPILER_TF2XLA_LIB_CHOLESKY_H_ + +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" + +namespace tensorflow { + +// Computes the Cholesky decompositions of a batch of symmetric positive +// definite matrices. +// `a` must be a (batched) square matrix; i.e., it must have rank >= 2 with the +// two minor dimensions equal. +// The algorithm implements a blocked Cholesky decomposition; `block_size` is +// the block size to use. +// TODO(phawkins): check for negative values on the diagonal and return an +// error, instead of silently yielding NaNs. +xla::StatusOr Cholesky( + xla::ComputationBuilder* builder, xla::ComputationDataHandle a, + int64 block_size = 256); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_CHOLESKY_H_ diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc new file mode 100644 index 0000000000000000000000000000000000000000..579944c3a381e7018b7fee5013d0509158ce21cc --- /dev/null +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc @@ -0,0 +1,175 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/lib/triangular_solve.h" + +#include +#include + +#include "tensorflow/compiler/tf2xla/lib/batch_dot.h" +#include "tensorflow/compiler/tf2xla/lib/util.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { + +xla::StatusOr TriangularSolve( + xla::ComputationBuilder* builder, const xla::ComputationDataHandle& a, + xla::ComputationDataHandle b, int64 block_size) { + TF_ASSIGN_OR_RETURN(std::unique_ptr a_shape, + builder->GetShape(a)); + TF_ASSIGN_OR_RETURN(std::unique_ptr b_shape, + builder->GetShape(b)); + if (xla::ShapeUtil::Rank(*a_shape) != xla::ShapeUtil::Rank(*b_shape)) { + return errors::InvalidArgument( + "Arguments to TriangularSolve have different ranks: ", + xla::ShapeUtil::HumanString(*a_shape), " vs. ", + xla::ShapeUtil::HumanString(*b_shape)); + } + const int ndims = xla::ShapeUtil::Rank(*a_shape); + if (ndims < 2) { + return errors::InvalidArgument( + "Arguments to TriangularSolve must have rank >= 2: ", ndims); + } + // The batch dimensions must be equal. + std::vector batch_dimensions; + for (int i = 0; i < ndims - 2; ++i) { + int64 a_size = a_shape->dimensions(i); + int64 b_size = b_shape->dimensions(i); + if (a_size != b_size) { + return errors::InvalidArgument( + "Batch dimensions of arguments to TriangularSolve must be equal: ", + xla::ShapeUtil::HumanString(*a_shape), " vs ", + xla::ShapeUtil::HumanString(*b_shape)); + } + batch_dimensions.push_back(a_size); + } + + const int64 n = xla::ShapeUtil::GetDimension(*a_shape, -1); + const int64 m = xla::ShapeUtil::GetDimension(*b_shape, -2); + if (n != xla::ShapeUtil::GetDimension(*a_shape, -2)) { + return errors::InvalidArgument( + "The 'a' arguments to TriangularSolve must be square matrices: ", + xla::ShapeUtil::HumanString(*a_shape)); + } + if (n != xla::ShapeUtil::GetDimension(*b_shape, -1)) { + return errors::InvalidArgument( + "Arguments to TriangularSolve have incompatible matrix shapes: ", + xla::ShapeUtil::HumanString(*a_shape), " vs ", + xla::ShapeUtil::HumanString(*b_shape)); + } + + if (block_size < 1) { + return errors::InvalidArgument( + "block_size argument to TriangularSolve must be >= 1; got ", + block_size); + } + + // Returns [b1, b2, ... , bn, indices[0], indices[1]]. + auto prepend_batch_dims = [&](std::array indices) { + std::vector output(ndims); + std::copy(batch_dimensions.begin(), batch_dimensions.end(), output.begin()); + std::copy(indices.begin(), indices.end(), + output.begin() + batch_dimensions.size()); + return output; + }; + + std::map base_computations; + auto get_base_triangular_solve = + [&](int k) -> xla::StatusOr { + xla::Computation& computation = base_computations[k]; + if (computation.IsNull()) { + std::unique_ptr sub = builder->CreateSubBuilder( + tensorflow::strings::StrCat("trsm_base_", k)); + + auto a_param = + sub->Parameter(0, + xla::ShapeUtil::MakeShape(b_shape->element_type(), + prepend_batch_dims({k, k})), + "a"); + + auto b_param = + sub->Parameter(1, + xla::ShapeUtil::MakeShape(b_shape->element_type(), + prepend_batch_dims({m, k})), + "b"); + + // TODO(phawkins): it might make sense to use a while loop here, rather + // than unrolling. + // TODO(phawkins): the left-looking variant of the algorithm might be more + // efficient at block size 1. + TF_RETURN_IF_ERROR(TriangularSolve(sub.get(), a_param, b_param, + /*block_size=*/1) + .status()); + + TF_ASSIGN_OR_RETURN(computation, sub->Build()); + } + return &computation; + }; + + xla::ComputationDataHandle output = Zeros(builder, *b_shape); + + // Right-looking blocked triangular solve. + // For an explanation of the algorithm, see the TRSM discussion in: + // Goto, Kazushige, and Robert Van De Geijn. "High-performance implementation + // of the level-3 BLAS." ACM Transactions on Mathematical Software (TOMS) 35.1 + // (2008): 4. + for (int64 i = 0; i < n; i += block_size) { + int64 k = std::min(block_size, n - i); + + // if k > 1: + // output[..., :, i:i+k] = triangular_solve( + // a[..., i:i+k, ..., i:i+k], b[..., :, i:i+k], side='Right', + // kind='Lower', transpose=True, block_size=1) + // else: + // output[..., :, i] = b[..., :, i] / a[..., i, i] + TF_ASSIGN_OR_RETURN(auto a_slice, + SliceInMinorDims(builder, a, {i, i}, {i + k, i + k})); + TF_ASSIGN_OR_RETURN(auto b_slice, + SliceInMinorDims(builder, b, {0, i}, {m, i + k})); + xla::ComputationDataHandle update; + if (k > 1) { + TF_ASSIGN_OR_RETURN(xla::Computation * solve, + get_base_triangular_solve(k)); + update = builder->Call(*solve, {a_slice, b_slice}); + } else { + update = builder->Div(b_slice, a_slice); + } + + TF_ASSIGN_OR_RETURN( + output, UpdateSliceInMinorDims(builder, output, update, {0, i})); + // b[..., :, i+k:] -= np.dot(output[..., :, i:i+k], + // np.transpose(..., a[i+k:, i:i+k])) + if (i + k < n) { + TF_ASSIGN_OR_RETURN(auto a_slice_2, + SliceInMinorDims(builder, a, {i + k, i}, {n, i + k})); + TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(builder, update, a_slice_2, + /*transpose_x=*/false, + /*transpose_y=*/true)); + + TF_ASSIGN_OR_RETURN(auto b_slice_2, + SliceInMinorDims(builder, b, {0, i + k}, {m, n})); + b_update = builder->Sub(b_slice_2, b_update); + TF_ASSIGN_OR_RETURN( + b, UpdateSliceInMinorDims(builder, b, b_update, {0, i + k})); + } + } + return output; +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.h b/tensorflow/compiler/tf2xla/lib/triangular_solve.h new file mode 100644 index 0000000000000000000000000000000000000000..501d026411c80359c7efa406ece5929a2e46ac1f --- /dev/null +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.h @@ -0,0 +1,46 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_TRIANGULAR_SOLVE_H_ +#define TENSORFLOW_COMPILER_TF2XLA_LIB_TRIANGULAR_SOLVE_H_ + +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" + +namespace tensorflow { + +// Solves systems of linear equations with upper or lower triangular matrices by +// backsubstitution. +// +// `a` is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions form +// square matrices. The strictly upper triangular part of each inner-most matrix +// is assumed to be zero and not accessed. +// `b` is a tensor of shape `[..., M, K]`. +// +// The innermost matrices in the output satisfy matrix equations +// `output[..., i, j] * adjoint(a[..., k, j]) = b[..., i, k]`. +// +// Uses a blocked algorithm if `block_size` is > 1; if block_size == 1 then no +// blocking is used. +// TODO(phawkins): equivalent to the BLAS TRSM routine with side=right, +// kind=lower, and transposed_a=true. Implement the other possible combinations +// of side, kind and transposed_a. +xla::StatusOr TriangularSolve( + xla::ComputationBuilder* builder, const xla::ComputationDataHandle& a, + xla::ComputationDataHandle b, int64 block_size = 256); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_TRIANGULAR_SOLVE_H_ diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..671d9aa4fe0c042a3cc44468074653d51c2be75d --- /dev/null +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc @@ -0,0 +1,69 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/lib/triangular_solve.h" + +#include +#include +#include + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace tensorflow { +namespace { + +using TriangularSolveTest = xla::ClientLibraryTestBase; + +XLA_TEST_F(TriangularSolveTest, Simple) { + xla::ComputationBuilder builder(client_, TestName()); + + xla::Array2D a_vals({ + {2, 0, 0, 0}, + {3, 6, 0, 0}, + {4, 7, 9, 0}, + {5, 8, 10, 11}, + }); + xla::Array2D b_vals({ + {1, 2, 3, 4}, + {5, 6, 7, 8}, + {9, 10, 11, 12}, + }); + + xla::ComputationDataHandle a, b; + auto a_data = CreateR2Parameter(a_vals, 0, "a", &builder, &a); + auto b_data = CreateR2Parameter(b_vals, 1, "b", &builder, &b); + auto result = TriangularSolve(&builder, a, b, /*block_size=*/2); + TF_ASSERT_OK(result.status()); + + xla::Array2D expected({ + {0.5, 0.08333334, 0.04629629, 0.03367003}, + {2.5, -0.25, -0.1388889, -0.1010101}, + {4.5, -0.58333331, -0.32407406, -0.23569024}, + }); + + ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, + xla::ErrorSpec(2e-3, 2e-3)); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/util.cc b/tensorflow/compiler/tf2xla/lib/util.cc new file mode 100644 index 0000000000000000000000000000000000000000..943248aedbdce5e81baa341fdab82fea9a48302d --- /dev/null +++ b/tensorflow/compiler/tf2xla/lib/util.cc @@ -0,0 +1,110 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/lib/util.h" + +#include +#include + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { + +xla::ComputationDataHandle Zeros(xla::ComputationBuilder* builder, + xla::Shape& shape) { + return builder->Broadcast( + builder->ConstantLiteral(xla::Literal::Zero(shape.element_type())), + xla::AsInt64Slice(shape.dimensions())); +} + +xla::ComputationDataHandle FloatLiteral(xla::ComputationBuilder* builder, + xla::PrimitiveType type, double value) { + switch (type) { + case xla::F16: + return builder->ConstantR0(static_cast(value)); + break; + case xla::BF16: + return builder->ConstantR0(static_cast(value)); + break; + case xla::F32: + return builder->ConstantR0(static_cast(value)); + break; + case xla::F64: + return builder->ConstantR0(value); + break; + case xla::C64: + return builder->ConstantR0(value); + break; + default: + LOG(FATAL) << "unhandled element type " << type; + } +} + +xla::StatusOr SliceInMinorDims( + xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, + gtl::ArraySlice start, gtl::ArraySlice end) { + TF_RET_CHECK(start.size() == end.size()); + int64 n_minor_dims = start.size(); + + TF_ASSIGN_OR_RETURN(std::unique_ptr shape, builder->GetShape(x)); + + const int64 n_dims = xla::ShapeUtil::Rank(*shape); + TF_RET_CHECK(n_minor_dims <= n_dims); + gtl::ArraySlice major_dims(xla::AsInt64Slice(shape->dimensions()), + /*pos=*/0, + /*len=*/n_dims - n_minor_dims); + + // Prepends 0s in the major dim + std::vector padded_start(n_dims, 0); + std::copy(start.begin(), start.end(), + padded_start.begin() + major_dims.size()); + + // Prepends the shape of the major dims. + std::vector padded_end(n_dims); + std::copy(major_dims.begin(), major_dims.end(), padded_end.begin()); + std::copy(end.begin(), end.end(), padded_end.begin() + major_dims.size()); + + std::vector strides(n_dims, 1); + return builder->Slice(x, padded_start, padded_end, strides); +} + +xla::StatusOr UpdateSlice( + xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, + const xla::ComputationDataHandle& update, gtl::ArraySlice start) { + // TODO(phawkins): make int64 work on all backends, remove the int32 cast. + std::vector start_as_int32(start.begin(), start.end()); + return builder->DynamicUpdateSlice( + x, update, builder->ConstantR1(start_as_int32)); +} + +xla::StatusOr UpdateSliceInMinorDims( + xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, + const xla::ComputationDataHandle& update, gtl::ArraySlice start) { + TF_ASSIGN_OR_RETURN(std::unique_ptr shape, builder->GetShape(x)); + const int64 n_dims = xla::ShapeUtil::Rank(*shape); + const int64 n_minor_dims = start.size(); + TF_RET_CHECK(n_minor_dims <= n_dims); + std::vector padded_start(n_dims, 0); + std::copy(start.begin(), start.end(), + padded_start.begin() + (n_dims - n_minor_dims)); + return UpdateSlice(builder, x, update, padded_start); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/util.h b/tensorflow/compiler/tf2xla/lib/util.h new file mode 100644 index 0000000000000000000000000000000000000000..8fba6b5cf247e9b2c26533c53ece8b0d7d4f4c36 --- /dev/null +++ b/tensorflow/compiler/tf2xla/lib/util.h @@ -0,0 +1,54 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_UTIL_H_ +#define TENSORFLOW_COMPILER_TF2XLA_LIB_UTIL_H_ + +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/lib/gtl/array_slice.h" + +namespace tensorflow { + +// Returns a zero-filled tensor with shape `shape`. +xla::ComputationDataHandle Zeros(xla::ComputationBuilder* builder, + xla::Shape& shape); + +// Returns a floating point scalar constant of 'type' with 'value'. +// If 'type' is complex, returns a real value with zero imaginary component. +xla::ComputationDataHandle FloatLiteral(xla::ComputationBuilder* builder, + xla::PrimitiveType type, double value); + +// Performs a slice in the minor dimensions of a Tensor. +xla::StatusOr SliceInMinorDims( + xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, + gtl::ArraySlice start, gtl::ArraySlice end); + +// Updates a slice of 'x', i.e., +// x[start[0], ..., start[n]] = update +xla::StatusOr UpdateSlice( + xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, + const xla::ComputationDataHandle& update, gtl::ArraySlice start); + +// Updates a slice of 'x', where 'start' contains a list of minor dimensions: +// x[..., start[0], ..., start[n]] = update +xla::StatusOr UpdateSliceInMinorDims( + xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, + const xla::ComputationDataHandle& update, gtl::ArraySlice start); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/sharding_util.cc b/tensorflow/compiler/tf2xla/sharding_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..b08a7583cb5ab7efa30a1fa27b973d04992584a7 --- /dev/null +++ b/tensorflow/compiler/tf2xla/sharding_util.cc @@ -0,0 +1,111 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/tf2xla/sharding_util.h" + +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/util/device_name_utils.h" + +namespace tensorflow { +namespace { +const char kDeviceSuffixReplicatedCore[] = "REPLICATED_CORE"; +const char kShardingAttribute[] = "_XlaSharding"; +} // namespace + +namespace { +xla::StatusOr> +GetShardingFromNodeDef(const NodeDef& node_def) { + if (!HasNodeAttr(node_def, kShardingAttribute)) { + return tensorflow::gtl::optional(); + } + string value; + xla::OpSharding sharding; + TF_RETURN_IF_ERROR(GetNodeAttr(node_def, kShardingAttribute, &value)); + if (!sharding.ParseFromString(value)) { + return xla::InvalidArgument( + "Experimental _XlaSharding attribute was not a valid encoded " + "xla::OpSharding proto."); + } + return tensorflow::gtl::optional(sharding); +} + +Status CoreOutOfRangeError(int core, int num_cores_per_replica) { + return errors::InvalidArgument( + "Invalid replicated core id: ", core, + "; num_cores_per_replica=", num_cores_per_replica); +} +} // namespace + +xla::StatusOr> +ParseShardingFromDevice( + const string& device_name, int num_cores_per_replica, + tensorflow::gtl::optional explicit_sharding) { + if (device_name.empty()) { + return tensorflow::gtl::optional(); + } + DeviceNameUtils::ParsedName parsed_device; + if (!DeviceNameUtils::ParseFullName(device_name, &parsed_device)) { + return errors::InvalidArgument("Malformed assigned device '", device_name, + "'"); + } + + if (explicit_sharding.has_value()) { + return explicit_sharding; + } else if (!parsed_device.has_type || !parsed_device.has_id || + !StringPiece(parsed_device.type) + .contains(kDeviceSuffixReplicatedCore)) { + return tensorflow::gtl::optional(); + } else { + const int core = parsed_device.id; + if (core < 0 || core >= num_cores_per_replica) { + return CoreOutOfRangeError(core, num_cores_per_replica); + } + return tensorflow::gtl::optional( + xla::ShardingBuilder::AssignDevice(core)); + } +} + +xla::StatusOr> +ParseShardingFromDevice(const NodeDef& node_def, int num_cores_per_replica) { + const string& device_name = node_def.device(); + TF_ASSIGN_OR_RETURN(tensorflow::gtl::optional sharding, + GetShardingFromNodeDef(node_def)); + return ParseShardingFromDevice(device_name, num_cores_per_replica, sharding); +} + +xla::StatusOr> +ParseShardingFromDevice(const Node& node, int num_cores_per_replica) { + string device_name = node.assigned_device_name(); + if (device_name.empty()) { + device_name = node.requested_device(); + } + TF_ASSIGN_OR_RETURN(tensorflow::gtl::optional sharding, + GetShardingFromNodeDef(node.def())); + return ParseShardingFromDevice(device_name, num_cores_per_replica, sharding); +} + +void SetShardingDeviceAssignmentFromNode(const Node& src, Node* dst) { + string device_name = src.assigned_device_name(); + if (device_name.empty()) { + device_name = src.requested_device(); + } + dst->set_assigned_device_name(device_name); + if (const AttrValue* attr = src.attrs().Find(kShardingAttribute)) { + dst->AddAttr(kShardingAttribute, *attr); + } +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/sharding_util.h b/tensorflow/compiler/tf2xla/sharding_util.h new file mode 100644 index 0000000000000000000000000000000000000000..9e430e30a1247c7d01910b6d57f7c577964e1dd1 --- /dev/null +++ b/tensorflow/compiler/tf2xla/sharding_util.h @@ -0,0 +1,51 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_TF2XLA_TPU_UTIL_H_ +#define TENSORFLOW_COMPILER_TF2XLA_TPU_UTIL_H_ + +#include + +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +// Parses the op sharding from the 'replicated core' device_name . +// Returns an error: +// - if the device name is invalid. +// - the core is parsed and is out of the range [0, num_cores_per_replica). +// +// Otherwise, returns either: +// - explicit_sharding if explicit_sharding.has_value() +// - a non-value if there is no assigned core or +// - a sharding set as per xla::ShardingBuilder::AssignDevice. +xla::StatusOr> +ParseShardingFromDevice(const string& device_name, int num_cores_per_replica, + tensorflow::gtl::optional + explicit_sharding = tensorflow::gtl::nullopt); + +xla::StatusOr> +ParseShardingFromDevice(const Node& node, int num_cores_per_replica); + +xla::StatusOr> +ParseShardingFromDevice(const NodeDef& node_def, int num_cores_per_replica); + +void SetShardingDeviceAssignmentFromNode(const Node& src, Node* dst); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_TPU_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/sharding_util_test.cc b/tensorflow/compiler/tf2xla/sharding_util_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..bff5978237a827cb9650541f2cf6984d9e846796 --- /dev/null +++ b/tensorflow/compiler/tf2xla/sharding_util_test.cc @@ -0,0 +1,58 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/tf2xla/sharding_util.h" + +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +TEST(CoreUtilTest, ParseShardingFromDevice) { + Graph graph(OpRegistry::Global()); + + auto core_from_sharding = + [](tensorflow::gtl::optional sharding) -> int64 { + if (sharding.has_value() && + sharding.value().type() == + xla::OpSharding::Type::OpSharding_Type_MAXIMAL) { + return sharding.value().tile_assignment_devices(0); + } else { + return -1; + } + }; + + auto parse_status = ParseShardingFromDevice("", 1); + TF_EXPECT_OK(parse_status.status()); + EXPECT_EQ(-1, core_from_sharding(parse_status.ValueOrDie())); + parse_status = ParseShardingFromDevice("", 100); + TF_EXPECT_OK(parse_status.status()); + EXPECT_EQ(-1, core_from_sharding(parse_status.ValueOrDie())); + + parse_status = ParseShardingFromDevice("/device:A_REPLICATED_CORE:-1", 100); + EXPECT_FALSE(parse_status.ok()); + + parse_status = ParseShardingFromDevice("/device:A_REPLICATED_CORE:55", 100); + TF_EXPECT_OK(parse_status.status()); + EXPECT_EQ(55, core_from_sharding(parse_status.ValueOrDie())); + + parse_status = ParseShardingFromDevice("/device:A_REPLICATED_CORE:100", 100); + EXPECT_FALSE(parse_status.ok()); + + parse_status = ParseShardingFromDevice("/cpu:0", 100); + TF_EXPECT_OK(parse_status.status()); + EXPECT_EQ(-1, core_from_sharding(parse_status.ValueOrDie())); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc index a14c93a2b9494b89f579bc20ee0510c136f8f01b..906f2290433face4cce3296b2f815d50d8c496ce 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.cc +++ b/tensorflow/compiler/tf2xla/tf2xla.cc @@ -253,8 +253,7 @@ Status CreateXlaArgs(const Graph& graph, // Converts the TensorFlow graph into an XLA computation, by executing the // graph symbolically, with each op building up the XLA HLO. Status ConvertGraphToXla(std::unique_ptr graph, xla::Client* client, - xla::Computation* computation, - bool* requires_runtime_context) { + xla::Computation* computation) { XlaOpRegistry::RegisterCompilationKernels(); for (Node* node : graph->nodes()) { node->set_assigned_device_name( @@ -277,7 +276,6 @@ Status ConvertGraphToXla(std::unique_ptr graph, xla::Client* client, TF_RETURN_IF_ERROR(compiler.CompileGraph(XlaCompiler::CompileOptions(), "tfcompile", std::move(graph), xla_args, &result)); - *requires_runtime_context = result.requires_runtime_context; *computation = std::move(*result.computation); int num_const_results = 0; @@ -352,12 +350,10 @@ Status InitGraph(const GraphDef& graph_def, const tf2xla::Config& config, Status ConvertGraphDefToXla(const GraphDef& graph_def, const tf2xla::Config& config, xla::Client* client, - xla::Computation* computation, - bool* requires_runtime_context) { + xla::Computation* computation) { std::unique_ptr graph; TF_RETURN_IF_ERROR(InitGraph(graph_def, config, &graph)); - TF_RETURN_IF_ERROR(ConvertGraphToXla(std::move(graph), client, computation, - requires_runtime_context)); + TF_RETURN_IF_ERROR(ConvertGraphToXla(std::move(graph), client, computation)); return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/tf2xla.h b/tensorflow/compiler/tf2xla/tf2xla.h index ab99beebf7946237425d4d304a858ac6817177b8..473c431b12d441c652f1d0d6c11c5e87836ab36d 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.h +++ b/tensorflow/compiler/tf2xla/tf2xla.h @@ -30,13 +30,9 @@ namespace tensorflow { // // The computation is built in the context of the given `client`, which may // subsequently be used to compile or execute the computation. -// -// If `requires_runtime_context` is filled with true, this indicates the last -// argument of the computation is XlaLocalRuntimeContext*. Status ConvertGraphDefToXla(const GraphDef& graph_def, const tf2xla::Config& config, xla::Client* client, - xla::Computation* computation, - bool* requires_runtime_context); + xla::Computation* computation); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/tf2xla_supported_ops.cc b/tensorflow/compiler/tf2xla/tf2xla_supported_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..7aca889a266439538c4cd1c153460e6cc871b246 --- /dev/null +++ b/tensorflow/compiler/tf2xla/tf2xla_supported_ops.cc @@ -0,0 +1,97 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/tf2xla_supported_ops.h" + +#include +#include +#include +#include + +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/framework/kernel_def.pb.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace tensorflow { +namespace tf2xla { +namespace { + +void PrintSupportedOps(const string& device, const string& regen_run) { + XlaOpRegistry::RegisterCompilationKernels(); + + std::vector kdefs = + XlaOpRegistry::DeviceKernels(device, + /*include_compilation_only_kernels=*/true); + std::sort( + kdefs.begin(), kdefs.end(), + [](const KernelDef* a, const KernelDef* b) { return a->op() < b->op(); }); + + std::cout << "**Supported operators for device: " << device << "**\n\n" + << "Operator | Type Constraint\n" + << "-------- | ---------------" << std::endl; + for (const KernelDef* kdef : kdefs) { + std::vector constraints; + for (const KernelDef::AttrConstraint& constraint : kdef->constraint()) { + std::vector types; + for (int type : constraint.allowed_values().list().type()) { + types.push_back(DataTypeString(static_cast(type))); + } + std::sort(types.begin(), types.end()); + constraints.push_back("`" + constraint.name() + "={" + + str_util::Join(types, ",") + "}`"); + } + std::cout << "`" << kdef->op() << "` | " + << str_util::Join(constraints, "
") << std::endl; + } + + std::cout << "\nTo regenerate this table, run:\n\n```shell\n" + << regen_run << " --device=" << device << "\n```" << std::endl; +} + +} // namespace + +void SupportedOpsMain(int argc, char** argv, const char* regen_run) { + std::vector device_names = XlaOpRegistry::BackendNames(); + std::sort(device_names.begin(), device_names.end()); + + // Set up and parse flags. + string device; + std::vector flag_list = { + {"device", &device, + "Name of the compilation device for which to print supported ops, " + "one of: " + + str_util::Join(device_names, ",")}, + }; + string usage = Flags::Usage(argv[0], flag_list); + bool parsed_flags_ok = Flags::Parse(&argc, argv, flag_list); + QCHECK(parsed_flags_ok) << "\n" << usage; + QCHECK(XlaOpRegistry::IsBackendRegistered(device)) + << "\nUnknown device: " << device << "\n" + << usage; + + // Run the program. + port::InitMain(usage.c_str(), &argc, &argv); + QCHECK(argc == 1) << "\nERROR: This command does not take any arguments " + "other than flags\n\n" + << usage; + PrintSupportedOps(device, regen_run); +} + +} // namespace tf2xla +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/tf2xla_supported_ops.h b/tensorflow/compiler/tf2xla/tf2xla_supported_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..1b45fb4cdd3b0173b04e130b7416874a9a406dc5 --- /dev/null +++ b/tensorflow/compiler/tf2xla/tf2xla_supported_ops.h @@ -0,0 +1,33 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_TF2XLA_SUPPORTED_OPS_H_ +#define TENSORFLOW_COMPILER_TF2XLA_TF2XLA_SUPPORTED_OPS_H_ + +namespace tensorflow { +namespace tf2xla { + +// The implementation of a main function for a binary that prints a table of +// supported tf2xla operators for a given device, along with their type +// constraints, to stdout. +// +// Pass the argc and argv from main, unmodified. Use regen_run to specify the +// command used to regenerate the table. +void SupportedOpsMain(int argc, char** argv, const char* regen_run); + +} // namespace tf2xla +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_TF2XLA_SUPPORTED_OPS_H_ diff --git a/tensorflow/compiler/tf2xla/tf2xla_supported_ops_main.cc b/tensorflow/compiler/tf2xla/tf2xla_supported_ops_main.cc new file mode 100644 index 0000000000000000000000000000000000000000..690666c2400d45e33c1a5d1818b68a86a70a5be3 --- /dev/null +++ b/tensorflow/compiler/tf2xla/tf2xla_supported_ops_main.cc @@ -0,0 +1,22 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/tf2xla_supported_ops.h" + +int main(int argc, char** argv) { + const char* regen_run = + "bazel run -c opt -- tensorflow/compiler/tf2xla:tf2xla_supported_ops"; + tensorflow::tf2xla::SupportedOpsMain(argc, argv, regen_run); +} diff --git a/tensorflow/compiler/tf2xla/tf2xla_test.cc b/tensorflow/compiler/tf2xla/tf2xla_test.cc index 51ce17deb62117ff8c1075160d0bebe6cf1438f1..a9978e697b091715ce120f0d18fdddd259e08b32 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_test.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_test.cc @@ -70,10 +70,7 @@ TEST(ConvertGraphDefToXla, Sum) { xla::LocalClient* client = xla::ClientLibrary::LocalClientOrDie(); xla::Computation computation; - bool requires_runtime_context; - TF_EXPECT_OK(ConvertGraphDefToXla(graph_def, config, client, &computation, - &requires_runtime_context)); - ASSERT_FALSE(requires_runtime_context); + TF_EXPECT_OK(ConvertGraphDefToXla(graph_def, config, client, &computation)); // Set up arguments. auto x_literal = xla::Literal::CreateR0(10); @@ -92,7 +89,7 @@ TEST(ConvertGraphDefToXla, Sum) { client->ExecuteAndTransfer(computation, {x_global.get(), y_global.get()}); TF_EXPECT_OK(result_or.status()); std::unique_ptr result = std::move(result_or.ValueOrDie()); - EXPECT_EQ("(s32[]) (\n42,\n)", result->ToString()); + EXPECT_EQ("(s32[]) (\n42\n)", result->ToString()); } } // namespace diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.cc b/tensorflow/compiler/tf2xla/tf2xla_util.cc index 14e0910cab2c3aa329fe798d199454fd6c5ee6a5..55f2f3149c6ba7bfa18608f961c8a76103a50756 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util.cc @@ -19,7 +19,9 @@ limitations under the License. #include #include +#include "tensorflow/compiler/tf2xla/sharding_util.h" #include "tensorflow/compiler/tf2xla/tf2xla.pb.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/graph_def_util.h" #include "tensorflow/core/framework/node_def.pb.h" @@ -29,6 +31,7 @@ limitations under the License. #include "tensorflow/core/graph/tensor_id.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/lib/strings/strcat.h" namespace tensorflow { @@ -250,4 +253,32 @@ string TensorIdToString(const tf2xla::TensorId& id) { return strings::StrCat(id.node_name(), ":", id.output_index()); } +Status SetNodeShardingFromNeighbors(Node* n, bool out_edges) { + int core = -1; + const Node* matching_node = nullptr; + for (const Edge* edge : (out_edges ? n->out_edges() : n->in_edges())) { + if (edge->IsControlEdge()) continue; + const Node* possible_match = out_edges ? edge->dst() : edge->src(); + TF_ASSIGN_OR_RETURN( + tensorflow::gtl::optional sharding, + ParseShardingFromDevice( + *possible_match, + /*num_cores_per_replica=*/std::numeric_limits::max())); + if (sharding.has_value()) { + TF_RET_CHECK(sharding.value().type() == + xla::OpSharding::Type::OpSharding_Type_MAXIMAL); + const int core_annotation = sharding.value().tile_assignment_devices(0); + if (core == -1 || core > core_annotation) { + core = core_annotation; + matching_node = possible_match; + } + } + } + if (matching_node != nullptr) { + n->set_assigned_device_name(matching_node->assigned_device_name()); + n->set_requested_device(matching_node->requested_device()); + } + return Status::OK(); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.h b/tensorflow/compiler/tf2xla/tf2xla_util.h index a29d0c16f9cfde3c97bfa9cf3165890f83939a43..e5fba8ede7745febbb42c572a7b52247213afc95 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.h +++ b/tensorflow/compiler/tf2xla/tf2xla_util.h @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/tf2xla.pb.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/op.h" +#include "tensorflow/core/graph/graph.h" #include "tensorflow/core/lib/core/status.h" namespace tensorflow { @@ -45,6 +46,11 @@ Status PruneGraphDefInto(const tf2xla::Config& config, const GraphDef& in, // Returns node:port for the given . string TensorIdToString(const tf2xla::TensorId& id); +// Updates the sharding of based on the sharding of its neighbors. +// If is true, outgoing edges from are considered; else incoming +// edges are considered. +Status SetNodeShardingFromNeighbors(Node* n, bool out_edges); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/tf2xla_util_test.cc b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc index b98c89f284d6a2bfc6d043794a580e60da93617f..436039e154842443f779aba276bc571fc2ab7537 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util_test.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc @@ -15,7 +15,13 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/tf2xla_util.h" +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/ops/data_flow_ops.h" +#include "tensorflow/cc/ops/function_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/tf2xla/sharding_util.h" #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/graph/graph.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/stringpiece.h" @@ -211,5 +217,52 @@ TEST(PruneGraphDefInto, Basic) { EXPECT_EQ(def.DebugString(), copy.DebugString()); } +TEST(SetNodeShardingFromNeighbors, Basic) { + // Builds a graph that adds two Tensors. + Scope scope = Scope::NewRootScope().ExitOnError(); + auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0); + auto b = ops::_Arg(scope.WithOpName("B"), DT_INT32, 1); + auto c = ops::Add(scope.WithOpName("C"), a, b); + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + + Node* a_node = nullptr; + Node* b_node = nullptr; + Node* c_node = nullptr; + for (Node* n : graph->nodes()) { + if (n->name() == "A") a_node = n; + if (n->name() == "B") b_node = n; + if (n->name() == "C") c_node = n; + } + + const int num_cores_per_replica = 4; + + a_node->set_assigned_device_name("foo"); + EXPECT_FALSE(SetNodeShardingFromNeighbors(c_node, /*out_edges=*/false).ok()); + + // Test where one input to c_node has a device. + a_node->set_assigned_device_name("/device:TPU_REPLICATED_CORE:2"); + TF_ASSERT_OK(SetNodeShardingFromNeighbors(c_node, /*out_edges=*/false)); + auto parse_status = ParseShardingFromDevice(*c_node, num_cores_per_replica); + TF_ASSERT_OK(parse_status.status()); + ASSERT_TRUE(parse_status.ValueOrDie().has_value()); + EXPECT_EQ(2, parse_status.ValueOrDie().value().tile_assignment_devices(0)); + + // Test where two inputs to c_node have a device. + b_node->set_assigned_device_name("/device:TPU_REPLICATED_CORE:1"); + TF_ASSERT_OK(SetNodeShardingFromNeighbors(c_node, /*out_edges=*/false)); + parse_status = ParseShardingFromDevice(*c_node, num_cores_per_replica); + TF_ASSERT_OK(parse_status.status()); + ASSERT_TRUE(parse_status.ValueOrDie().has_value()); + EXPECT_EQ(1, parse_status.ValueOrDie().value().tile_assignment_devices(0)); + + // Test setting based on out edges. + TF_ASSERT_OK(SetNodeShardingFromNeighbors(a_node, /*out_edges=*/true)); + parse_status = ParseShardingFromDevice(*a_node, num_cores_per_replica); + TF_ASSERT_OK(parse_status.status()); + ASSERT_TRUE(parse_status.ValueOrDie().has_value()); + EXPECT_EQ(1, parse_status.ValueOrDie().value().tile_assignment_devices(0)); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/type_util.cc b/tensorflow/compiler/tf2xla/type_util.cc index 1efbe0ffb17dad5332aa700b2e255d4a99fbef72..c969212a1bfaa6cab0d896ee074cfd4e2b283ae4 100644 --- a/tensorflow/compiler/tf2xla/type_util.cc +++ b/tensorflow/compiler/tf2xla/type_util.cc @@ -49,6 +49,9 @@ Status DataTypeToPrimitiveType(DataType data_type, xla::PrimitiveType* type) { case tensorflow::DT_UINT64: *type = xla::U64; return Status::OK(); + case tensorflow::DT_BFLOAT16: + *type = xla::BF16; + return Status::OK(); case tensorflow::DT_HALF: *type = xla::F16; return Status::OK(); diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.cc b/tensorflow/compiler/tf2xla/xla_compilation_device.cc index fc866a4c0a34712dc3906fb60c13a30909ecffd2..cc459dc87c00f19230c65341d53da213e07fe364 100644 --- a/tensorflow/compiler/tf2xla/xla_compilation_device.cc +++ b/tensorflow/compiler/tf2xla/xla_compilation_device.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/sharding_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/core/common_runtime/local_device.h" @@ -77,7 +78,8 @@ XlaCompilationDevice::XlaCompilationDevice(const SessionOptions& options, : LocalDevice( options, Device::BuildDeviceAttributes( - "", type, Bytes(256 << 20), DeviceLocality(), + strings::StrCat("/device:", type.type(), ":0"), type, + Bytes(256 << 20), DeviceLocality(), strings::StrCat("device: XLA compilation device ", type.type()))), allocator_(new XlaCompilationAllocator()) {} @@ -97,23 +99,19 @@ void XlaCompilationDevice::Compute(OpKernel* op_kernel, metadata.set_op_name(op_kernel->name()); b->SetOpMetadata(metadata); - DeviceNameUtils::ParsedName parsed; - OP_REQUIRES( - context, - DeviceNameUtils::ParseFullName(op_kernel->requested_device(), &parsed), - errors::Internal("Unable to parse device name: ", - op_kernel->requested_device())); - // If no device ID assignment is found, XLA is free to use whatever device it - // wants. In practice this usually has the effect of placing things on - // device 0. - if (parsed.has_id) { - b->SetSharding(xla::ShardingBuilder::AssignDevice(parsed.id)); - } + auto sharding_parse_result = ParseShardingFromDevice( + op_kernel->def(), std::numeric_limits::max()); + OP_REQUIRES_OK(context, sharding_parse_result.status()); + tensorflow::gtl::optional op_sharding = + sharding_parse_result.ValueOrDie(); + // If no sharding metadata is found, XLA is free to use whatever device it + // wants. In practice this usually has the effect of placing things on device + // 0. + xla::ScopedShardingAssignment assign_sharding(b, op_sharding); op_kernel->Compute(context); b->ClearOpMetadata(); - b->ClearSharding(); VLOG(4) << "Done"; } diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc index b5c17c5273bb15e20184b2fefd93880d4828105e..79da701fd244a461a60588153b601d5c1870fa89 100644 --- a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc +++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc @@ -28,9 +28,10 @@ XlaCompiledCpuFunction::XlaCompiledCpuFunction(const StaticData& static_data, temps_(new void*[static_data.num_temps]), arg_names_(static_data.arg_names), result_names_(static_data.result_names), - program_shape_(static_data.program_shape) { + program_shape_(static_data.program_shape), + hlo_profile_printer_(static_data.hlo_profile_printer) { // Allocate arg and temp buffers. - if (alloc_mode == AllocMode::ARGS_RESULTS_AND_TEMPS) { + if (alloc_mode == AllocMode::ARGS_RESULTS_PROFILES_AND_TEMPS) { alloc_args_ = tensorflow::tfcompile::runtime::MallocContiguousBuffers( static_data.arg_sizes, static_data.num_args, args_, /*annotate_initialized=*/false); @@ -39,9 +40,13 @@ XlaCompiledCpuFunction::XlaCompiledCpuFunction(const StaticData& static_data, static_data.temp_sizes, static_data.num_temps, temps_, /*annotate_initialized=*/true); - // The runtime context is always the last arg, if it is required. - if (static_data.requires_runtime_context) { - args_[static_data.num_args - 1] = &context_; + // If Hlo profiling is enabled the generated code expects an appropriately + // sized buffer to be passed in as the last argument. If Hlo profiling is + // disabled the last function argument is still present in the function + // signature, but it is ignored by the generated code and we pass in null for + // it. + if (hlo_profiling_enabled()) { + profile_counters_ = new int64[static_data.profile_counters_size](); } } @@ -50,6 +55,7 @@ XlaCompiledCpuFunction::~XlaCompiledCpuFunction() { tensorflow::tfcompile::runtime::FreeContiguous(alloc_temps_); delete[] args_; delete[] temps_; + delete[] profile_counters_; } namespace { diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h index f49a7889222ff989144217ab10b27595f89e4311..e0ae3ed9a811bcc49ce8862037a67d293e879e57 100644 --- a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h +++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h @@ -16,10 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILED_CPU_FUNCTION_H_ #define TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILED_CPU_FUNCTION_H_ -#include +#include #include -#include "tensorflow/compiler/tf2xla/xla_local_runtime_context.h" #include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/core/platform/types.h" @@ -27,6 +26,7 @@ limitations under the License. // never use this functionality. namespace xla { class ProgramShape; +class HloProfilePrinter; } namespace tensorflow { @@ -48,12 +48,10 @@ namespace tensorflow { class XlaCompiledCpuFunction { public: // Type of the raw function, produced by either JIT or AOT. - // - // TODO(toddw): Add support for hlo profiling, and replace std::function with - // a raw function pointer, for some codesize savings. - using RawFunction = std::function; + using RawFunction = void (*)(void* result, + const xla::ExecutableRunOptions* run_options, + const void** args, void** temps, + int64* profile_counters); // StaticData represents the state necessary to run an XLA-compiled // function. For JIT this is backed by data in XlaJitCompiledCpuFunction; for @@ -71,9 +69,6 @@ class XlaCompiledCpuFunction { // The 0-based index of the result tuple, in the temp buffers. size_t result_index = 0; - // Is the final arg XlaLocalRuntimeContext? - bool requires_runtime_context = false; - // [Optional] Arrays of arg and result names. These are arrays of C-style // strings, where the array is terminated by nullptr. const char** arg_names = nullptr; @@ -81,21 +76,29 @@ class XlaCompiledCpuFunction { // [Optional] Arg and result shapes. const xla::ProgramShape* program_shape = nullptr; + + // [Optional] Profile printer. Null if profiling is disabled. + const xla::HloProfilePrinter* hlo_profile_printer = nullptr; + + // [Optional] The number of profile counters expected in the profile counter + // buffer by the generated code and hlo_profile_printer. 0 if profiling is + // disabled. + int64 profile_counters_size = 0; }; // AllocMode controls the buffer allocation mode. enum class AllocMode { - // Allocate all buffers - args, results and temps. - ARGS_RESULTS_AND_TEMPS, + // Allocate all buffers - args, results, profile and temps. + ARGS_RESULTS_PROFILES_AND_TEMPS, - // Only allocate result and temp buffers. + // Only allocate result, profile and temp buffers. // Use set_arg_data to set argument buffers before Run is called. - RESULTS_AND_TEMPS_ONLY, + RESULTS_PROFILES_AND_TEMPS_ONLY, }; XlaCompiledCpuFunction( const StaticData& static_data, - AllocMode alloc_mode = AllocMode::ARGS_RESULTS_AND_TEMPS); + AllocMode alloc_mode = AllocMode::ARGS_RESULTS_PROFILES_AND_TEMPS); virtual ~XlaCompiledCpuFunction(); XlaCompiledCpuFunction(const XlaCompiledCpuFunction&) = delete; @@ -104,21 +107,22 @@ class XlaCompiledCpuFunction { // Sets the intra-op thread pool used to run individual ops concurrently. void set_thread_pool(const Eigen::ThreadPoolDevice* pool) { run_options_.set_intra_op_thread_pool(pool); - context_.thread_pool = pool; } // Runs the computation, with inputs read from arg buffers, and outputs // written to result buffers. Returns true on success and false on failure. bool Run() { - context_.error = false; - context_.error_msg.clear(); raw_function_(temps_[result_index_], &run_options_, - const_cast(args_), temps_); - return !context_.error; + const_cast(args_), temps_, profile_counters_); + return true; } // Returns the error message from the previous failed Run call. - const string& error_msg() const { return context_.error_msg; } + // + // TODO(fschneider): For now this always returns an empty string because there + // is no support for error reporting in XLA. Remove this once all callers are + // updated. + string error_msg() const { return {}; } // ------------------------------ // Arg methods for managing input buffers. Buffers are in row-major order. @@ -141,10 +145,6 @@ class XlaCompiledCpuFunction { // tensorflow::tfcompile::runtime::kAlign. If possible, use the functions in // tensorflow/compiler/aot/runtime.h to ensure correct alignment. // - // If StaticData.requires_runtime_context==true, the final argument is an - // XlaLocalRuntimeContext, which is managed internally by this class, and - // should not be changed. - // // Aliasing of argument and result buffers is not allowed, and results in // undefined behavior. void set_arg_data(size_t index, void* data) { args_[index] = data; } @@ -162,6 +162,16 @@ class XlaCompiledCpuFunction { return static_cast(temps_[result_index_]); } + // Profile counters for this XLA computation. + // + // When Hlo profiling is enabled (`hlo_profiling_enabled()` return true in + // this case) these counters are non-null and are automatically populated by + // `Run`. The counters can then be pretty-printed using + // `hlo_profile_printer()`. + // + // When Hlo profiling is disabled, this accessor returns null. + const int64* profile_counters() const { return profile_counters_; } + // Returns the buffer for the positional result at the given `index`. void* result_data(size_t index) { return results()[index]; } const void* result_data(size_t index) const { return results()[index]; } @@ -195,6 +205,12 @@ class XlaCompiledCpuFunction { // program shape isn't available. const xla::ProgramShape* ProgramShape() const { return program_shape_; } + bool hlo_profiling_enabled() const { return hlo_profile_printer_ != nullptr; } + const xla::HloProfilePrinter& hlo_profile_printer() const { + assert(hlo_profiling_enabled()); + return *hlo_profile_printer_; + } + private: const RawFunction raw_function_; const size_t result_index_; @@ -208,14 +224,17 @@ class XlaCompiledCpuFunction { void* alloc_args_ = nullptr; void* alloc_temps_ = nullptr; + // Backing memory for profiling counters. + int64* profile_counters_ = nullptr; + // Options and context passed to the compiled function. xla::ExecutableRunOptions run_options_; - tensorflow::XlaLocalRuntimeContext context_; // Optional metadata. const char** arg_names_ = nullptr; const char** result_names_ = nullptr; const xla::ProgramShape* program_shape_ = nullptr; + const xla::HloProfilePrinter* hlo_profile_printer_ = nullptr; }; } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index e49663b8b047fb5f2c9ba17fa0aa032a673e7ed7..4c01e6732128fbb62fb134ad7fa3233725f53ebb 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -23,6 +23,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/functionalize_control_flow.h" #include "tensorflow/compiler/tf2xla/graph_compiler.h" #include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/sharding_util.h" +#include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_context.h" @@ -160,10 +162,10 @@ std::unique_ptr XlaCompiler::GetGraph(const FunctionBody* fbody) { return graph; } -Status XlaCompiler::CompileFunction( - const XlaCompiler::CompileOptions& options, const NameAttrList& function, - const std::vector& args, - XlaCompiler::CompilationResult* result) { +Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options, + const NameAttrList& function, + std::vector args, + XlaCompiler::CompilationResult* result) { const string function_id = Canonicalize(function.name(), AttrSlice(&function.attr())); VLOG(1) << "XlaCompiler::CompileFunction " << function_id; @@ -184,6 +186,25 @@ Status XlaCompiler::CompileFunction( std::unique_ptr graph(new Graph(options_.flib_def)); CopyGraph(*fbody->graph, graph.get()); + // _Arg and _Retval nodes don't exist in the stored subgraph for the function; + // they are added by the function body looked up. Therefore, they don't have + // core assignments here. + // Attempt to assign a core to each _Retval and _Arg. Chooses the + // lowest-numbered core that consumes the argument. We choose the + // lowest-numbered core so the assignment is deterministic. + for (Node* n : graph->nodes()) { + if (StringPiece(n->type_string()) == "_Arg") { + TF_RETURN_IF_ERROR(SetNodeShardingFromNeighbors(n, /*out_edges=*/true)); + } + } + // Do _Retval as a second loop, in case the retval's input is an _Arg (which + // may have gotten a device assignment from the first loop). + for (Node* n : graph->nodes()) { + if (StringPiece(n->type_string()) == "_Retval") { + TF_RETURN_IF_ERROR(SetNodeShardingFromNeighbors(n, /*out_edges=*/false)); + } + } + if (VLOG_IS_ON(2)) { VLOG(2) << "XlaCompiler::CompileFunction: " << dump_graph::DumpGraphToFile( @@ -241,13 +262,15 @@ Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr graph, // Builds XLA computations for each of the arguments to the computation. // `args` are the arguments to the computation. -Status BuildArguments(const std::vector& args, +Status BuildArguments(const Graph& graph, + const std::vector& args, bool use_tuple_arg, xla::ComputationBuilder* builder, - XlaContext* context, + XlaContext* context, std::vector* arg_cores, std::vector* arg_expressions, std::vector* input_mapping, std::vector* input_shapes) { arg_expressions->resize(args.size()); + *arg_cores = std::vector(args.size(), -1); // Argument numbers of arguments and resources that are to be passed to the // XLA computation as runtime parameters. @@ -302,6 +325,26 @@ Status BuildArguments(const std::vector& args, (*input_mapping)[i] = parameters[i]; } + // Use the _Arg nodes in the graph to resolve core assignments. + for (const Node* n : graph.nodes()) { + if (StringPiece(n->type_string()) != "_Arg") continue; + int index; + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index)); + TF_RET_CHECK(index >= 0 && index < args.size()) + << "_Arg out of bounds: " << index << " vs " << args.size(); + TF_ASSIGN_OR_RETURN( + auto sharding, + ParseShardingFromDevice(*n, std::numeric_limits::max())); + if (sharding.has_value()) { + TF_RET_CHECK(sharding.value().type() == + xla::OpSharding::Type::OpSharding_Type_MAXIMAL); + const int core = sharding.value().tile_assignment_devices(0); + if ((*arg_cores)[index] == -1 || core < (*arg_cores)[index]) { + (*arg_cores)[index] = core; + } + } + } + // Build parameter handles for non-constant arguments. std::vector arg_handles(parameters.size()); if (use_tuple_arg) { @@ -309,10 +352,18 @@ Status BuildArguments(const std::vector& args, xla::ComputationDataHandle tuple = builder->Parameter(0, tuple_shape, "arg_tuple"); for (std::vector::size_type i = 0; i < parameters.size(); ++i) { + const int core = (*arg_cores)[parameters[i]]; + xla::ScopedShardingAssignment assign_sharding( + builder, core == -1 ? tensorflow::gtl::optional() + : xla::ShardingBuilder::AssignDevice(core)); arg_handles[i] = builder->GetTupleElement(tuple, i); } } else { for (std::vector::size_type i = 0; i < parameters.size(); ++i) { + const int core = (*arg_cores)[parameters[i]]; + xla::ScopedShardingAssignment assign_sharding( + builder, core == -1 ? tensorflow::gtl::optional() + : xla::ShardingBuilder::AssignDevice(core)); arg_handles[i] = builder->Parameter(i, (*input_shapes)[i], strings::StrCat("arg", i)); } @@ -368,6 +419,7 @@ Status BuildArguments(const std::vector& args, // type of the final output. Status BuildComputation( const std::vector& args, + const std::vector& arg_cores, const std::vector& retvals, const std::vector>& resources, bool return_updated_values_for_all_resources, @@ -398,6 +450,8 @@ Status BuildComputation( for (const XlaResource* resource : arg_resources) { const XlaCompiler::Argument& arg = args[resource->arg_num]; + const int core = arg_cores[resource->arg_num]; + DCHECK_LT(resource->arg_num, arg_cores.size()); bool modified = resource->value.handle() != resource->initial_value.handle(); // TensorArray gradients were modified if their values changed or there are @@ -417,8 +471,21 @@ Status BuildComputation( for (const auto& grad : resource->tensor_array_gradients) { update.tensor_array_gradients_accessed.insert(grad.first); } + + // Request that the value be returned on a specific core. + xla::ScopedShardingAssignment assign_sharding( + builder, core == -1 ? tensorflow::gtl::optional() + : xla::ShardingBuilder::AssignDevice(core)); + xla::ComputationDataHandle handle; TF_RETURN_IF_ERROR(resource->Pack(&handle, builder)); + + // Since we can't change the sharding metadata of as this point, + // create a tuple/get-tuple-element combination so that sharding + // assignment will be placed on this value, which will cause the resource + // update to be returned from the same device that provided the resource. + handle = builder->GetTupleElement(builder->Tuple({handle}), 0); + elems.push_back(handle); } } @@ -476,12 +543,11 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, options.resolve_compile_time_constants); core::ScopedUnref context_unref(context); - result->tuple_arg = options.use_tuple_arg; - std::vector arg_expressions; + std::vector arg_cores; TF_RETURN_IF_ERROR(BuildArguments( - args, options.use_tuple_arg, &builder, context, &arg_expressions, - &result->input_mapping, &result->xla_input_shapes)); + *graph, args, options.use_tuple_arg, &builder, context, &arg_cores, + &arg_expressions, &result->input_mapping, &result->xla_input_shapes)); context->set_args(std::move(arg_expressions)); TF_RETURN_IF_ERROR(ExecuteGraph(context, std::move(graph), device_, @@ -491,16 +557,11 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, int num_computation_outputs; result->computation = std::make_shared(); TF_RETURN_IF_ERROR(BuildComputation( - args, context->retvals(), context->resources(), + args, arg_cores, context->retvals(), context->resources(), options.return_updated_values_for_all_resources, &builder, result->computation.get(), &num_computation_outputs, &num_nonconst_outputs, &result->resource_updates)); - result->requires_runtime_context = context->has_context_parameter(); - - // Tuple arguments and runtime context parameters are incompatible. - TF_RET_CHECK(!(options.use_tuple_arg && result->requires_runtime_context)); - VLOG(2) << "Outputs: total: " << context->retvals().size() << " nonconstant: " << num_nonconst_outputs; result->outputs.resize(context->retvals().size()); diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index a8882a638caf2d742bfa2b4f68140e1dc4520db1..380e24e96bc713af4453f92a5359995e9ab4734a 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -54,8 +54,6 @@ namespace tensorflow { // +---------------------+-----------------------------------------+ // Within each block, the arguments are arranged by the _Arg index from which // they were derived. -// If `Options::requires_runtime_context` is true, then an additional runtime -// context argument is passed as a final argument. // // The run-time outputs of the XLA computation are arranged in the following // order: @@ -191,16 +189,9 @@ class XlaCompiler { // original arguments, and are not necessarily in the same order.) std::vector input_mapping; - // Does the computation require the local runtime context to be passed as - // the last argument? - bool requires_runtime_context = false; - // Input shapes of the computation. std::vector xla_input_shapes; - // Should the arguments be packed into a single tuple? - bool tuple_arg; - // Output shape in XLA format. The output shape is always a tuple. xla::Shape xla_output_shape; @@ -232,16 +223,9 @@ class XlaCompiler { int graph_def_version = TF_GRAPH_DEF_VERSION; // If 'allow_cpu_custom_calls' is true, kernels may make use of CustomCall() - // for CPU; additionally, an optional XlaLocalRuntimeContext* may be passed - // to the computation. + // for CPU. bool allow_cpu_custom_calls = false; - // If 'local_executable_has_hybrid_result', the top-level pointers of the - // result tuple of compiled programs are stored in host memory and the - // nested buffers in device memory, otherwise the whole result tuple is - // stored in device memory. - bool local_executable_has_hybrid_result = false; - // If not nullptr, populate_resource_manager is called with the // compilation device's resource manager when the compilation // device is created, and can be used to create metadata objects @@ -255,8 +239,7 @@ class XlaCompiler { Status CompileFunction(const CompileOptions& options, const NameAttrList& fn_name_attrs, - const std::vector& args, - CompilationResult* result); + std::vector args, CompilationResult* result); // Compiles a tensorflow::Graph into an xla::Computation. // Similar to CompileFunction, but takes a Graph as input rather than a diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc index 651bafd6c5d946adfedd63ebbe93e4ea016f0b37..5d19dd353fc04744e196bb50c35cb60b35d8b258 100644 --- a/tensorflow/compiler/tf2xla/xla_context.cc +++ b/tensorflow/compiler/tf2xla/xla_context.cc @@ -70,24 +70,6 @@ XlaContext::XlaContext(XlaCompiler* compiler, xla::ComputationBuilder* builder, allow_cpu_custom_calls_(allow_cpu_custom_calls), resolve_compile_time_constants_(resolve_compile_time_constants) {} -const xla::ComputationDataHandle& -XlaContext::GetOrCreateRuntimeContextParameter() { - CHECK(allow_cpu_custom_calls_); - if (has_context_parameter_) return context_parameter_; - has_context_parameter_ = true; - - // Allocate the next available parameter for the context parameter. - int num_parameters = 0; - for (const XlaExpression& arg : args_) { - if (!arg.has_constant_value()) { - ++num_parameters; - } - } - context_parameter_ = builder_->Parameter( - num_parameters, xla::ShapeUtil::MakeOpaqueShape(), "tf_context"); - return context_parameter_; -} - string XlaContext::DebugString() { return "TLA JIT context"; } // This is called by the Retval Op to associate a computed value @@ -178,6 +160,20 @@ const xla::Computation* XlaContext::GetOrCreateAdd(const DataType type) { }); } +const xla::Computation* XlaContext::GetOrCreateMul(const DataType type) { + return LookupOrCreate(type, &mul_func_, [this, type] { + const string type_string = DataTypeString(type); + VLOG(1) << "Building Mul() for " << type_string; + xla::ComputationBuilder b(builder()->client(), "mul<" + type_string + ">"); + xla::PrimitiveType xla_type; + TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type)); + auto x = b.Parameter(0, xla::ShapeUtil::MakeShape(xla_type, {}), "x"); + auto y = b.Parameter(1, xla::ShapeUtil::MakeShape(xla_type, {}), "y"); + b.Mul(x, y); + return b.Build().ConsumeValueOrDie(); + }); +} + const xla::Computation* XlaContext::LookupOrCreate( DataType type, ComputationMap* out, const std::function& create) { diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h index de8aafa3628e6eebdabbc508cd95a2ac86e3472f..ebd758d1540eba5483714265565ad22c244ca4a3 100644 --- a/tensorflow/compiler/tf2xla/xla_context.h +++ b/tensorflow/compiler/tf2xla/xla_context.h @@ -56,15 +56,10 @@ class XlaContext : public ResourceBase { xla::ComputationBuilder* builder(); bool allow_cpu_custom_calls() const { return allow_cpu_custom_calls_; } - bool has_context_parameter() const { return has_context_parameter_; } const std::vector& args() const { return args_; } void set_args(std::vector args); - // Get the runtime context parameter, adding one if it does not already exist. - // Dies if not compiling a local executable. - const xla::ComputationDataHandle& GetOrCreateRuntimeContextParameter(); - const std::vector& retvals() { return retvals_; } // This is called by the Retval Op to associate a computed value @@ -102,6 +97,11 @@ class XlaContext : public ResourceBase { // separate specialization of the computation for each DataType. const xla::Computation* GetOrCreateAdd(const DataType type); + // Get an XLA lambda to compute Mul. This is cached in the + // XlaContext since it may be used by multiple Ops. There is a + // separate specialization of the computation for each DataType. + const xla::Computation* GetOrCreateMul(const DataType type); + // The name of the XlaContext resource during symbolic graph execution. static const char kXlaContextResourceName[]; @@ -119,13 +119,6 @@ class XlaContext : public ResourceBase { // run-time computation outptus. const bool resolve_compile_time_constants_; - // When 'has_context_parameter_' is true, this is the computation handle - // for an additional final parameter to the computation, through which will be - // passed a XlaLocalRuntimeContext* at runtime. Created on demand by - // GetOrCreateRuntimeContextParameter(). - bool has_context_parameter_ = false; - xla::ComputationDataHandle context_parameter_; - // Arguments to the Tensorflow graph, indexed by _Arg index. // Includes both compile-time constant arguments and runtime parameters. std::vector args_; @@ -155,6 +148,9 @@ class XlaContext : public ResourceBase { // Cached computation to compute Sum of two elements, specialized by type. ComputationMap add_func_; + // Cached computation to compute Mul of two elements, specialized by type. + ComputationMap mul_func_; + // Cached computation to compute Sigmoid of an element, specialized by type. ComputationMap sigmoid_func_; diff --git a/tensorflow/compiler/tf2xla/xla_gpu_backend.cc b/tensorflow/compiler/tf2xla/xla_gpu_backend.cc index d504613d232c779e47a506657d2825d052e726dc..8ca757e72355d890c13b8b448d35c327d3986696 100644 --- a/tensorflow/compiler/tf2xla/xla_gpu_backend.cc +++ b/tensorflow/compiler/tf2xla/xla_gpu_backend.cc @@ -21,8 +21,6 @@ namespace tensorflow { bool GpuOpFilter(KernelDef* kdef) { // TODO(b/31361304): The GPU backend does not parallelize PRNG ops, leading to // slow code. - // TODO(b/34969189) The implementation of TruncatedNormal generates illegal - // code on GPU. if (kdef->op() == "RandomStandardNormal" || kdef->op() == "RandomUniform" || kdef->op() == "RandomUniformInt" || kdef->op() == "TruncatedNormal") { return false; diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc index de5ad5f176536e1453da518b96ee755c7f1e8fdc..ec9e535b707beec6ea26dc81c7ee76b1d4da9225 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.cc +++ b/tensorflow/compiler/tf2xla/xla_helpers.cc @@ -13,12 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// This file defines helper routines for Tla JIT compilation. +// This file defines helper routines for XLA compilation. #include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/lib/util.h" + #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/framework/tensor.h" @@ -26,6 +29,67 @@ limitations under the License. namespace tensorflow { +namespace { + +Status ArgMinMax(xla::ComputationBuilder* builder, XlaOpKernelContext* ctx, + const xla::ComputationDataHandle& input, + const TensorShape& input_shape, DataType input_type, + DataType output_type, int axis, bool is_min, + xla::ComputationDataHandle* argminmax) { + xla::ComputationDataHandle init_value; + const xla::Computation* reducer; + if (is_min) { + init_value = XlaHelpers::MaxValue(builder, input_type); + reducer = ctx->GetOrCreateMin(input_type); + } else { + init_value = XlaHelpers::MinValue(builder, input_type); + reducer = ctx->GetOrCreateMax(input_type); + } + + xla::PrimitiveType xla_output_type; + TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(output_type, &xla_output_type)); + + xla::ComputationDataHandle input_max = builder->Reduce( + input, init_value, *reducer, /*dimensions_to_reduce=*/{axis}); + std::vector broadcast_dims(input_shape.dims() - 1); + std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0); + std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1); + // Compute a mask that has 1s for elements equal to the maximum. + xla::ComputationDataHandle partial_mask = builder->ConvertElementType( + builder->Eq(input, input_max, broadcast_dims), xla_output_type); + + // In order to make identity elements for a bitwise And, we: + // Left shift the 1 to the leftmost bit, yielding 0x10...0 + // Arithmetic right shift the 1 back to the rightmost bit, yielding + // 0xFF...F + int32 bits_in_type = + xla::ShapeUtil::ByteSizeOfPrimitiveType(xla_output_type) * 8 - 1; + xla::ComputationDataHandle shift_amount = + XlaHelpers::IntegerLiteral(builder, output_type, bits_in_type); + xla::ComputationDataHandle full_mask = builder->ShiftRightArithmetic( + builder->ShiftLeft(partial_mask, shift_amount), shift_amount); + + // And with the vector [0, 1, 2, ...] to convert each 0xFF...F into its + // index. + xla::ComputationDataHandle iota; + + const int64 axis_size = input_shape.dim_size(axis); + TF_RETURN_IF_ERROR(XlaHelpers::Iota(builder, output_type, axis_size, &iota)); + xla::ComputationDataHandle product = + builder->And(full_mask, iota, /*broadcast_dimensions=*/{axis}); + + // If there are multiple maximum elements, choose the one with the highest + // index. + xla::ComputationDataHandle output = + builder->Reduce(product, XlaHelpers::MinValue(builder, output_type), + *ctx->GetOrCreateMax(output_type), + /*dimensions_to_reduce=*/{axis}); + *argminmax = output; + return Status::OK(); +} + +} // namespace + xla::ComputationDataHandle XlaHelpers::MinValue(xla::ComputationBuilder* b, DataType data_type) { xla::PrimitiveType type; @@ -57,6 +121,8 @@ xla::ComputationDataHandle XlaHelpers::One(xla::ComputationBuilder* b, xla::ComputationDataHandle XlaHelpers::Epsilon(xla::ComputationBuilder* b, DataType data_type) { switch (data_type) { + case DT_BFLOAT16: + return b->ConstantR0(bfloat16::epsilon()); case DT_FLOAT: return b->ConstantR0(std::numeric_limits::epsilon()); case DT_DOUBLE: @@ -105,6 +171,9 @@ xla::ComputationDataHandle XlaHelpers::IntegerLiteral( case xla::S16: case xla::U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; + case xla::BF16: + literal = *xla::Literal::CreateR0(static_cast(value)); + break; case xla::F16: literal = *xla::Literal::CreateR0(static_cast(value)); @@ -122,25 +191,9 @@ xla::ComputationDataHandle XlaHelpers::IntegerLiteral( xla::ComputationDataHandle XlaHelpers::FloatLiteral(xla::ComputationBuilder* b, DataType data_type, double value) { - xla::Literal literal; xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); - switch (type) { - case xla::F16: - return b->ConstantR0(static_cast(value)); - break; - case xla::F32: - return b->ConstantR0(static_cast(value)); - break; - case xla::F64: - return b->ConstantR0(value); - break; - case xla::C64: - return b->ConstantR0(value); - break; - default: - LOG(FATAL) << "unhandled element type " << type; - } + return ::tensorflow::FloatLiteral(b, type, value); } /* static */ Status XlaHelpers::ReshapeLiteral( @@ -174,6 +227,26 @@ static Tensor MakeLinspaceTensor(const TensorShape& shape, int64 depth) { return linspace; } +Status XlaHelpers::ArgMax(xla::ComputationBuilder* builder, + XlaOpKernelContext* ctx, + const xla::ComputationDataHandle& input, + const TensorShape& input_shape, DataType input_type, + DataType output_type, int axis, + xla::ComputationDataHandle* argmax) { + return ArgMinMax(builder, ctx, input, input_shape, input_type, output_type, + axis, /*is_min=*/false, argmax); +} + +Status XlaHelpers::ArgMin(xla::ComputationBuilder* builder, + XlaOpKernelContext* ctx, + const xla::ComputationDataHandle& input, + const TensorShape& input_shape, DataType input_type, + DataType output_type, int axis, + xla::ComputationDataHandle* argmin) { + return ArgMinMax(builder, ctx, input, input_shape, input_type, output_type, + axis, /*is_min=*/true, argmin); +} + Status XlaHelpers::Iota(xla::ComputationBuilder* builder, DataType dtype, int64 size, xla::ComputationDataHandle* iota) { TensorShape linspace_shape({size}); diff --git a/tensorflow/compiler/tf2xla/xla_helpers.h b/tensorflow/compiler/tf2xla/xla_helpers.h index af23d20fd306c03b5e47c5ca9dd042187a2d51ed..2a027db4c839c917f3a7acd27184792d157356bf 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.h +++ b/tensorflow/compiler/tf2xla/xla_helpers.h @@ -72,14 +72,35 @@ class XlaHelpers { gtl::ArraySlice shape, xla::Literal* output); + // Sets `argmax` to the argmax of `input` along `axis`. `input_shape` and + // `input_dtype` are the shape and dtype of `input` respectively, and + // `output_type` is the dtype to use for `argmax`. + static Status ArgMax(xla::ComputationBuilder* builder, + XlaOpKernelContext* ctx, + const xla::ComputationDataHandle& input, + const TensorShape& input_shape, DataType input_type, + DataType output_type, int axis, + xla::ComputationDataHandle* argmax); + + // Sets `argmin` to the argmin of `input` along `axis`. `input_shape` and + // `input_dtype` are the shape and dtype of `input` respectively, and + // `output_type` is the dtype to use for `argmin`. + static Status ArgMin(xla::ComputationBuilder* builder, + XlaOpKernelContext* ctx, + const xla::ComputationDataHandle& input, + const TensorShape& input_shape, DataType input_type, + DataType output_type, int axis, + xla::ComputationDataHandle* argmin); + // Sets *iota to a rank 1 tensor with values [0, 1, 2, ...] of `dtype`. static Status Iota(xla::ComputationBuilder* builder, DataType dtype, int64 size, xla::ComputationDataHandle* iota); // Converts `indices` into a one-hot representation. `depth` is the size // of the new axis to add. `axis` is the position at which to add the new - // axis. `indices_shape` is the shape of `indices`. `on_value` and `off_value` - // represent the values to use for the on and off positions, respectively. + // axis. `indices_shape` is the shape of `indices`. `on_value` and + // `off_value` represent the values to use for the on and off positions, + // respectively. static Status OneHot(xla::ComputationBuilder* builder, int64 depth, int axis, DataType index_type, const TensorShape& indices_shape, const xla::ComputationDataHandle& indices, diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc index 1dd454ea8d57e21526e5bcde0c8efc5514983b93..584417bc72c8f6645c05912e857b031cfb394e54 100644 --- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc +++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc @@ -37,27 +37,14 @@ namespace { // Returns a vector of positional argument buffer sizes. xla::StatusOr> ComputeArgSizes( - const xla::ProgramShape& program_shape, bool requires_runtime_context) { + const xla::ProgramShape& program_shape) { std::vector arg_sizes; const size_t num_args = program_shape.parameters_size(); arg_sizes.reserve(num_args); for (int i = 0; i < num_args; ++i) { const xla::Shape& arg_shape = program_shape.parameters(i); - if (i == num_args - 1 && requires_runtime_context) { - // If the compiled function needs an XlaLocalRuntimeContext* arg, it's - // always last, and must be represented as an opaque type. - const xla::PrimitiveType type = arg_shape.element_type(); - if (type != xla::OPAQUE) { - return errors::InvalidArgument( - "expected final context arg to be opaque, but got type: ", - xla::PrimitiveType_Name(type), ", from program shape: ", - xla::ShapeUtil::HumanString(program_shape)); - } - arg_sizes.push_back(-1); - } else { - constexpr size_t kPointerSize = sizeof(void*); - arg_sizes.push_back(xla::ShapeUtil::ByteSizeOf(arg_shape, kPointerSize)); - } + constexpr size_t kPointerSize = sizeof(void*); + arg_sizes.push_back(xla::ShapeUtil::ByteSizeOf(arg_shape, kPointerSize)); } return std::move(arg_sizes); } @@ -90,21 +77,6 @@ xla::StatusOr ComputeResultIndex( return result_slice.index(); } -// Adapt ComputeFunctionType, which includes a final profile_counters arg, to -// RawFunction, which doesn't include that final arg. -// -// TODO(toddw): Change RawFunction and AOT to also pass the final -// profile_counters arg, and remove this adapter. -XlaCompiledCpuFunction::RawFunction RawFunctionAdapter( - xla::cpu::CpuExecutable::ComputeFunctionType compute_function) { - return [compute_function](void* result, - const xla::ExecutableRunOptions* run_options, - const void** args, void** temps) { - return compute_function(result, run_options, args, temps, - /*profile_counters=*/nullptr); - }; -} - // Collect names from `entries`, where T is one of tf2xla::{Feed,Fetch}. We hold // the actual strings in nonempty_names, and hold arrays of pointers in // name_ptrs, terminated by a nullptr entry. @@ -144,9 +116,8 @@ XlaJitCompiledCpuFunction::Compile( TF_ASSIGN_OR_RETURN(xla::LocalClient * client, xla::ClientLibrary::GetOrCreateLocalClient()); xla::Computation computation; - bool requires_runtime_context; - TF_RETURN_IF_ERROR(tensorflow::ConvertGraphDefToXla( - graph_def, config, client, &computation, &requires_runtime_context)); + TF_RETURN_IF_ERROR(tensorflow::ConvertGraphDefToXla(graph_def, config, client, + &computation)); // Get and verify the program shape. TF_ASSIGN_OR_RETURN(std::unique_ptr program_shape, @@ -177,14 +148,13 @@ XlaJitCompiledCpuFunction::Compile( const xla::cpu::CpuExecutable* cpu_executable = static_cast(executable->executable()); XlaCompiledCpuFunction::RawFunction raw_function = - RawFunctionAdapter(cpu_executable->compute_function()); + cpu_executable->compute_function(); const xla::BufferAssignment& buffer_assignment = cpu_executable->buffer_assignment(); // Compute buffer sizes and the result index, needed to run the raw function. - TF_ASSIGN_OR_RETURN( - std::vector arg_sizes, - ComputeArgSizes(*program_shape, requires_runtime_context)); + TF_ASSIGN_OR_RETURN(std::vector arg_sizes, + ComputeArgSizes(*program_shape)); TF_ASSIGN_OR_RETURN(std::vector temp_sizes, ComputeTempSizes(buffer_assignment)); TF_ASSIGN_OR_RETURN(size_t result_index, @@ -203,7 +173,6 @@ XlaJitCompiledCpuFunction::Compile( jit->static_data_.temp_sizes = jit->temp_sizes_.data(); jit->static_data_.num_temps = jit->temp_sizes_.size(); jit->static_data_.result_index = result_index; - jit->static_data_.requires_runtime_context = requires_runtime_context; // Optional metadata is collected and set below. CollectNames(config.feed(), &jit->nonempty_arg_names_, &jit->arg_names_); CollectNames(config.fetch(), &jit->nonempty_result_names_, @@ -211,6 +180,14 @@ XlaJitCompiledCpuFunction::Compile( jit->static_data_.arg_names = jit->arg_names_.data(); jit->static_data_.result_names = jit->result_names_.data(); jit->static_data_.program_shape = jit->program_shape_.get(); + + if (cpu_executable->hlo_profiling_enabled()) { + jit->static_data_.hlo_profile_printer = + &cpu_executable->hlo_profile_printer(); + jit->static_data_.profile_counters_size = + cpu_executable->hlo_profile_printer().profile_counters_size(); + } + return std::move(jit_unique_ptr); } diff --git a/tensorflow/compiler/tf2xla/xla_local_runtime_context.h b/tensorflow/compiler/tf2xla/xla_local_runtime_context.h deleted file mode 100644 index dca420d6ee3fec45f88ac3b450ab0cb4fb83d38a..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/tf2xla/xla_local_runtime_context.h +++ /dev/null @@ -1,55 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_LOCAL_RUNTIME_CONTEXT_H_ -#define TENSORFLOW_COMPILER_TF2XLA_XLA_LOCAL_RUNTIME_CONTEXT_H_ - -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/types.h" - -// Forward-declare the ThreadPoolDevice so that it can be ignored unless it's -// actually used. E.g. some ahead-of-time compiled computations don't need a -// thread pool. -namespace Eigen { -struct ThreadPoolDevice; -} - -namespace tensorflow { - -// An instance of this class is passed to each call from tensorflow into a -// compiled XLA computation. See xla_launch_ops.cc. -struct XlaLocalRuntimeContext { - public: - XlaLocalRuntimeContext() {} - - // Kernels implemented using custom call ops set this if they encounter an - // error. The error is checked after the entire XLA computation is - // complete. - // - // error+error_msg are used instead of Status to reduce the binary size - // overhead for ahead-of-time compiled binaries. - bool error = false; - string error_msg; - - // Kernels that need a thread pool can get it from here. - const Eigen::ThreadPoolDevice* thread_pool = nullptr; - - private: - TF_DISALLOW_COPY_AND_ASSIGN(XlaLocalRuntimeContext); -}; - -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_TF2XLA_XLA_LOCAL_RUNTIME_CONTEXT_H_ diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index b948dfee6ab33651e52ca5045cfce600c788bc3b..79d501b511bf37ba4a79ab9d375d6f789a36889b 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -345,6 +345,16 @@ void XlaOpKernelContext::SetConstantOutput(int index, const Tensor& constant) { expression->set_constant_value(constant); } +void XlaOpKernelContext::SetInvalidOutput(int index) { + Tensor* output = nullptr; + OP_REQUIRES_OK(context_, + context_->allocate_output(index, TensorShape({}), &output)); + XlaExpression* expression = CastExpressionFromUninitializedTensor(output); + xla::ComputationDataHandle handle; + handle.set_handle(0); + expression->set_handle(handle); +} + void XlaOpKernelContext::SetResourceOutput(int index, XlaResource* resource) { Tensor* output = nullptr; // The shape of the output tensor is the shape of the resource itself @@ -407,6 +417,11 @@ const xla::Computation* XlaOpKernelContext::GetOrCreateAdd( return XlaContext::Get(context_).GetOrCreateAdd(type); } +const xla::Computation* XlaOpKernelContext::GetOrCreateMul( + const DataType type) { + return XlaContext::Get(context_).GetOrCreateMul(type); +} + XlaOpKernel::XlaOpKernel(OpKernelConstruction* context) : OpKernel(context) {} void XlaOpKernel::Compute(OpKernelContext* context) { diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index 5519e89252ca5a3964dcdaaeb3d08ce6c9da6bd4..f1ae81a5aa9d507a3e0dd577568377385b1844e6 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -142,6 +142,10 @@ class XlaOpKernelContext { // SetConstantOutput where possible. void SetConstantOutput(int index, const Tensor& host_tensor); + // Sets output 'index' to an invalid value. + // Any subsequent attempt to consume this output will cause an error. + void SetInvalidOutput(int index); + // Status handling. void SetStatus(const Status& status) { context_->SetStatus(status); } Status status() { return context_->status(); } @@ -174,7 +178,7 @@ class XlaOpKernelContext { // If this kernel invocation is within a function execution, // call_frame() returns the call frame for the function call. - FunctionCallFrame* call_frame() const { return context_->call_frame(); } + CallFrameInterface* call_frame() const { return context_->call_frame(); } FunctionLibraryRuntime* function_library() const { return context_->function_library(); @@ -206,6 +210,11 @@ class XlaOpKernelContext { // separate specialization of the computation for each DataType. const xla::Computation* GetOrCreateAdd(const DataType type); + // Gets an XLA lambda to compute Mul. This is cached in the + // XlaContext since it may be used by multiple Ops. There is a + // separate specialization of the computation for each DataType. + const xla::Computation* GetOrCreateMul(const DataType type); + private: OpKernelContext* const context_; }; diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc index 02318cf7fa1d4edc12507f6b4d66a8e897cbe100..faf47434b5dc6b569ec4f9c91a8667de275a6315 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.cc +++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/framework/kernel_def.pb.h" #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/op_def_util.h" #include "tensorflow/core/platform/mem.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -187,22 +188,39 @@ void XlaOpRegistry::RegisterCompilationKernels() { // Constrain each type attribute to the intersection of: // a) the types supported by the backend, and - // b) the attribute's type constraints. - // TODO(phawkins): it may be necessary to also take the intersection with - // the set of types supported by the OpDef. + // b) the types allowed by the OpDef, and + // c) the type constraints. for (const string& type_attr : type_attrs) { KernelDef::AttrConstraint* attr_constraint = kdef->add_constraint(); attr_constraint->set_name(type_attr); auto* allowed_values = attr_constraint->mutable_allowed_values()->mutable_list(); - auto it = op_registration->type_constraints.find(type_attr); + const OpDef::AttrDef& op_def_attr = *FindAttr(type_attr, *op_def); + const auto* op_def_allowed_types = + op_def_attr.has_allowed_values() + ? &op_def_attr.allowed_values().list().type() + : nullptr; + auto constraint_it = op_registration->type_constraints.find(type_attr); + const std::set* type_constraints = + constraint_it != op_registration->type_constraints.end() + ? &constraint_it->second + : nullptr; for (DataType dtype : backend.second.supported_types) { - if (it == op_registration->type_constraints.end() || - (it != op_registration->type_constraints.end() && - it->second.find(dtype) != it->second.end())) { - allowed_values->add_type(dtype); + // Filter out types that aren't allowed by the OpDef. + if (op_def_allowed_types != nullptr && + std::find(op_def_allowed_types->begin(), + op_def_allowed_types->end(), + dtype) == op_def_allowed_types->end()) { + continue; } + // Filter out types based on the type constraints. + if (type_constraints != nullptr && + type_constraints->find(dtype) == type_constraints->end()) { + continue; + } + // Passed all the filters, this type is allowed. + allowed_values->add_type(dtype); } if (op_registration->allow_resource_types) { allowed_values->add_type(DT_RESOURCE); @@ -245,6 +263,22 @@ std::vector XlaOpRegistry::DeviceKernels( return kernels; } +std::vector XlaOpRegistry::BackendNames() { + std::vector names; + XlaOpRegistry& registry = Instance(); + mutex_lock lock(registry.mutex_); + for (const auto& backend_pair : registry.backends_) { + names.push_back(backend_pair.first); + } + return names; +} + +bool XlaOpRegistry::IsBackendRegistered(const string& name) { + XlaOpRegistry& registry = Instance(); + mutex_lock lock(registry.mutex_); + return registry.backends_.find(name) != registry.backends_.end(); +} + XlaOpRegistry& XlaOpRegistry::Instance() { static XlaOpRegistry* r = new XlaOpRegistry; return *r; diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h index 6aee8c91cc01b4382ef867fa8e438eede008ac73..2959d2ab690dfb91f8f46f5cf5718a405d9e0c7f 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.h +++ b/tensorflow/compiler/tf2xla/xla_op_registry.h @@ -97,6 +97,12 @@ class XlaOpRegistry { gtl::ArraySlice supported_types, BackendOpFilter op_filter); + // Returns the names of the registered backends. + static std::vector BackendNames(); + + // Returns true iff a backend with the given name is registered. + static bool IsBackendRegistered(const string& name); + // Registers `device_name` for XLA compilation, using information from // `registration`. static void RegisterCompilationDevice(const string& device_name, @@ -116,8 +122,8 @@ class XlaOpRegistry { static void RegisterCompilationKernels(); // Returns KernelDefs for compilation ops registered on - // 'compilation_device_name'. - // Does not include kernels registered as CompilationOnly. + // 'compilation_device_name'. Does not include kernels registered as + // CompilationOnly, iff include_compilation_only_kernels=false. static std::vector DeviceKernels( const string& compilation_device_name, bool include_compilation_only_kernels); diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index 660f419e464936b01a3644e69c2f056f998140f5..d3f292207fee396fb4248dede5c0eeb5cd2b87c9 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -77,6 +77,7 @@ cc_library( hdrs = ["types.h"], visibility = [":friends"], deps = [ + "//tensorflow/core:framework_lite", "//tensorflow/core:lib", "//third_party/eigen3", ], @@ -174,6 +175,7 @@ cc_library( ":types", ":xla_data_proto", "//tensorflow/core:lib", + "//tensorflow/core:ptr_util", ], ) @@ -339,6 +341,7 @@ cc_library( name = "array", hdrs = ["array.h"], deps = [ + ":status", ":types", "//tensorflow/core:lib", ], diff --git a/tensorflow/compiler/xla/array.h b/tensorflow/compiler/xla/array.h index ba898d1f4e9100df59c6e4b28824895c5ae6c08a..213e0bac6c77e9972de8d4dd7dfc8c7cf3a1b865 100644 --- a/tensorflow/compiler/xla/array.h +++ b/tensorflow/compiler/xla/array.h @@ -23,8 +23,10 @@ limitations under the License. #include #include #include +#include #include +#include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/bits.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -35,10 +37,63 @@ limitations under the License. namespace xla { +namespace array_impl { + +// conjunction +// +// Performs a compile-time logical AND operation on the passed types (which +// must have `::value` members convertible to `bool`. Short-circuits if it +// encounters any `false` members (and does not compare the `::value` members +// of any remaining arguments). +// +// This metafunction is designed to be a drop-in replacement for the C++17 +// `std::conjunction` metafunction. +template +struct conjunction; + +template +struct conjunction + : std::conditional, T>::type {}; + +template <> +struct conjunction<> : std::true_type {}; + +// A type trait that is valid when all elements in a parameter pack are of +// integral type. +template +using pack_is_integral = conjunction...>; + +// Compares three same-sized vectors elementwise. For each item in `values`, +// returns false if any of values[i] is outside the half-open range [starts[i], +// ends[i]). +template +bool all_inside_range(const C1& values, const C2& range_starts, + const C3& range_ends) { + for (size_t i = 0, e = values.size(); i < e; ++i) { + if (values[i] < range_starts[i] || values[i] >= range_ends[i]) { + return false; + } + } + return true; +} + +} // namespace array_impl + // General N dimensional array class with arbitrary value type. template class Array { public: + // Type inference can have a hard time parsing very deep initializer list + // nests, especially if one or more dimensions is one as the compiler just + // sees a single-element integer initializer. These typedefs allow casting + // explicitly with less typing. + using InitializerList1D = std::initializer_list; + using InitializerList2D = std::initializer_list; + using InitializerList3D = std::initializer_list; + using InitializerList4D = std::initializer_list; + + using value_type = T; + // Creates a new array with the specified dimensions. explicit Array(tensorflow::gtl::ArraySlice sizes) : Array(sizes, T()) {} @@ -53,7 +108,7 @@ class Array { // Creates a 2D array from the given nested initializer list. The outer // initializer list is the first dimension, the inner is the second dimension. // For example, {{1, 2, 3}, {4, 5, 6}} results in an array with n1=2 and n2=3. - Array(std::initializer_list> values) + Array(InitializerList2D values) : Array(ToInt64Vector({values.size(), values.begin()->size()})) { int64 idx = 0; for (const auto& it1 : values) { @@ -67,8 +122,7 @@ class Array { // Creates a 3D array from the given nested initializer list. The outer // initializer list is the first dimension, and so on. - Array(std::initializer_list>> - values) + Array(InitializerList3D values) : Array(ToInt64Vector({values.size(), values.begin()->size(), values.begin()->begin()->size()})) { int64 idx = 0; @@ -85,9 +139,7 @@ class Array { // Creates a 4D array from the given nested initializer list. The outer // initializer list is the first dimension, and so on. - Array(std::initializer_list< - std::initializer_list>>> - values) + Array(InitializerList4D values) : Array(ToInt64Vector({values.size(), values.begin()->size(), values.begin()->begin()->size(), values.begin()->begin()->begin()->size()})) { @@ -173,10 +225,46 @@ class Array { } } + // Invokes a callback with the (indices, value_ptr) for each cell in the + // array. If a callback returns a non-OK status, returns that else returns + // Status::OK(). + Status EachStatus( + std::function, T*)> f) { + std::vector index(sizes_.size()); + for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) { + Status s = f(index, &values_[i]); + if (!s.ok()) { + return s; + } + } + return Status::OK(); + } + + // Invokes a callback with the (indices, value) for each cell in the array. + // If a callback returns a non-OK status, returns that else returns + // Status::OK(). + Status EachStatus( + std::function, T)> f) const { + std::vector index(sizes_.size()); + for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) { + Status s = f(index, values_[i]); + if (!s.ok()) { + return s; + } + } + return Status::OK(); + } + // Returns the value at the cell specified by the indexes. The number of // arguments have to match with the number of dimensions for the array. + // + // The type trait is required to avoid this overload participating too + // eagerly; a parameter pack can take zero or more elements, so we must + // restrict this to only parameter packs that are all of integral type. template - const T& operator()(Dims... dims) const { + typename std::enable_if::value, + const T&>::type + operator()(Dims... dims) const { // We are using a std::array to avoid having to allocate memory in this // function for performance reasons. std::array indexes{{static_cast(dims)...}}; @@ -186,7 +274,9 @@ class Array { // Returns the value at the cell specified by the indexes. The number of // arguments have to match with the number of dimensions for the array. template - T& operator()(Dims... dims) { + typename std::enable_if::value, + T&>::type + operator()(Dims... dims) { // We are using a std::array to avoid having to allocate memory in this // function for performance reasons. std::array indexes{{static_cast(dims)...}}; @@ -255,6 +345,59 @@ class Array { bool operator!=(const Array& other) const { return !(*this == other); } + // Performs the equivalent of a slice operation on this array. + Array Slice(tensorflow::gtl::ArraySlice starts, + tensorflow::gtl::ArraySlice limits) const { + CHECK_EQ(starts.size(), num_dimensions()); + CHECK_EQ(limits.size(), num_dimensions()); + + std::vector sizes; + std::transform(starts.begin(), starts.end(), limits.begin(), + std::back_inserter(sizes), + [](int64 start, int64 limit) { return limit - start; }); + Array result(sizes); + + std::vector index(sizes_.size()); + int64 slice_i = 0; + for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) { + if (array_impl::all_inside_range(index, starts, limits)) { + // Even though the bounds of result are different to our bounds, we're + // iterating in the same order. So we can simply write successive linear + // indices instead of recalculating a multi-dimensional index. + result.values_[slice_i++] = values_[i]; + } + } + return result; + } + + // Performs the equivalent of a DynamicUpdateSlice in-place on this array. + void UpdateSlice(const Array& from, + tensorflow::gtl::ArraySlice start_indices) { + CHECK_EQ(from.num_dimensions(), num_dimensions()); + std::vector limit_indices; + std::transform(start_indices.begin(), start_indices.end(), + from.dimensions().begin(), std::back_inserter(limit_indices), + std::plus{}); + std::vector index(sizes_.size()); + int64 from_i = 0; + for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) { + if (array_impl::all_inside_range(index, start_indices, limit_indices)) { + // Even though the bounds of from are different to our bounds, we're + // iterating in the same order. So we can simply write successive linear + // indices instead of recalculating a multi-dimensional index. + values_[i] = from.values_[from_i++]; + } + } + } + + // Performs an in-place reshape, modifying the dimensions but not the + // underlying data. + void Reshape(tensorflow::gtl::ArraySlice new_dimensions) { + int64 old_num_elements = num_elements(); + sizes_ = std::vector(new_dimensions.begin(), new_dimensions.end()); + CHECK_EQ(num_elements(), old_num_elements); + } + // Returns a string representation of the array suitable for debugging. string ToString() const { std::vector pieces; diff --git a/tensorflow/compiler/xla/array3d.h b/tensorflow/compiler/xla/array3d.h index e9449f01ad69a5722f53cce09e2884e20a0def5a..a1c5840a5f3874e27043c821ed4684da2fa6c542 100644 --- a/tensorflow/compiler/xla/array3d.h +++ b/tensorflow/compiler/xla/array3d.h @@ -36,6 +36,8 @@ namespace xla { template class Array3D : public Array { public: + Array3D() : Array(std::vector{0, 0, 0}) {} + // Creates an array of dimensions n1 x n2 x n3, uninitialized values. Array3D(const int64 n1, const int64 n2, const int64 n3) : Array(std::vector{n1, n2, n3}) {} diff --git a/tensorflow/compiler/xla/array_test.cc b/tensorflow/compiler/xla/array_test.cc index 093784f541b3bd18f4a1fc1b665cd0d17a892f28..8b9419477479d952126fd831eb44899e7649ca71 100644 --- a/tensorflow/compiler/xla/array_test.cc +++ b/tensorflow/compiler/xla/array_test.cc @@ -71,6 +71,19 @@ TEST(ArrayTest, IndexingReadWrite) { EXPECT_EQ(arr(1, 2), 61); } +TEST(ArrayTest, DynamicIndexingReadWrite) { + Array arr({2, 3}); + + std::vector index1 = {1, 1}; + std::vector index2 = {1, 2}; + EXPECT_EQ(arr(index1), 0); + EXPECT_EQ(arr(index2), 0); + arr(index1) = 51; + arr(index2) = 61; + EXPECT_EQ(arr(1, 1), 51); + EXPECT_EQ(arr(1, 2), 61); +} + TEST(ArrayTest, IndexingReadWriteBool) { Array arr{{false, true, false}, {false, true, false}}; @@ -141,5 +154,37 @@ TEST(ArrayTest, Each) { EXPECT_EQ(arr.num_elements() * (arr.num_elements() - 1) / 2, each_sum); } +TEST(ArrayTest, Slice) { + Array arr({2, 4}); + arr.FillWithMultiples(1); + + Array identity_slice = arr.Slice({0, 0}, {2, 4}); + EXPECT_EQ(identity_slice.dimensions(), arr.dimensions()); + for (auto it1 = arr.begin(), it2 = identity_slice.begin(), e = arr.end(); + it1 != e; ++it1, ++it2) { + EXPECT_EQ(*it1, *it2); + } + + Array sub_slice = arr.Slice({1, 0}, {2, 2}); + EXPECT_EQ(sub_slice.dimensions(), (std::vector{1, 2})); + const string expected = R"([[4, 5]])"; + EXPECT_EQ(expected, sub_slice.ToString()); +} + +TEST(ArrayTest, UpdateSlice) { + Array arr({3, 4}); + arr.FillWithMultiples(1); + + Array sub_arr({2, 2}); + sub_arr.FillWithMultiples(3); + + arr.UpdateSlice(sub_arr, {1, 1}); + + const string expected = R"([[0, 1, 2, 3], + [4, 0, 3, 7], + [8, 6, 9, 11]])"; + EXPECT_EQ(expected, arr.ToString()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc index 92cd8e729d659c4ff24c156d89f29275848c3cee..66937d64aff18817bbd5310e0c24e19556e9d727 100644 --- a/tensorflow/compiler/xla/client/client.cc +++ b/tensorflow/compiler/xla/client/client.cc @@ -142,8 +142,7 @@ StatusOr> Client::TransferFromOutfeed( "TransferToClient request"); } - Literal literal(response.literal()); - return MakeUnique(literal); + return MakeUnique(response.literal()); } Status Client::ResetDevice() { diff --git a/tensorflow/compiler/xla/client/client.h b/tensorflow/compiler/xla/client/client.h index a716159f9e74041c4823ad20b46fa94c2d7b9d8c..c28380b689c7a0e16bf0bcbf15003f4aa15e42a7 100644 --- a/tensorflow/compiler/xla/client/client.h +++ b/tensorflow/compiler/xla/client/client.h @@ -67,6 +67,15 @@ class Client { std::vector arguments; ExecutionOptions execution_options; ExecutionProfile* execution_profile; + + ComputationInstance(const Computation& computation, + std::vector arguments, + ExecutionOptions execution_options, + ExecutionProfile* execution_profile) + : computation(computation), + arguments(std::move(arguments)), + execution_options(execution_options), + execution_profile(execution_profile) {} }; // Executes a list ComputationInstances and returns global data produced from @@ -133,7 +142,7 @@ class Client { // Returns a vector of global data handles that point to the tuple elements. StatusOr>> DeconstructTuple( - const GlobalData& computation); + const GlobalData& data); // Retrieves the statistics of the given computation. StatusOr GetComputationStats( diff --git a/tensorflow/compiler/xla/client/computation_builder.cc b/tensorflow/compiler/xla/client/computation_builder.cc index 24774c4c2a385d9aabd22a550bd8be3acf409d85..317dcb4e41723b93e7e50d911f16e48bc3505a09 100644 --- a/tensorflow/compiler/xla/client/computation_builder.cc +++ b/tensorflow/compiler/xla/client/computation_builder.cc @@ -153,6 +153,7 @@ bool ComputationBuilder::MakeWindow( } else { dim->set_window_dilation(1); } + dim->set_window_reversal(false); } return true; } @@ -624,7 +625,41 @@ ComputationDataHandle ComputationBuilder::Lt( ComputationDataHandle ComputationBuilder::Dot( const ComputationDataHandle& lhs, const ComputationDataHandle& rhs) { - return BinaryOp(BINOP_DOT, lhs, rhs, /*broadcast_dimensions=*/{}); + StatusOr> lhs_shape_or_status = GetShape(lhs); + if (!lhs_shape_or_status.ok()) { + NoteError(lhs_shape_or_status.status()); + return ComputationDataHandle(); + } + std::unique_ptr lhs_shape = lhs_shape_or_status.ConsumeValueOrDie(); + + DotDimensionNumbers dimension_numbers; + dimension_numbers.add_lhs_contracting_dimensions( + lhs_shape->dimensions_size() == 1 ? 0 : 1); + dimension_numbers.add_rhs_contracting_dimensions(0); + return DotGeneral(lhs, rhs, dimension_numbers); +} + +ComputationDataHandle ComputationBuilder::DotGeneral( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + const DotDimensionNumbers& dimension_numbers) { + if (!first_error_.ok() || !PrepareComputation().ok()) { + return ComputationDataHandle(); + } + + DotRequest request; + *request.mutable_lhs() = lhs; + *request.mutable_rhs() = rhs; + *request.mutable_dimension_numbers() = dimension_numbers; + + OpRequest op_request; + *op_request.mutable_computation() = computation_.handle(); + *op_request.mutable_dot_request() = request; + AddCommonFieldsToOpRequest(&op_request); + OpResponse response; + + VLOG(2) << "making Dot request"; + Status s = client_->stub()->Op(&op_request, &response); + return ParseOpResponse(s, &response); } ComputationDataHandle ComputationBuilder::Conv( @@ -693,11 +728,15 @@ bool ComputationBuilder::VerifyConvolution( } return true; }; - return check_spatial_dimensions("spatial_dimensions", - dimension_numbers.spatial_dimensions()) && + return check_spatial_dimensions( + "input_spatial_dimensions", + dimension_numbers.input_spatial_dimensions()) && check_spatial_dimensions( "kernel_spatial_dimensions", - dimension_numbers.kernel_spatial_dimensions()); + dimension_numbers.kernel_spatial_dimensions()) && + check_spatial_dimensions( + "output_spatial_dimensions", + dimension_numbers.output_spatial_dimensions()); } ComputationDataHandle ComputationBuilder::ConvWithGeneralDimensions( @@ -729,11 +768,11 @@ ComputationDataHandle ComputationBuilder::ConvWithGeneralDimensions( } std::vector base_area_dimensions( - dimension_numbers.spatial_dimensions_size()); + dimension_numbers.input_spatial_dimensions_size()); for (std::vector::size_type i = 0; i < base_area_dimensions.size(); ++i) { base_area_dimensions[i] = - lhs_shape->dimensions(dimension_numbers.spatial_dimensions(i)); + lhs_shape->dimensions(dimension_numbers.input_spatial_dimensions(i)); } std::vector window_dimensions( @@ -1163,6 +1202,34 @@ ComputationDataHandle ComputationBuilder::ConvertElementType( return ParseOpResponse(s, &response); } +ComputationDataHandle ComputationBuilder::BitcastConvertType( + const ComputationDataHandle& operand, PrimitiveType new_element_type) { + if (!first_error_.ok() || !PrepareComputation().ok()) { + return ComputationDataHandle(); + } + + StatusOr> shape_status = GetShape(operand); + if (!shape_status.ok()) { + first_error_ = shape_status.status(); + return ComputationDataHandle(); + } + std::unique_ptr original = shape_status.ConsumeValueOrDie(); + + ConvertRequest request; + *request.mutable_operand() = operand; + request.set_new_element_type(new_element_type); + OpRequest op_request; + *op_request.mutable_computation() = computation_.handle(); + *op_request.mutable_bitcast_convert_request() = request; + AddCommonFieldsToOpRequest(&op_request); + OpResponse response; + + VLOG(2) << "making bitcast convert request"; + Status s = client_->stub()->Op(&op_request, &response); + + return ParseOpResponse(s, &response); +} + ComputationDataHandle ComputationBuilder::SquareF32( const ComputationDataHandle& operand) { return BinaryOp(BINOP_POW, operand, ConstantR0(2.0), @@ -1309,7 +1376,7 @@ Status ComputationBuilder::SetReturnValue( } StatusOr ComputationBuilder::IsConstant( - const ComputationDataHandle& operand) { + const ComputationDataHandle& operand, int64 num_parameters) { if (!first_error_.ok()) { return first_error_; } @@ -1317,6 +1384,7 @@ StatusOr ComputationBuilder::IsConstant( IsConstantRequest request; *request.mutable_computation() = computation_.handle(); *request.mutable_operand() = operand; + request.set_num_parameters(num_parameters); IsConstantResponse response; VLOG(2) << "making IsConstant request"; @@ -1330,7 +1398,8 @@ StatusOr ComputationBuilder::IsConstant( } StatusOr> ComputationBuilder::ComputeConstant( - const ComputationDataHandle& operand, const Layout* output_layout) { + const ComputationDataHandle& operand, const Layout* output_layout, + tensorflow::gtl::ArraySlice parameters) { if (!first_error_.ok()) { return first_error_; } @@ -1341,6 +1410,9 @@ StatusOr> ComputationBuilder::ComputeConstant( if (output_layout != nullptr) { *request.mutable_output_layout() = *output_layout; } + for (const auto& param : parameters) { + *request.add_parameters() = param.ToProto(); + } ComputeConstantResponse response; @@ -1432,6 +1504,34 @@ ComputationDataHandle ComputationBuilder::While( return ParseOpResponse(s, &response); } +ComputationDataHandle ComputationBuilder::Conditional( + const ComputationDataHandle& predicate, + const ComputationDataHandle& true_operand, + const Computation& true_computation, + const ComputationDataHandle& false_operand, + const Computation& false_computation) { + if (!first_error_.ok() || !PrepareComputation().ok()) { + return ComputationDataHandle(); + } + + ConditionalRequest request; + *request.mutable_predicate() = predicate; + *request.mutable_true_operand() = true_operand; + *request.mutable_true_computation() = true_computation.handle(); + *request.mutable_false_operand() = false_operand; + *request.mutable_false_computation() = false_computation.handle(); + OpRequest op_request; + *op_request.mutable_computation() = computation_.handle(); + *op_request.mutable_conditional_request() = request; + AddCommonFieldsToOpRequest(&op_request); + OpResponse response; + + VLOG(2) << "making conditional op request"; + Status s = client_->stub()->Op(&op_request, &response); + + return ParseOpResponse(s, &response); +} + ComputationDataHandle ComputationBuilder::Reduce( const ComputationDataHandle& operand, const ComputationDataHandle& init_value, const Computation& computation, @@ -1811,25 +1911,27 @@ ComputationBuilder::CreateDefaultConvDimensionNumbers(int num_spatial_dims) { dimension_numbers.set_kernel_input_feature_dimension( kConvKernelInputDimension); for (int i = 0; i < num_spatial_dims; ++i) { - dimension_numbers.add_spatial_dimensions(i + 2); + dimension_numbers.add_input_spatial_dimensions(i + 2); dimension_numbers.add_kernel_spatial_dimensions(i + 2); + dimension_numbers.add_output_spatial_dimensions(i + 2); } return dimension_numbers; } /* static */ StatusOr ComputationBuilder::CreateConvDimensionNumbers( - int64 input_batch, int64 input_feature, int64 output_batch, - int64 output_feature, int64 first_spatial, int64 second_spatial, + int64 input_batch, int64 input_feature, int64 input_first_spatial, + int64 input_second_spatial, int64 output_batch, int64 output_feature, + int64 output_first_spatial, int64 output_second_spatial, int64 kernel_output_feature, int64 kernel_input_feature, int64 kernel_first_spatial, int64 kernel_second_spatial) { - if (std::set( - {input_batch, input_feature, first_spatial, second_spatial}) + if (std::set({input_batch, input_feature, input_first_spatial, + input_second_spatial}) .size() != 4) { return FailedPrecondition( "dimension numbers for the input are not unique: (%lld, %lld, %lld, " "%lld)", - input_batch, input_feature, first_spatial, second_spatial); + input_batch, input_feature, input_first_spatial, input_second_spatial); } if (std::set({kernel_output_feature, kernel_input_feature, kernel_first_spatial, kernel_second_spatial}) @@ -1840,25 +1942,28 @@ ComputationBuilder::CreateConvDimensionNumbers( kernel_output_feature, kernel_input_feature, kernel_first_spatial, kernel_second_spatial); } - if (std::set( - {output_batch, output_feature, first_spatial, second_spatial}) + if (std::set({output_batch, output_feature, output_first_spatial, + output_second_spatial}) .size() != 4) { return FailedPrecondition( "dimension numbers for the output are not unique: (%lld, %lld, %lld, " "%lld)", - output_batch, output_feature, first_spatial, second_spatial); + output_batch, output_feature, output_first_spatial, + output_second_spatial); } ConvolutionDimensionNumbers dimension_numbers; dimension_numbers.set_input_batch_dimension(input_batch); dimension_numbers.set_input_feature_dimension(input_feature); - dimension_numbers.set_output_batch_dimension(output_batch); - dimension_numbers.set_output_feature_dimension(output_feature); - dimension_numbers.add_spatial_dimensions(first_spatial); - dimension_numbers.add_spatial_dimensions(second_spatial); + dimension_numbers.add_input_spatial_dimensions(input_first_spatial); + dimension_numbers.add_input_spatial_dimensions(input_second_spatial); dimension_numbers.set_kernel_output_feature_dimension(kernel_output_feature); dimension_numbers.set_kernel_input_feature_dimension(kernel_input_feature); dimension_numbers.add_kernel_spatial_dimensions(kernel_first_spatial); dimension_numbers.add_kernel_spatial_dimensions(kernel_second_spatial); + dimension_numbers.set_output_batch_dimension(output_batch); + dimension_numbers.set_output_feature_dimension(output_feature); + dimension_numbers.add_output_spatial_dimensions(output_first_spatial); + dimension_numbers.add_output_spatial_dimensions(output_second_spatial); return dimension_numbers; } diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h index d282174947970ab13a8b29ba4212d56ceb0c572a..97531cdc750094adeeb2378d53ebc82cced1cbd8 100644 --- a/tensorflow/compiler/xla/client/computation_builder.h +++ b/tensorflow/compiler/xla/client/computation_builder.h @@ -68,6 +68,7 @@ class ShardingBuilder { const TileAssignment& tile_assignment) { OpSharding result; result.set_type(OpSharding::Type::OpSharding_Type_OTHER); + *result.mutable_tile_shape() = tile_shape; for (int64 dim : tile_assignment.dimensions()) { result.add_tile_assignment_dimensions(dim); } @@ -120,23 +121,23 @@ class ComputationBuilder { // result, OpMetadata is set on the Computation Builder. All subsequent // instructions generated via this Computation Builder will have the same // OpMetadata attached until a call to ClearOpMetdata. - void SetOpMetadata(const OpMetadata& metadata) { - metadata_ = metadata; - } + void SetOpMetadata(const OpMetadata& metadata) { metadata_ = metadata; } // Clears the HloMetadata state. - void ClearOpMetadata() { - metadata_.Clear(); - } + void ClearOpMetadata() { metadata_.Clear(); } - // Sets an OpDeviceAssignment that will be attached to all instructions - // until cleared. + // Sets an OpSharding that will be attached to all instructions until cleared. void SetSharding(const OpSharding& sharding) { sharding_ = sharding; } - // Clears the device assignment. Ops will be placed according to the default - // placement policy. + // Clears the sharding. Ops will be sharded according to the default placement + // policy. void ClearSharding() { sharding_ = tensorflow::gtl::nullopt; } + // Returns the OpSharding that will be attached to all instructions. + const tensorflow::gtl::optional& sharding() const { + return sharding_; + } + // Sets the builder to a mode where it will die immediately when an error is // encountered, rather than producing it in a deferred fashion when Build() is // called (which is the default). @@ -392,6 +393,11 @@ class ComputationBuilder { ComputationDataHandle Dot(const ComputationDataHandle& lhs, const ComputationDataHandle& rhs); + // Enqueues a general dot instruction onto the computation. + ComputationDataHandle DotGeneral( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + const DotDimensionNumbers& dimension_numbers); + // Default dimension numbers used for a 2D convolution. static constexpr int64 kConvBatchDimension = 0; static constexpr int64 kConvFeatureDimension = 1; @@ -412,8 +418,9 @@ class ComputationBuilder { // Creates a ConvolutionDimensionNumbers with the given arguments. Returns an // error if either the input or the weight dimension numbers have conflicts. static StatusOr CreateConvDimensionNumbers( - int64 input_batch, int64 input_feature, int64 output_batch, - int64 output_feature, int64 first_spatial, int64 second_spatial, + int64 input_batch, int64 input_feature, int64 input_first_spatial, + int64 input_second_spatial, int64 output_batch, int64 output_feature, + int64 output_first_spatial, int64 output_second_spatial, int64 kernel_output_feature, int64 kernel_input_feature, int64 kernel_first_spatial, int64 kernel_second_spatial); @@ -668,6 +675,13 @@ class ComputationBuilder { ComputationDataHandle ConvertElementType(const ComputationDataHandle& operand, PrimitiveType new_element_type); + // Enqueues a no-op instruction onto the computation that changes + // the element type of the operand array to primitive_type. The + // bit-widths of the source and destination element types must be + // identical. + ComputationDataHandle BitcastConvertType(const ComputationDataHandle& operand, + PrimitiveType new_element_type); + // Enqueues a float32 reciprocal instruction onto the computation. // (float32 is specified as there is an implicit float32 -1.0f constant // exponent). @@ -727,6 +741,13 @@ class ComputationBuilder { const Computation& body, const ComputationDataHandle& init); + // Enqueues a conditional node onto the computation. + ComputationDataHandle Conditional(const ComputationDataHandle& predicate, + const ComputationDataHandle& true_operand, + const Computation& true_computation, + const ComputationDataHandle& false_operand, + const Computation& false_computation); + // Enqueues a ReducePrecision node onto the computation. ComputationDataHandle ReducePrecision(const ComputationDataHandle& operand, const int exponent_bits, @@ -742,11 +763,12 @@ class ComputationBuilder { ComputationDataHandle Recv(const Shape& shape, const ChannelHandle& handle); // Returns true if 'operand' is a compile-time constant. A compile-time - // constant does not depend on parameters, or on stateful operators such - // as `RngNormal` or `Infeed`. Unlike `ComputeConstant`, `IsConstant` tests - // whether a computation is a compile-time constant without evaluating the - // computation. - StatusOr IsConstant(const ComputationDataHandle& operand); + // constant does not depend on parameters with higher index then + // `num_parameters`, or on stateful operators such as `RngNormal` or `Infeed`. + // Unlike `ComputeConstant`, `IsConstant` tests whether a computation is a + // compile-time constant without evaluating the computation. + StatusOr IsConstant(const ComputationDataHandle& operand, + int64 num_parameters = 0); // Normalizes operand across spatial and batch dimensions for each feature. // @@ -791,7 +813,7 @@ class ComputationBuilder { float epsilon, int64 feature_index); // Computes the value of a constant indicated by a - // ComputationDataHandle. + // ComputationDataHandle using a non-optimized interpreter on the host. // // The operand must be from the computation currently being built - // i.e., returned from this builder with no intervening call to @@ -799,8 +821,11 @@ class ComputationBuilder { // that may stop working at any time. // // The operand must represent a constant value, which in this case - // means that it must not statically depend on a parameter to the - // computation that is being built. + // means that it must not statically depend on any parameter of the + // computation that is being built other then the ones specified on the + // paramtere list. The parameters in the list will be indexed by their + // parameter id property so the number of parameters specified should be at + // least as many as the largest used parameter index. // // `IsConstant` can be used to test whether a computation is a compile-time // constant without evaluation it. `ComputeConstant` only succeeds for @@ -818,7 +843,8 @@ class ComputationBuilder { // will be stored using that layout. StatusOr> ComputeConstant( const ComputationDataHandle& operand, - const Layout* output_layout = nullptr); + const Layout* output_layout = nullptr, + tensorflow::gtl::ArraySlice parameters = {}); // Returns a new ComputationBuilder whose resultant Computation is used only // by this ComputationBuilder. The sub-ComputationBuilder has the same @@ -1038,6 +1064,33 @@ ComputationDataHandle ComputationBuilder::ConstantR4FromArray4D( return ConstantFromArray(values); } +// RAII-style object: sets the current sharding assignment in builder on +// construction, and sets back to the previous assignment on destruction. +class ScopedShardingAssignment { + public: + ScopedShardingAssignment(xla::ComputationBuilder* builder, + tensorflow::gtl::optional sharding) + : builder_(builder), prev_sharding_(builder->sharding()) { + SetSharding(sharding); + } + + ~ScopedShardingAssignment() { SetSharding(prev_sharding_); } + + private: + void SetSharding(const tensorflow::gtl::optional& sharding) { + if (sharding.has_value()) { + builder_->SetSharding(sharding.value()); + } else { + builder_->ClearSharding(); + } + } + + xla::ComputationBuilder* const builder_; + tensorflow::gtl::optional prev_sharding_; + + TF_DISALLOW_COPY_AND_ASSIGN(ScopedShardingAssignment); +}; + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_CLIENT_COMPUTATION_BUILDER_H_ diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index ee3468208792879c3fe4ff5860e434ef5a0c0155..fca2bf2688cd21b44f099da3bae3b890cbb069ab 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -44,6 +44,7 @@ cc_library( "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", ], ) diff --git a/tensorflow/compiler/xla/client/lib/testing.cc b/tensorflow/compiler/xla/client/lib/testing.cc index e6645e4941bd04c658b67117bb689f6fdef7dfc1..5f2b55713e342aa3d0251386d57cb52481fe748d 100644 --- a/tensorflow/compiler/xla/client/lib/testing.cc +++ b/tensorflow/compiler/xla/client/lib/testing.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -48,65 +49,9 @@ std::unique_ptr MakeFakeDataViaDeviceOrDie(const Shape& shape, } // namespace -StatusOr> MakeFakeLiteral(const Shape& shape) { - if (ShapeUtil::IsTuple(shape)) { - std::vector> elements; - for (const Shape& element_shape : shape.tuple_shapes()) { - TF_ASSIGN_OR_RETURN(std::unique_ptr element, - MakeFakeLiteral(element_shape)); - elements.push_back(std::move(element)); - } - return Literal::MakeTupleOwned(std::move(elements)); - } - std::unique_ptr literal = Literal::CreateFromShape(shape); - std::minstd_rand0 engine; - switch (shape.element_type()) { - case F32: { - std::uniform_real_distribution generator(0.0f, 1.0f); - TF_CHECK_OK(literal->Populate( - [&](tensorflow::gtl::ArraySlice /*indices*/) { - return generator(engine); - })); - break; - } - case S32: { - std::uniform_int_distribution generator( - std::numeric_limits::lowest(), - std::numeric_limits::max()); - TF_CHECK_OK(literal->Populate( - [&](tensorflow::gtl::ArraySlice /*indices*/) { - return generator(engine); - })); - break; - } - case S64: { - std::uniform_int_distribution generator( - std::numeric_limits::lowest(), - std::numeric_limits::max()); - TF_CHECK_OK(literal->Populate( - [&](tensorflow::gtl::ArraySlice /*indices*/) { - return generator(engine); - })); - break; - } - case PRED: { - std::uniform_int_distribution generator(0, 1); - TF_CHECK_OK(literal->Populate( - [&](tensorflow::gtl::ArraySlice /*indices*/) { - return generator(engine); - })); - break; - } - default: - return Unimplemented("Unsupported type for fake literal generation: %s", - ShapeUtil::HumanString(shape).c_str()); - } - return std::move(literal); -} - std::unique_ptr MakeFakeDataOrDie(const Shape& shape, Client* client) { - if (ShapeUtil::ByteSizeOf(shape) < (1LL << 30)) { + if (ShapeUtil::ByteSizeOf(shape) < (1LL << 20)) { StatusOr> literal_status = MakeFakeLiteral(shape); if (!literal_status.ok()) { // If we got an Unimplemented error, fall back to making the fake data via diff --git a/tensorflow/compiler/xla/client/lib/testing.h b/tensorflow/compiler/xla/client/lib/testing.h index b5c4393dcc3e37c03a5b0e1a806b0f8b07a132ed..7e640d1307edcc3e2c021f4391c456f578a015ee 100644 --- a/tensorflow/compiler/xla/client/lib/testing.h +++ b/tensorflow/compiler/xla/client/lib/testing.h @@ -26,10 +26,6 @@ limitations under the License. namespace xla { -// Generates fake data in a literal of the given shape, or returns an error -// status if the element type is currently unhandled for fake data generation. -StatusOr> MakeFakeLiteral(const Shape& shape); - // Generates fake data of the given shape on the device or dies. The fake data // is created by performing a computation on the device rather than transferring // data from the host to the device. diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index 15c744ecd349e91dc703bec5708d78a896f132c3..b051955f0fd85b7ca886bc0238068aeb94427209 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -27,16 +27,6 @@ namespace se = ::perftools::gputools; namespace xla { -ExecutableBuildOptions& ExecutableBuildOptions::set_platform( - perftools::gputools::Platform* platform) { - platform_ = platform; - return *this; -} - -perftools::gputools::Platform* ExecutableBuildOptions::platform() const { - return platform_; -} - ExecutableBuildOptions& ExecutableBuildOptions::set_device_ordinal( int device_ordinal) { device_ordinal_ = device_ordinal; @@ -56,16 +46,6 @@ const Shape* ExecutableBuildOptions::result_layout() const { return result_layout_set_ ? &result_layout_ : nullptr; } -ExecutableBuildOptions& ExecutableBuildOptions::set_has_hybrid_result( - bool has_hybrid_result) { - has_hybrid_result_ = has_hybrid_result; - return *this; -} - -bool ExecutableBuildOptions::has_hybrid_result() const { - return has_hybrid_result_; -} - namespace { StatusOr BorrowStreamForDevice(int device_ordinal, Backend* backend) { @@ -230,9 +210,9 @@ tensorflow::Status LocalExecutable::RecordArguments( SessionModule* session_module) { session_module->clear_arguments(); for (const ShapedBuffer* argument : arguments) { - Literal literal; - TF_RETURN_IF_ERROR(LiteralFromShapedBuffer(*argument, &literal)); - *session_module->add_arguments() = literal.ToProto(); + TF_ASSIGN_OR_RETURN(std::unique_ptr literal, + LiteralFromShapedBuffer(*argument)); + *session_module->add_arguments() = literal->ToProto(); } return Status::OK(); } @@ -240,21 +220,19 @@ tensorflow::Status LocalExecutable::RecordArguments( tensorflow::Status LocalExecutable::RecordResult( const ShapedBuffer* result, SessionModule* session_module) { session_module->clear_result(); - Literal literal(session_module->result()); - TF_RETURN_IF_ERROR(LiteralFromShapedBuffer(*result, &literal)); - *session_module->mutable_result() = literal.ToProto(); + TF_ASSIGN_OR_RETURN(std::unique_ptr literal, + LiteralFromShapedBuffer(*result)); + *session_module->mutable_result() = literal->ToProto(); return Status::OK(); } -// TODO(dnovillo) Change signature to return StatusOr. -tensorflow::Status LocalExecutable::LiteralFromShapedBuffer( - const ShapedBuffer& shaped_buffer, Literal* literal) { +StatusOr> LocalExecutable::LiteralFromShapedBuffer( + const ShapedBuffer& shaped_buffer) { TF_ASSIGN_OR_RETURN( se::StreamExecutor * executor, backend_->stream_executor(shaped_buffer.device_ordinal())); - return backend_->transfer_manager()->TransferLiteralFromDevice( - executor, shaped_buffer.buffer({}), shaped_buffer.shape(), - shaped_buffer.shape(), literal); + return backend_->transfer_manager()->TransferLiteralFromDevice(executor, + shaped_buffer); } se::Platform* LocalClient::platform() const { @@ -297,9 +275,6 @@ StatusOr> LocalClient::Compile( device_ordinal, options)); } -// Copy the literal data to the device with the given ordinal and return as a -// ScopedShapedBuffer. The given memory allocator is used for device memory -// allocation. StatusOr> LocalClient::LiteralToShapedBuffer(const Literal& literal, int device_ordinal, DeviceMemoryAllocator* allocator) { @@ -308,46 +283,42 @@ LocalClient::LiteralToShapedBuffer(const Literal& literal, int device_ordinal, } TF_ASSIGN_OR_RETURN( auto scoped_buffer, - ScopedShapedBuffer::Allocate(literal.shape(), allocator, device_ordinal)); + ScopedShapedBuffer::Allocate( + literal.shape(), allocator, device_ordinal, + [this](const Shape& shape) { + return backend().transfer_manager()->GetByteSizeRequirement(shape); + })); TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, backend().stream_executor(device_ordinal)); - TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( - literal.shape(), [&](const Shape& subshape, const ShapeIndex& index) { - if (ShapeUtil::IsArray(subshape)) { - // This is a leaf of the shape. Transfer the literal array data to the - // device buffer. - return backend().transfer_manager()->TransferLiteralToDevice( - executor, literal.GetSubliteral(index), - scoped_buffer->mutable_buffer(index)); - } - return Status::OK(); - })); + TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice( + executor, literal, *scoped_buffer)); return std::move(scoped_buffer); } -// Copy the data from the device contained in the given ShapedBuffer and -// return as a Literal. StatusOr> LocalClient::ShapedBufferToLiteral( const ShapedBuffer& shaped_buffer) { - std::unique_ptr literal = - Literal::CreateFromShape(shaped_buffer.shape()); TF_ASSIGN_OR_RETURN( se::StreamExecutor * executor, backend().stream_executor(shaped_buffer.device_ordinal())); - TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( - literal->shape(), [&](const Shape& subshape, const ShapeIndex& index) { - if (ShapeUtil::IsArray(subshape)) { - // This is a leaf of the shape. Transfer the device buffer into the - // literal. The layout of the literal and the device buffer are - // necessarily the same so we pass 'subshape' for both device and - // literal shapes. - return backend().transfer_manager()->TransferLiteralFromDevice( - executor, shaped_buffer.buffer(index), - /*device_shape=*/subshape, - /*literal_shape*/ subshape, &literal->GetSubliteral(index)); - } - return Status::OK(); - })); + return backend().transfer_manager()->TransferLiteralFromDevice(executor, + shaped_buffer); +} + +Status LocalClient::TransferToInfeedLocal(const Literal& literal, + int device_ordinal) { + TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, + backend().stream_executor(device_ordinal)); + return backend().transfer_manager()->TransferLiteralToInfeed(executor, + literal); +} + +StatusOr> LocalClient::TransferFromOutfeedLocal( + const Shape& shape, int device_ordinal) { + TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, + backend().stream_executor(device_ordinal)); + auto literal = MakeUnique(); + TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralFromOutfeed( + executor, shape, literal.get())); return std::move(literal); } diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h index 9f985ed5275815de2d59f6caedbbcc8060420a13..3ca0d2ef5513cfb6b0dbfbc63b311f81a318356e 100644 --- a/tensorflow/compiler/xla/client/local_client.h +++ b/tensorflow/compiler/xla/client/local_client.h @@ -37,14 +37,6 @@ namespace xla { // LocalClient::Compile. class ExecutableBuildOptions { public: - // If set, this is the platform to build the computation for. This must match - // the underlying platform of the service. A value of nullptr indicates the - // option has not been set. - // - // TODO(b/28616830): Support multiple platforms. - ExecutableBuildOptions& set_platform(perftools::gputools::Platform* platform); - perftools::gputools::Platform* platform() const; - // If set, this is the device to build the computation for. Valid // device_ordinal values are: 0 to # of devices - 1. These values are // identical to the device ordinal values used by StreamExecutor. The built @@ -61,18 +53,10 @@ class ExecutableBuildOptions { ExecutableBuildOptions& set_result_layout(const Shape& shape_with_layout); const Shape* result_layout() const; - // If set, the executable will be built to output a hybrid - // ShapedBuffer with top-level tuple pointers in host memory and - // result buffers in device memory. - ExecutableBuildOptions& set_has_hybrid_result(bool has_hybrid_result); - bool has_hybrid_result() const; - private: - perftools::gputools::Platform* platform_ = nullptr; int device_ordinal_ = -1; Shape result_layout_; bool result_layout_set_ = false; - bool has_hybrid_result_ = true; }; class LocalExecutable { @@ -129,9 +113,9 @@ class LocalExecutable { tensorflow::Status RecordResult(const ShapedBuffer* result, SessionModule* session_module); - // Copies the contents of a ShapedBuffer into a Literal proto. - tensorflow::Status LiteralFromShapedBuffer(const ShapedBuffer& shaped_buffer, - Literal* literal); + // Returns a literal containing the contents of the given ShapedBuffer. + StatusOr> LiteralFromShapedBuffer( + const ShapedBuffer& shaped_buffer); // Compiled computation. std::unique_ptr executable_; @@ -178,6 +162,20 @@ class LocalClient : public Client { StatusOr> ShapedBufferToLiteral( const ShapedBuffer& shaped_buffer); + // Transfer the given literal to the infeed queue of the given device. + // TODO(b/69670845): Remove the 'Local' from the name when LocalClient does + // not inherit from Client and there is no possibility of confusion with + // Client::TransferToInfeed. + Status TransferToInfeedLocal(const Literal& literal, int device_ordinal); + + // Transfer and return a value of the given shape from the outfeed of the + // given device. + // TODO(b/69670845): Remove the 'Local' from the name when LocalClient does + // not inherit from Client and there is no possibility of confusion with + // Client::TransferFromOutfeed. + StatusOr> TransferFromOutfeedLocal( + const Shape& shape, int device_ordinal); + // Returns the platform that the underlying service targets. perftools::gputools::Platform* platform() const; diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc index f2cdd9669c727bb778fce495ede0faaf2d9a923d..bfafef0a40f55e13ac94b2d1750df25146081784 100644 --- a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc +++ b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc @@ -31,7 +31,6 @@ std::vector* flag_objects; std::once_flag flags_init; void SetDebugOptionsDefaults(DebugOptions* flags) { - flags->set_xla_hlo_graph_path("/tmp/"); flags->set_xla_enable_fast_math(true); flags->set_xla_llvm_enable_alias_scope_metadata(true); flags->set_xla_llvm_enable_noalias_metadata(true); @@ -117,9 +116,22 @@ void AllocateFlags() { bool_setter_for(&DebugOptions::set_xla_hlo_dump_as_graphdef), flag_values->xla_hlo_dump_as_graphdef(), "Dump HLO graphs as TensorFlow GraphDefs."), + tensorflow::Flag( + "xla_hlo_graph_sharding_color", + bool_setter_for(&DebugOptions::set_xla_hlo_graph_sharding_color), + flag_values->xla_hlo_graph_sharding_color(), + "Assign colors based on sharding assignments when generating the " + "HLO graphs."), + tensorflow::Flag( + "xla_hlo_tfgraph_device_scopes", + bool_setter_for(&DebugOptions::set_xla_hlo_tfgraph_device_scopes), + flag_values->xla_hlo_tfgraph_device_scopes(), + "When generating TensorFlow HLO graphs, if the HLO instructions " + "are assigned to a specific device, prefix the name scope with " + "\"devX\" with X being the device ordinal."), tensorflow::Flag( "xla_log_hlo_text", flag_values->mutable_xla_log_hlo_text(), - "HLO modules matching this regex will be dumped to LOG(INFO). "), + "HLO modules matching this regex will be dumped to LOG(INFO)."), tensorflow::Flag( "xla_generate_hlo_text_to", flag_values->mutable_xla_generate_hlo_text_to(), diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index 8fc8644a60ef62d7ba5e7f0cc11253742395f09b..42c9d21149a41a3d60f2cfff65d3af08d7c8b9d7 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -33,6 +33,20 @@ limitations under the License. #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" +namespace { +using tensorflow::int64; + +constexpr bool kLittleEndian = __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__; + +// Converts between little and big endian, assuming elements in the array are 16 +// bits long. +void ConvertEndianShort(char* bytes, int64 size) { + CHECK_EQ(size / 2, 0); + for (int64 i = 0; i < size; i += 2) { + std::swap(bytes[i], bytes[i + 1]); + } +} +} // namespace namespace xla { @@ -169,6 +183,8 @@ Status Literal::Copy(const Literal& src_literal, return CopyRange(src_literal, src_base, dest_base, copy_size); case F16: return CopyRange(src_literal, src_base, dest_base, copy_size); + case BF16: + return CopyRange(src_literal, src_base, dest_base, copy_size); case F32: return CopyRange(src_literal, src_base, dest_base, copy_size); case F64: @@ -200,6 +216,8 @@ Status Literal::Copy(const Literal& src_literal, return *Literal::CreateR0(0); case F16: return *Literal::CreateR0(static_cast(0.0f)); + case BF16: + return *Literal::CreateR0(static_cast(0.0f)); case F32: return *Literal::CreateR0(0); case F64: @@ -234,6 +252,10 @@ Status Literal::Copy(const Literal& src_literal, return *Literal::CreateR0(1); case S64: return *Literal::CreateR0(1); + case F16: + return *Literal::CreateR0(static_cast(1.0f)); + case BF16: + return *Literal::CreateR0(static_cast(1.0f)); case F32: return *Literal::CreateR0(1); case F64: @@ -245,8 +267,6 @@ Status Literal::Copy(const Literal& src_literal, case S16: case U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; - case F16: - return *Literal::CreateR0(static_cast(1.0f)); case TUPLE: LOG(FATAL) << "tuple element type cannot take on value of 1"; case OPAQUE: @@ -285,6 +305,9 @@ Status Literal::Copy(const Literal& src_literal, case F16: return *Literal::CreateR0( static_cast(-std::numeric_limits::infinity())); + case BF16: + return *Literal::CreateR0( + static_cast(-std::numeric_limits::infinity())); case TUPLE: LOG(FATAL) << "tuple element type has no minimum value"; case OPAQUE: @@ -321,6 +344,9 @@ Status Literal::Copy(const Literal& src_literal, case F16: return *Literal::CreateR0( static_cast(std::numeric_limits::infinity())); + case BF16: + return *Literal::CreateR0( + static_cast(std::numeric_limits::infinity())); case TUPLE: LOG(FATAL) << "tuple element type has no maximum value"; case OPAQUE: @@ -428,6 +454,7 @@ std::unique_ptr Literal::Transpose( // The shape with affine layout resulting from that operation will be // F32[8,11]{0,1}, since it leaves the original most minor (the 8 sized), the // most minor. + // // Essentially, given MinMaj(Di) the position of the Di dimension within the // minor to major vector, and given T(Di) the index that the original Di // dimension has within the transposed array, a layout is affine if @@ -536,6 +563,9 @@ string Literal::GetAsString( } case F16: return tensorflow::strings::StrCat(Get(multi_index)); + case BF16: + return tensorflow::strings::StrCat( + static_cast(Get(multi_index))); default: return tensorflow::strings::StrCat( "[", PrimitiveType_Name(shape().element_type()), "]"); @@ -569,9 +599,17 @@ int64 Literal::LinearIndex( return IndexUtil::MultidimensionalIndexToLinearIndex(shape(), multi_index); } -string Literal::ToString() const { +string Literal::ToString(bool print_layout) const { std::vector pieces; + auto shape_to_string = [print_layout](const Shape& shape) { + if (print_layout) { + return ShapeUtil::HumanStringWithLayout(shape); + } else { + return ShapeUtil::HumanString(shape); + } + }; + auto element_to_string = [this](tensorflow::gtl::ArraySlice indices) -> string { PrimitiveType element_type = shape().element_type(); @@ -585,13 +623,13 @@ string Literal::ToString() const { // TODO(b/32894291): refactor this code to reduce code duplication. if (ShapeUtil::IsTuple(shape())) { - pieces.push_back(ShapeUtil::HumanString(shape())); + pieces.push_back(shape_to_string(shape())); pieces.push_back(" (\n"); - for (const auto& element_literal : tuple_literals()) { - pieces.push_back(element_literal.ToString()); - pieces.push_back(",\n"); - } - pieces.push_back(")"); + pieces.push_back(tensorflow::str_util::Join( + tuple_literals(), ",\n", [](string* out, const Literal& element) { + tensorflow::strings::StrAppend(out, element.ToString()); + })); + pieces.push_back("\n)"); } else if (ShapeUtil::Rank(shape()) == 0) { pieces.push_back(GetAsString({})); } else if (ShapeUtil::Rank(shape()) == 1) { @@ -601,7 +639,7 @@ string Literal::ToString() const { } pieces.push_back("}"); } else if (ShapeUtil::Rank(shape()) == 2) { - pieces.push_back(ShapeUtil::HumanString(shape())); + pieces.push_back(shape_to_string(shape())); pieces.push_back(" {\n"); for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) { pieces.push_back(" { "); @@ -609,11 +647,11 @@ string Literal::ToString() const { pieces.push_back(element_to_string({i0, i1})); } pieces.push_back(" "); - pieces.push_back("},\n"); + pieces.push_back(i0 == shape().dimensions(0) - 1 ? "}\n" : "},\n"); } pieces.push_back("}"); } else if (ShapeUtil::Rank(shape()) == 3) { - pieces.push_back(ShapeUtil::HumanString(shape())); + pieces.push_back(shape_to_string(shape())); pieces.push_back(" {\n"); for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) { pieces.push_back(i0 > 0 ? ",\n{" : "{"); @@ -628,53 +666,62 @@ string Literal::ToString() const { } pieces.push_back("\n}"); } else if (ShapeUtil::Rank(shape()) == 4) { - pieces.push_back(ShapeUtil::HumanString(shape())); + pieces.push_back(shape_to_string(shape())); pieces.push_back(" {\n"); for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) { - pieces.push_back(tensorflow::strings::Printf(" { // i0=%lld\n", i0)); + pieces.push_back(tensorflow::strings::Printf(" { /*i0=%lld*/\n", i0)); for (int64 i1 = 0; i1 < shape().dimensions(1); ++i1) { pieces.push_back( - tensorflow::strings::Printf(" { // i1=%lld\n", i1)); + tensorflow::strings::Printf(" { /*i1=%lld*/\n", i1)); for (int64 i2 = 0; i2 < shape().dimensions(2); ++i2) { pieces.push_back(" {"); for (int64 i3 = 0; i3 < shape().dimensions(3); ++i3) { pieces.push_back(element_to_string({i0, i1, i2, i3})); } - pieces.push_back("},\n"); + pieces.push_back(i2 == shape().dimensions(2) - 1 ? "}\n" : "},\n"); } - pieces.push_back(" },\n"); + pieces.push_back(i1 == shape().dimensions(1) - 1 ? " }\n" + : " },\n"); } - pieces.push_back(" },\n"); + pieces.push_back(i0 == shape().dimensions(0) - 1 ? " }\n" : " },\n"); } pieces.push_back("}"); } else if (ShapeUtil::Rank(shape()) == 5) { - pieces.push_back(ShapeUtil::HumanString(shape())); + pieces.push_back(shape_to_string(shape())); pieces.push_back(" {\n"); for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) { - pieces.push_back(tensorflow::strings::Printf(" { // i0=%lld\n", i0)); + pieces.push_back(tensorflow::strings::Printf(" { /*i0=%lld*/\n", i0)); for (int64 i1 = 0; i1 < shape().dimensions(1); ++i1) { pieces.push_back( - tensorflow::strings::Printf(" { // i1=%lld\n", i1)); + tensorflow::strings::Printf(" { /*i1=%lld*/\n", i1)); for (int64 i2 = 0; i2 < shape().dimensions(2); ++i2) { pieces.push_back( - tensorflow::strings::Printf(" { // i2=%lld\n", i2)); + tensorflow::strings::Printf(" { /*i2=%lld*/\n", i2)); for (int64 i3 = 0; i3 < shape().dimensions(3); ++i3) { pieces.push_back(" {"); for (int64 i4 = 0; i4 < shape().dimensions(4); ++i4) { pieces.push_back(element_to_string({i0, i1, i2, i3, i4})); } - pieces.push_back("},\n"); + pieces.push_back(i3 == shape().dimensions(3) - 1 ? "}\n" : "},\n"); } - pieces.push_back(" },\n"); + pieces.push_back(i2 == shape().dimensions(2) - 1 ? " }\n" + : " },\n"); } - pieces.push_back(" },\n"); + pieces.push_back(i1 == shape().dimensions(1) - 1 ? " }\n" + : " },\n"); } - pieces.push_back(" },\n"); + pieces.push_back(i0 == shape().dimensions(0) - 1 ? " }\n" : " },\n"); } pieces.push_back("}"); } else { - pieces.push_back(ShapeUtil::HumanString(shape())); - pieces.push_back(" {...}"); + pieces.push_back(shape_to_string(shape())); + pieces.push_back(" {"); + EachCellAsString( + [&](tensorflow::gtl::ArraySlice indices, const string& value) { + pieces.push_back(" "); + pieces.push_back(value); + }); + pieces.push_back("}"); } return tensorflow::str_util::Join(pieces, ""); @@ -732,6 +779,8 @@ void* Literal::MutableInternalData() { return reinterpret_cast(c64s_.data()); case F16: return reinterpret_cast(f16s_.data()); + case BF16: + return reinterpret_cast(bf16s_.data()); default: LOG(FATAL) << "primitive type not supported in literals: " << PrimitiveType_Name(shape().element_type()); @@ -774,6 +823,9 @@ void Literal::Reserve(int64 num_elements) { case F16: Resize(num_elements, static_cast(0.0f)); break; + case BF16: + Resize(num_elements, static_cast(0.0f)); + break; default: LOG(FATAL) << "primitive type not supported in literals: " << PrimitiveType_Name(shape().element_type()); @@ -813,6 +865,9 @@ tensorflow::Status Literal::ValidateLiteral() const { case F16: actual = f16s().size() / sizeof(half); break; + case BF16: + actual = bf16s().size(); + break; default: return tensorflow::errors::Unimplemented( "unhandled element type for literal validation: " + @@ -909,6 +964,7 @@ StatusOr> ConvertIfDestTypeMatches( CONVERT_IF_TYPES_MATCH(F16) CONVERT_IF_TYPES_MATCH(F32) CONVERT_IF_TYPES_MATCH(F64) + CONVERT_IF_TYPES_MATCH(BF16) #undef CONVERT_IF_TYPES_MATCH case C64: return ConvertToC64(src_literal); @@ -938,8 +994,9 @@ StatusOr> Literal::Convert( CONVERT_IF_DEST_TYPE_MATCHES(F16) CONVERT_IF_DEST_TYPE_MATCHES(F32) CONVERT_IF_DEST_TYPE_MATCHES(F64) + CONVERT_IF_DEST_TYPE_MATCHES(BF16) #undef CONVERT_IF_DEST_TYPE_MATCHES - // Other types are not yet supported. + // Other types are not yet supported. default: return InvalidArgument("Unimplemented: Convert from type %s to type %s", PrimitiveType_Name(shape().element_type()).c_str(), @@ -1008,6 +1065,8 @@ bool Literal::operator==(const Literal& other) const { return EqualElements(*this, other, 0, &multi_index); case F16: return EqualElements(*this, other, 0, &multi_index); + case BF16: + return EqualElements(*this, other, 0, &multi_index); case C64: return EqualElements(*this, other, 0, &multi_index); default: @@ -1117,13 +1176,18 @@ tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { template <> tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { - // TODO - there is an endianess problem here. fix it, or wait for uint16 - // support in protobuf auto values = mutable_f16s(); return tensorflow::gtl::MutableArraySlice(values->data(), values->size()); } +template <> +tensorflow::gtl::MutableArraySlice +Literal::GetMutableArraySlice() { + auto values = mutable_bf16s(); + return {values->data(), values->size()}; +} + template <> tensorflow::gtl::ArraySlice Literal::GetArraySlice() const { CHECK_EQ(shape().element_type(), PRED); @@ -1194,6 +1258,12 @@ tensorflow::gtl::ArraySlice Literal::GetArraySlice() const { f16s().size() / sizeof(half)); } +template <> +tensorflow::gtl::ArraySlice Literal::GetArraySlice() const { + CHECK_EQ(shape().element_type(), BF16); + return {bf16s().data(), bf16s().size()}; +} + template <> tensorflow::gtl::ArraySlice Literal::GetArraySlice() const { @@ -1242,6 +1312,9 @@ bool Literal::IsAll(int8 value) const { return AllElementsEqualValue(*this, value); case F16: return AllElementsEqualValue(*this, static_cast(value)); + case BF16: + return AllElementsEqualValue(*this, + static_cast(value)); case PRED: if (value == 0) { return AllElementsEqualValue(*this, false); @@ -1263,6 +1336,9 @@ bool Literal::IsAllFloat(float value) const { return AllElementsEqualValue(*this, value); case F16: return AllElementsEqualValue(*this, static_cast(value)); + case BF16: + return AllElementsEqualValue(*this, + static_cast(value)); default: return false; } @@ -1299,6 +1375,8 @@ bool Literal::IsZero(tensorflow::gtl::ArraySlice indices) const { return Get(indices) == complex64(0.0f, 0.0f); case F16: return Get(indices) == static_cast(0.0f); + case BF16: + return Get(indices) == static_cast(0.0f); case PRED: return Get(indices) == false; default: @@ -1366,6 +1444,12 @@ void Literal::Resize(int64 num_elements, half value) { mutable_f16s()->resize(num_elements, value); } +template <> +void Literal::Resize(int64 num_elements, bfloat16 value) { + CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements); + mutable_bf16s()->resize(num_elements, value); +} + template <> void Literal::Resize(int64 num_elements, complex64 value) { CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements); @@ -1414,6 +1498,19 @@ LiteralProto Literal::ToProto() const { *proto.mutable_f16s() = string(reinterpret_cast(f16s_.data()), f16s_.size() * sizeof(half)); + if (!kLittleEndian) { + ConvertEndianShort(const_cast(proto.mutable_f16s()->data()), + proto.f16s().size()); + } + break; + case BF16: + *proto.mutable_bf16s() = + string(reinterpret_cast(bf16s_.data()), + bf16s_.size() * sizeof(bfloat16)); + if (!kLittleEndian) { + ConvertEndianShort(const_cast(proto.mutable_bf16s()->data()), + proto.bf16s().size()); + } break; case F32: CopyToRepeatedField(proto.mutable_f32s(), f32s()); @@ -1482,6 +1579,21 @@ void Literal::CopyFromProto(const LiteralProto& literal_proto) { CHECK_EQ(0, s.size() % sizeof(half)); f16s_ = std::vector(s.size() / sizeof(half)); memcpy(f16s_.data(), s.data(), s.size()); + + if (!kLittleEndian) { + ConvertEndianShort(reinterpret_cast(f16s_.data()), s.size()); + } + break; + } + case BF16: { + const string& s(literal_proto.bf16s()); + CHECK_EQ(0, s.size() % sizeof(bfloat16)); + bf16s_ = std::vector(s.size() / sizeof(bfloat16)); + memcpy(bf16s_.data(), s.data(), s.size()); + + if (!kLittleEndian) { + ConvertEndianShort(reinterpret_cast(bf16s_.data()), s.size()); + } break; } case F32: diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index a1e288829f22835f94c6e3c041796f84d995211c..2981f9f8753a60f7acb7e3c6bf86f2b9da4c96d8 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -99,6 +99,7 @@ class Literal { f16s_.clear(); f32s_.clear(); f64s_.clear(); + c64s_.clear(); tuple_literals_.clear(); } @@ -163,6 +164,11 @@ class Literal { const std::vector& c64s() const { return c64s_; } std::vector* mutable_c64s() { return &c64s_; } + int bf16s_size() const { return bf16s().size(); } + bfloat16 bf16s(int i) const { return bf16s_[i]; } + const std::vector& bf16s() const { return bf16s_; } + std::vector* mutable_bf16s() { return &bf16s_; } + int tuple_literals_size() const { return tuple_literals().size(); } const Literal& tuple_literals(int i) const { return tuple_literals_[i]; } Literal* add_tuple_literals() { @@ -280,11 +286,11 @@ class Literal { std::unique_ptr Relayout(const Layout& new_layout, const ShapeIndex& shape_index = {}) const; - // Creates a new literal by reshaping this literal to have 'shape'. Both the - // original shape and 'shape' must contain the same number of elements. The + // Creates a new literal by reshaping this literal to have the given + // dimensions. The total number of elements must not change; The // implementation currently only supports monotonic dim0-major layouts. StatusOr> Reshape( - tensorflow::gtl::ArraySlice shape) const; + tensorflow::gtl::ArraySlice dimensions) const; // Creates a new literal by reordering the dimensions of this literal. // The given `permutation` must be a permutation of the dimension numbers @@ -450,7 +456,7 @@ class Literal { tensorflow::Status ValidateLiteral() const; // Returns a string representation of the literal value. - string ToString() const; + string ToString(bool print_layout = false) const; // Invokes the "per cell" callback for each element in the provided // literal with the element's indices and a string representation of @@ -622,6 +628,7 @@ class Literal { std::vector u16s_; std::vector u32s_; std::vector u64s_; + std::vector bf16s_; std::vector f16s_; std::vector f32s_; std::vector f64s_; @@ -674,6 +681,9 @@ tensorflow::gtl::ArraySlice Literal::GetArraySlice() const; template <> tensorflow::gtl::ArraySlice Literal::GetArraySlice() const; +template <> +tensorflow::gtl::ArraySlice Literal::GetArraySlice() const; + template <> tensorflow::gtl::ArraySlice Literal::GetArraySlice() const; @@ -714,6 +724,9 @@ tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); template <> tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); +template <> +tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); + template <> tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); @@ -747,6 +760,9 @@ void Literal::Resize(int64 num_elements, double value); template <> void Literal::Resize(int64 num_elements, half value); +template <> +void Literal::Resize(int64 num_elements, bfloat16 value); + template <> void Literal::Resize(int64 num_elements, complex64 value); @@ -990,6 +1006,14 @@ inline half Literal::Get( return GetArraySlice()[linear_index]; } +template <> +inline bfloat16 Literal::Get( + tensorflow::gtl::ArraySlice multi_index) const { + CHECK(shape().element_type() == BF16); + int64 linear_index = LinearIndex(multi_index); + return GetArraySlice()[linear_index]; +} + template void Literal::Set(tensorflow::gtl::ArraySlice multi_index, NativeT value) { diff --git a/tensorflow/compiler/xla/literal_util_test.cc b/tensorflow/compiler/xla/literal_util_test.cc index a9af4849e2124fd47ae42cc06ac8cc5ca5a22cb7..7ff64c4134155e7fe22ab99584970a7d6d6e8803 100644 --- a/tensorflow/compiler/xla/literal_util_test.cc +++ b/tensorflow/compiler/xla/literal_util_test.cc @@ -110,6 +110,18 @@ TEST_F(LiteralUtilTest, LiteralScalarToString) { auto c64_lit = Literal::CreateR0({3.14f, 2.78f}); ASSERT_EQ("(3.14, 2.78)", c64_lit->ToString()); + + auto bf16_lit = Literal::CreateR0(static_cast(0.5f)); + ASSERT_EQ("0.5", bf16_lit->ToString()); + + // 3.14 will be truncated to 3.125 in bfloat16 format. + auto bf16_lit_truncated = + Literal::CreateR0(static_cast(3.14f)); + ASSERT_EQ("3.125", bf16_lit_truncated->ToString()); + + auto bf16_lit_truncated2 = + Literal::CreateR0(static_cast(9.001f)); + ASSERT_EQ("9", bf16_lit_truncated2->ToString()); } TEST_F(LiteralUtilTest, LiteralVectorToString) { @@ -122,7 +134,7 @@ TEST_F(LiteralUtilTest, R2ToString) { const string expected = R"(s32[3,2] { { 1, 2 }, { 3, 4 }, - { 5, 6 }, + { 5, 6 } })"; ASSERT_EQ(expected, literal->ToString()); } @@ -148,8 +160,8 @@ TEST_F(LiteralUtilTest, TupleToString) { 1, f32[2,2] { { 1, 2 }, - { 3, 4 }, -}, + { 3, 4 } +} ))"; ASSERT_EQ(expected, tuple->ToString()); } @@ -191,18 +203,18 @@ TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) { EXPECT_THAT(literal->shape().dimensions(), ElementsAre(1, 2, 3, 2)); string result = literal->ToString(); const string expected = R"(f32[1,2,3,2] { - { // i0=0 - { // i1=0 + { /*i0=0*/ + { /*i1=0*/ {1, 2}, {1001, 1002}, - {2001, 2002}, + {2001, 2002} }, - { // i1=1 + { /*i1=1*/ {1, 2}, {1001, 1002}, - {2001, 2002}, - }, - }, + {2001, 2002} + } + } })"; ASSERT_EQ(expected, result); } @@ -212,30 +224,30 @@ TEST_F(LiteralUtilTest, LiteralR4F32Stringifies) { ElementsAre(2, 2, 3, 3)); string result = literal_r4_2x2x3x3_dim0major_->ToString(); const string expected = R"(f32[2,2,3,3] { - { // i0=0 - { // i1=0 + { /*i0=0*/ + { /*i1=0*/ {1, 2, 3}, {4, 5, 6}, - {7, 8, 9}, + {7, 8, 9} }, - { // i1=1 + { /*i1=1*/ {11, 12, 13}, {14, 15, 16}, - {17, 18, 19}, - }, + {17, 18, 19} + } }, - { // i0=1 - { // i1=0 + { /*i0=1*/ + { /*i1=0*/ {101, 102, 103}, {104, 105, 106}, - {107, 108, 109}, + {107, 108, 109} }, - { // i1=1 + { /*i1=1*/ {201, 202, 203}, {204, 205, 206}, - {207, 208, 209}, - }, - }, + {207, 208, 209} + } + } })"; ASSERT_EQ(expected, result); } @@ -397,6 +409,18 @@ TEST_F(LiteralUtilTest, IsAll) { EXPECT_FALSE(Literal::CreateR2({{h8}, {h9}})->IsAll(8)); EXPECT_FALSE(Literal::CreateR2({{h9}, {h8}})->IsAll(8)); + bfloat16 b8(8.0f); + bfloat16 b9(9.0f); + + EXPECT_TRUE(Literal::CreateR2({{b8}, {b8}})->IsAll(8)); + EXPECT_FALSE(Literal::CreateR2({{b8}, {b9}})->IsAll(8)); + EXPECT_FALSE(Literal::CreateR2({{b9}, {b8}})->IsAll(8)); + + // 9.001 will be truncated to 9.0 + bfloat16 b91(9.001f); + bfloat16 b90(9.00f); + EXPECT_TRUE(Literal::CreateR2({{b91}, {b90}})->IsAll(9.0)); + complex64 c8_9 = {8, 9}; EXPECT_FALSE(Literal::CreateR2({{c8_9}, {c8_9}})->IsAll(8)); @@ -491,7 +515,7 @@ TYPED_TEST(LiteralUtilTestTemplated, Relayout2x2) { TEST_F(LiteralUtilTest, ReshapeR0) { auto original = Literal::CreateR0(1.7f); - auto reshape = original->Reshape(/*shape=*/{}).ConsumeValueOrDie(); + auto reshape = original->Reshape(/*dimensions=*/{}).ConsumeValueOrDie(); EXPECT_EQ(*original, *reshape); } @@ -691,6 +715,30 @@ TEST_F(LiteralUtilTest, PopulateR2C64) { EXPECT_EQ(output, *expected); } +TEST_F(LiteralUtilTest, PopulateWithValueR0BF16) { + Literal output; + bfloat16 h(0.25f); + output.PopulateWithValue(h, {}); + auto expected = Literal::CreateR0(h); + EXPECT_EQ(output, *expected); +} + +TEST_F(LiteralUtilTest, PopulateWithValueR1BF16) { + Literal output; + bfloat16 h(0.5f); + output.PopulateWithValue(h, {3}); + auto expected = Literal::CreateR1({h, h, h}); + EXPECT_EQ(output, *expected); +} + +TEST_F(LiteralUtilTest, PopulateWithValueR2BF16) { + Literal output; + bfloat16 h(2.0f); + output.PopulateWithValue(h, {2, 2}); + auto expected = Literal::CreateR2({{h, h}, {h, h}}); + EXPECT_EQ(output, *expected); +} + TEST_F(LiteralUtilTest, PopulateWithValueR0F32) { Literal output; output.PopulateWithValue(2.5f, {}); @@ -975,6 +1023,14 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) { {{half(26.0), half(0.0), half(28.0), half(0.0)}, {half(0.0), half(31.0), half(0.0), half(33.0)}}, }}, layout_r4_dim0major_); + auto bf16 = Literal::CreateR4WithLayout({{ + {{bfloat16(10.0), bfloat16(0.0), bfloat16(12.0), bfloat16(0.0)}, + {bfloat16(0.0), bfloat16(15.0), bfloat16(0.0), bfloat16(17.0)}}, + {{bfloat16(0.0), bfloat16(19.0), bfloat16(0.0), bfloat16(21.0)}, + {bfloat16(22.0), bfloat16(0.0), bfloat16(24.0), bfloat16(0.0)}}, + {{bfloat16(26.0), bfloat16(0.0), bfloat16(28.0), bfloat16(0.0)}, + {bfloat16(0.0), bfloat16(31.0), bfloat16(0.0), bfloat16(33.0)}}, + }}, layout_r4_dim0major_); auto f32 = Literal::CreateR4WithLayout({{ {{10.0f, 0.0f, 12.0f, 0.0f}, {0.0f, 15.0f, 0.0f, 17.0f}}, {{0.0f, 19.0f, 0.0f, 21.0f}, {22.0f, 0.0f, 24.0f, 0.0f}}, @@ -1008,6 +1064,12 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) { conv = s8->Convert(PRED).ConsumeValueOrDie(); EXPECT_EQ(*conv, *pred); + conv = bf16->Convert(S32).ConsumeValueOrDie(); + EXPECT_EQ(*conv, *s32); + + conv = bf16->Convert(F32).ConsumeValueOrDie(); + EXPECT_EQ(*conv, *f32); + conv = pred->Convert(S32).ConsumeValueOrDie(); EXPECT_EQ(*conv, *int32_pred); diff --git a/tensorflow/compiler/xla/primitive_util.cc b/tensorflow/compiler/xla/primitive_util.cc index 2113b5e06f3eb0169be50c0ee731a903c0eece9d..2bce56b7bd2f91f20ea670d0e7ccaa432c2b5f9f 100644 --- a/tensorflow/compiler/xla/primitive_util.cc +++ b/tensorflow/compiler/xla/primitive_util.cc @@ -78,6 +78,11 @@ PrimitiveType NativeToPrimitiveType() { return F64; } +template <> +PrimitiveType NativeToPrimitiveType() { + return BF16; +} + template <> PrimitiveType NativeToPrimitiveType() { return F16; @@ -89,7 +94,7 @@ PrimitiveType NativeToPrimitiveType() { } bool IsFloatingPointType(PrimitiveType type) { - return type == F16 || type == F32 || type == F64; + return type == F16 || type == F32 || type == F64 || type == BF16; } bool IsComplexType(PrimitiveType type) { return type == C64; } @@ -118,6 +123,7 @@ int BitWidth(PrimitiveType type) { case S16: case U16: case F16: + case BF16: return 16; case U32: diff --git a/tensorflow/compiler/xla/primitive_util.h b/tensorflow/compiler/xla/primitive_util.h index a49c8b86fcfe156ea3733ce05c0fb7337cf60dce..cb4583d198b454be1432134a9f6a77dbbbe5bdd8 100644 --- a/tensorflow/compiler/xla/primitive_util.h +++ b/tensorflow/compiler/xla/primitive_util.h @@ -26,6 +26,13 @@ limitations under the License. namespace xla { namespace primitive_util { +// The number of exponent bits in a BF16 value. +const int kBFloat16ExponentBits = 8; + +// The number of mantissa bits in a BF16 value. There is an implicit leading +// 1, so there is an implicit additional bit of precision. +const int kBFloat16MantissaBits = 7; + // Returns the XLA primitive type (eg, F32) corresponding to the given // template parameter native type (eg, float). template @@ -77,6 +84,8 @@ template <> PrimitiveType NativeToPrimitiveType(); template <> PrimitiveType NativeToPrimitiveType(); +template <> +PrimitiveType NativeToPrimitiveType(); // Complex template <> @@ -167,6 +176,11 @@ struct PrimitiveTypeToNative { using type = half; }; +template <> +struct PrimitiveTypeToNative { + using type = bfloat16; +}; + // Complex template <> struct PrimitiveTypeToNative { diff --git a/tensorflow/compiler/xla/ptr_util.h b/tensorflow/compiler/xla/ptr_util.h index fa670303136ebff0c3e0e32f5c64e879c46fe964..c58c19db2cacbe9b038160f27b9bd76aa58146eb 100644 --- a/tensorflow/compiler/xla/ptr_util.h +++ b/tensorflow/compiler/xla/ptr_util.h @@ -16,7 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_PTR_UTIL_H_ #define TENSORFLOW_COMPILER_XLA_PTR_UTIL_H_ -// Utility functions for pointers. +// As this was moved to tensorflow/core/util, provide indirections here to +// maintain current functionality of the library. #include @@ -24,55 +25,27 @@ limitations under the License. #include #include -namespace xla { - -namespace internal { - -// Trait to select overloads and return types for MakeUnique. -template -struct MakeUniqueResult { - using scalar = std::unique_ptr; -}; -template -struct MakeUniqueResult { - using array = std::unique_ptr; -}; -template -struct MakeUniqueResult { - using invalid = void; -}; +#include "tensorflow/core/util/ptr_util.h" -} // namespace internal +namespace xla { -// Transfers ownership of a raw pointer to a std::unique_ptr of deduced type. -// Example: -// X* NewX(int, int); -// auto x = WrapUnique(NewX(1, 2)); // 'x' is std::unique_ptr. -// -// WrapUnique is useful for capturing the output of a raw pointer factory. -// However, prefer 'MakeUnique(args...) over 'WrapUnique(new T(args...))'. -// auto x = WrapUnique(new X(1, 2)); // works, but nonideal. -// auto x = MakeUnique(1, 2); // safer, standard, avoids raw 'new'. -// -// Note: Cannot wrap pointers to array of unknown bound (i.e. U(*)[]). template std::unique_ptr WrapUnique(T* ptr) { - static_assert(!std::is_array::value || std::extent::value != 0, - "types T[0] or T[] are unsupported"); - return std::unique_ptr(ptr); + return tensorflow::WrapUnique(ptr); } template -typename internal::MakeUniqueResult::scalar MakeUnique(Args&&... args) { - return std::unique_ptr(new T(std::forward(args)...)); +typename tensorflow::helper::MakeUniqueResult::scalar MakeUnique( + Args&&... args) { + return tensorflow::MakeUnique(std::forward(args)...); } // Overload for array of unknown bound. // The allocation of arrays needs to use the array form of new, // and cannot take element constructor arguments. template -typename internal::MakeUniqueResult::array MakeUnique(size_t n) { - return std::unique_ptr(new typename std::remove_extent::type[n]()); +typename tensorflow::helper::MakeUniqueResult::array MakeUnique(size_t n) { + return tensorflow::MakeUnique(n); } } // namespace xla diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc index 90aa9720a1e18bad06842adeead46fc3120d01dd..bdf92eaed1ff1d83cf03eec4d126677ea42c577f 100644 --- a/tensorflow/compiler/xla/reference_util.cc +++ b/tensorflow/compiler/xla/reference_util.cc @@ -102,7 +102,9 @@ ReferenceUtil::ConvArray3DGeneralDimensionsDilated( const Array3D& lhs, const Array3D& rhs, int64 kernel_stride, Padding padding, int64 lhs_dilation, int64 rhs_dilation, const ConvolutionDimensionNumbers& dnums) { - CHECK_EQ(dnums.spatial_dimensions_size(), 1); + CHECK_EQ(dnums.input_spatial_dimensions_size(), 1); + CHECK_EQ(dnums.kernel_spatial_dimensions_size(), 1); + CHECK_EQ(dnums.output_spatial_dimensions_size(), 1); // Reuse the code for Array4D-convolution by extending the 3D input into a 4D // array by adding a fourth dummy dimension of size 1 without stride, padding // and dilation. @@ -120,8 +122,9 @@ ReferenceUtil::ConvArray3DGeneralDimensionsDilated( }); // Add a second dummy spatial dimensions. ConvolutionDimensionNumbers dnums2d = dnums; - dnums2d.add_spatial_dimensions(3); + dnums2d.add_input_spatial_dimensions(3); dnums2d.add_kernel_spatial_dimensions(3); + dnums2d.add_output_spatial_dimensions(3); std::unique_ptr> convr4 = ConvArray4DGeneralDimensionsDilated( a4dlhs, a4drhs, {kernel_stride, 1}, padding, {lhs_dilation, 1}, {rhs_dilation, 1}, dnums2d); @@ -192,14 +195,26 @@ ReferenceUtil::ReduceWindow1DGeneric( const tensorflow::gtl::ArraySlice& window, const tensorflow::gtl::ArraySlice& stride, Padding padding) { std::vector dim_lengths{static_cast(operand.size())}; - auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding); + return ReduceWindow1DGeneric( + operand, init, reduce_func, window, stride, + xla::MakePadding(dim_lengths, window, stride, padding)); +} +/* static */ std::unique_ptr> +ReferenceUtil::ReduceWindow1DGeneric( + const tensorflow::gtl::ArraySlice& operand, float init, + const std::function& reduce_func, + const tensorflow::gtl::ArraySlice& window, + const tensorflow::gtl::ArraySlice& stride, + const tensorflow::gtl::ArraySlice>& padding) { + std::vector dim_lengths{static_cast(operand.size())}; std::vector window_counts(window.size(), 0); std::vector pad_low(window.size(), 0); for (int64 i = 0; i < window.size(); ++i) { + int64 padded_width = padding[i].first + dim_lengths[i] + padding[i].second; window_counts[i] = - WindowCount(dim_lengths[i], window[i], stride[i], padding); - pad_low[i] = padding_both[i].first; + window_util::StridedBound(padded_width, window[i], stride[i]); + pad_low[i] = padding[i].first; } auto result = MakeUnique>(window_counts[0]); @@ -465,9 +480,9 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated( } ordered_input_dimensions[0] = - lhs_literal->shape().dimensions(dnums.spatial_dimensions(0)); + lhs_literal->shape().dimensions(dnums.input_spatial_dimensions(0)); ordered_input_dimensions[1] = - lhs_literal->shape().dimensions(dnums.spatial_dimensions(1)); + lhs_literal->shape().dimensions(dnums.input_spatial_dimensions(1)); ordered_kernel_dimensions[0] = rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(0)); ordered_kernel_dimensions[1] = @@ -703,137 +718,4 @@ ReferenceUtil::ReduceToRowArray2D( return result; } -/* static */ std::unique_ptr> ReferenceUtil::PadArray2D( - const Array2D& operand, const PaddingConfig& padding, - const float pad) { - int64 in0 = operand.n1(); - int64 high_padding0 = padding.dimensions(0).edge_padding_high(); - int64 low_padding0 = padding.dimensions(0).edge_padding_low(); - int64 interior_padding0 = padding.dimensions(0).interior_padding(); - int64 out0 = - in0 + low_padding0 + high_padding0 + (in0 - 1) * interior_padding0; - - int64 in1 = operand.n2(); - int64 high_padding1 = padding.dimensions(1).edge_padding_high(); - int64 low_padding1 = padding.dimensions(1).edge_padding_low(); - int64 interior_padding1 = padding.dimensions(1).interior_padding(); - int64 out1 = - in1 + low_padding1 + high_padding1 + (in1 - 1) * interior_padding1; - - auto result = MakeUnique>(out0, out1); - result->Fill(pad); - int64 o0 = low_padding0; - for (int64 i0 = 0; i0 < in0; ++i0) { - int64 o1 = low_padding1; - for (int64 i1 = 0; i1 < in1; ++i1) { - if (o0 >= 0 && o1 >= 0 && o0 < out0 && o1 < out1) { - (*result)(o0, o1) = operand(i0, i1); - } - o1 += interior_padding1 + 1; - } - o0 += interior_padding0 + 1; - } - return result; -} - -/* static */ Array3D ReferenceUtil::PadArray3D( - const Array3D& operand, const PaddingConfig& padding, - const float pad) { - CHECK_EQ(padding.dimensions_size(), 3); - - const std::vector input_bounds = {operand.n1(), operand.n2(), - operand.n3()}; - std::vector pad_low(3); - std::vector pad_high(3); - std::vector pad_interior(3); - std::vector output_bounds(3); - for (int64 i = 0; i < 3; ++i) { - pad_low[i] = padding.dimensions(i).edge_padding_low(); - pad_high[i] = padding.dimensions(i).edge_padding_high(); - CHECK_LE(0, pad_low[i]); - CHECK_LE(0, pad_high[i]); - CHECK_LE(0, padding.dimensions(i).interior_padding()) << "not implemented"; - pad_interior[i] = padding.dimensions(i).interior_padding(); - - output_bounds[i] = pad_low[i] + input_bounds[i] + pad_high[i] + - (input_bounds[i] - 1) * pad_interior[i]; - } - - Array3D result(output_bounds[0], output_bounds[1], output_bounds[2]); - std::vector indices = {0, 0, 0}; - for (indices[0] = 0; indices[0] < output_bounds[0]; ++indices[0]) { - for (indices[1] = 0; indices[1] < output_bounds[1]; ++indices[1]) { - for (indices[2] = 0; indices[2] < output_bounds[2]; ++indices[2]) { - float* value = &result(indices[0], indices[1], indices[2]); - bool value_padded = false; - for (int i = 0; i < 3; ++i) { - bool in_low_padding = indices[i] < pad_low[i]; - bool in_high_padding = indices[i] >= output_bounds[i] - pad_high[i]; - if (in_low_padding || in_high_padding) { - *value = pad; - value_padded = true; - } - if (pad_interior[i] && - (indices[i] - pad_low[i]) % (pad_interior[i] + 1)) { - *value = pad; - value_padded = true; - } - } - if (value_padded) { - continue; - } - *value = operand((indices[0] - pad_low[0]) / (pad_interior[0] + 1), - (indices[1] - pad_low[1]) / (pad_interior[1] + 1), - (indices[2] - pad_low[2]) / (pad_interior[2] + 1)); - } - } - } - return result; -} - -/* static */ Array4D ReferenceUtil::PadArray4D( - const Array4D& operand, const PaddingConfig& padding, - const float pad) { - CHECK_EQ(padding.dimensions_size(), 4); - - const std::vector input_bounds = {operand.n1(), operand.n2(), - operand.n3(), operand.n4()}; - std::vector pad_low(4); - std::vector pad_high(4); - std::vector pad_interior(4); - std::vector output_bounds(4); - for (int64 i = 0; i < 4; ++i) { - pad_low[i] = padding.dimensions(i).edge_padding_low(); - pad_high[i] = padding.dimensions(i).edge_padding_high(); - CHECK_LE(0, padding.dimensions(i).interior_padding()) << "not implemented"; - pad_interior[i] = padding.dimensions(i).interior_padding(); - - output_bounds[i] = pad_low[i] + input_bounds[i] + pad_high[i] + - (input_bounds[i] - 1) * pad_interior[i]; - } - - Array4D result(output_bounds[0], output_bounds[1], output_bounds[2], - output_bounds[3]); - result.Each([&](tensorflow::gtl::ArraySlice indices, float* value) { - for (int i = 0; i < 4; ++i) { - bool in_low_padding = indices[i] < pad_low[i]; - bool in_high_padding = indices[i] >= output_bounds[i] - pad_high[i]; - if (in_low_padding || in_high_padding) { - *value = pad; - return; - } - if (pad_interior[i] && - (indices[i] - pad_low[i]) % (pad_interior[i] + 1)) { - *value = pad; - return; - } - } - *value = operand((indices[0] - pad_low[0]) / (pad_interior[0] + 1), - (indices[1] - pad_low[1]) / (pad_interior[1] + 1), - (indices[2] - pad_low[2]) / (pad_interior[2] + 1), - (indices[3] - pad_low[3]) / (pad_interior[3] + 1)); - }); - return result; -} - } // namespace xla diff --git a/tensorflow/compiler/xla/reference_util.h b/tensorflow/compiler/xla/reference_util.h index 2da17307817858eea60e868f4be1ab8138784385..58e1a844610678f64677838e93f0379b63f65d39 100644 --- a/tensorflow/compiler/xla/reference_util.h +++ b/tensorflow/compiler/xla/reference_util.h @@ -70,7 +70,7 @@ class ReferenceUtil { // dilation factors. static std::unique_ptr> ConvArray4DGeneralDimensionsDilated( const Array4D& lhs, const Array4D& rhs, - std::pair stride, Padding padding, + std::pair kernel_stride, Padding padding, std::pair lhs_dilation, std::pair rhs_dilation, ConvolutionDimensionNumbers dnums); @@ -184,6 +184,12 @@ class ReferenceUtil { const std::function& reduce_func, const tensorflow::gtl::ArraySlice& window, const tensorflow::gtl::ArraySlice& stride, Padding padding); + static std::unique_ptr> ReduceWindow1DGeneric( + const tensorflow::gtl::ArraySlice& operand, float init, + const std::function& reduce_func, + const tensorflow::gtl::ArraySlice& window, + const tensorflow::gtl::ArraySlice& stride, + const tensorflow::gtl::ArraySlice>& padding); static std::unique_ptr> ReduceWindow4DGeneric( const Array4D& operand, float init, const std::function& reduce_func, @@ -486,19 +492,147 @@ class ReferenceUtil { } // Returns the result of a 2D pad on an input matrix. - static std::unique_ptr> PadArray2D( - const Array2D& operand, const PaddingConfig& padding, - const float pad); + template + static std::unique_ptr> PadArray2D( + const Array2D& operand, const PaddingConfig& padding, + const NativeT pad) { + int64 in0 = operand.n1(); + int64 high_padding0 = padding.dimensions(0).edge_padding_high(); + int64 low_padding0 = padding.dimensions(0).edge_padding_low(); + int64 interior_padding0 = padding.dimensions(0).interior_padding(); + int64 out0 = + in0 + low_padding0 + high_padding0 + (in0 - 1) * interior_padding0; + + int64 in1 = operand.n2(); + int64 high_padding1 = padding.dimensions(1).edge_padding_high(); + int64 low_padding1 = padding.dimensions(1).edge_padding_low(); + int64 interior_padding1 = padding.dimensions(1).interior_padding(); + int64 out1 = + in1 + low_padding1 + high_padding1 + (in1 - 1) * interior_padding1; + + auto result = MakeUnique>(out0, out1); + result->Fill(pad); + int64 o0 = low_padding0; + for (int64 i0 = 0; i0 < in0; ++i0) { + int64 o1 = low_padding1; + for (int64 i1 = 0; i1 < in1; ++i1) { + if (o0 >= 0 && o1 >= 0 && o0 < out0 && o1 < out1) { + (*result)(o0, o1) = operand(i0, i1); + } + o1 += interior_padding1 + 1; + } + o0 += interior_padding0 + 1; + } + return result; + } // Returns the result of a 3D pad on an input matrix. - static Array3D PadArray3D(const Array3D& operand, - const PaddingConfig& padding, - const float pad); + template + static Array3D PadArray3D(const Array3D& operand, + const PaddingConfig& padding, + const NativeT pad) { + CHECK_EQ(padding.dimensions_size(), 3); + + const std::vector input_bounds = {operand.n1(), operand.n2(), + operand.n3()}; + std::vector pad_low(3); + std::vector pad_high(3); + std::vector pad_interior(3); + std::vector output_bounds(3); + for (int64 i = 0; i < 3; ++i) { + pad_low[i] = padding.dimensions(i).edge_padding_low(); + pad_high[i] = padding.dimensions(i).edge_padding_high(); + CHECK_LE(0, pad_low[i]); + CHECK_LE(0, pad_high[i]); + CHECK_LE(0, padding.dimensions(i).interior_padding()) + << "not implemented"; + pad_interior[i] = padding.dimensions(i).interior_padding(); + + output_bounds[i] = pad_low[i] + input_bounds[i] + pad_high[i] + + (input_bounds[i] - 1) * pad_interior[i]; + } + + Array3D result(output_bounds[0], output_bounds[1], + output_bounds[2]); + std::vector indices = {0, 0, 0}; + for (indices[0] = 0; indices[0] < output_bounds[0]; ++indices[0]) { + for (indices[1] = 0; indices[1] < output_bounds[1]; ++indices[1]) { + for (indices[2] = 0; indices[2] < output_bounds[2]; ++indices[2]) { + NativeT* value = &result(indices[0], indices[1], indices[2]); + bool value_padded = false; + for (int i = 0; i < 3; ++i) { + bool in_low_padding = indices[i] < pad_low[i]; + bool in_high_padding = indices[i] >= output_bounds[i] - pad_high[i]; + if (in_low_padding || in_high_padding) { + *value = pad; + value_padded = true; + } + if (pad_interior[i] && + (indices[i] - pad_low[i]) % (pad_interior[i] + 1)) { + *value = pad; + value_padded = true; + } + } + if (value_padded) { + continue; + } + *value = operand((indices[0] - pad_low[0]) / (pad_interior[0] + 1), + (indices[1] - pad_low[1]) / (pad_interior[1] + 1), + (indices[2] - pad_low[2]) / (pad_interior[2] + 1)); + } + } + } + return result; + } // Returns the result of a 4D pad on an input array. - static Array4D PadArray4D(const Array4D& operand, - const PaddingConfig& padding, - const float pad); + template + static Array4D PadArray4D(const Array4D& operand, + const PaddingConfig& padding, + const NativeT pad) { + CHECK_EQ(padding.dimensions_size(), 4); + + const std::vector input_bounds = {operand.n1(), operand.n2(), + operand.n3(), operand.n4()}; + std::vector pad_low(4); + std::vector pad_high(4); + std::vector pad_interior(4); + std::vector output_bounds(4); + for (int64 i = 0; i < 4; ++i) { + pad_low[i] = padding.dimensions(i).edge_padding_low(); + pad_high[i] = padding.dimensions(i).edge_padding_high(); + CHECK_LE(0, padding.dimensions(i).interior_padding()) + << "not implemented"; + pad_interior[i] = padding.dimensions(i).interior_padding(); + + output_bounds[i] = pad_low[i] + input_bounds[i] + pad_high[i] + + (input_bounds[i] - 1) * pad_interior[i]; + } + + Array4D result(output_bounds[0], output_bounds[1], + output_bounds[2], output_bounds[3]); + result.Each( + [&](tensorflow::gtl::ArraySlice indices, NativeT* value) { + for (int i = 0; i < 4; ++i) { + bool in_low_padding = indices[i] < pad_low[i]; + bool in_high_padding = indices[i] >= output_bounds[i] - pad_high[i]; + if (in_low_padding || in_high_padding) { + *value = pad; + return; + } + if (pad_interior[i] && + (indices[i] - pad_low[i]) % (pad_interior[i] + 1)) { + *value = pad; + return; + } + } + *value = operand((indices[0] - pad_low[0]) / (pad_interior[0] + 1), + (indices[1] - pad_low[1]) / (pad_interior[1] + 1), + (indices[2] - pad_low[2]) / (pad_interior[2] + 1), + (indices[3] - pad_low[3]) / (pad_interior[3] + 1)); + }); + return result; + } // ApplyElementwise2D(f, x, y, ...) returns the Array2D formed by running // f(x[i], y[i], ...) for each array element in the Array2Ds x, y, .... diff --git a/tensorflow/compiler/xla/reference_util_test.cc b/tensorflow/compiler/xla/reference_util_test.cc index eb6a71242ffa1499876b90f14f8a60ffdbdd069c..846ccdc83df900e3afedb6ababe07ebb1bd68f41 100644 --- a/tensorflow/compiler/xla/reference_util_test.cc +++ b/tensorflow/compiler/xla/reference_util_test.cc @@ -60,7 +60,9 @@ TEST_F(ReferenceUtilTest, TransposeArray2D) { TEST_F(ReferenceUtilTest, MatmulArray2D) { Array2D rhs({ - {7.f, 8.f}, {9.f, 10.f}, {11.f, 12.f}, + {7.f, 8.f}, + {9.f, 10.f}, + {11.f, 12.f}, }); auto result = ReferenceUtil::MatmulArray2D(*matrix_, rhs); auto actual_literal = Literal::CreateR2FromArray2D(*result); @@ -326,8 +328,10 @@ TEST_F(ReferenceUtilTest, ConvGeneralDimensionsWithSamePadding) { dimension_numbers.set_input_feature_dimension(0); dimension_numbers.set_output_batch_dimension(2); dimension_numbers.set_output_feature_dimension(0); - dimension_numbers.add_spatial_dimensions(1); - dimension_numbers.add_spatial_dimensions(3); + dimension_numbers.add_input_spatial_dimensions(1); + dimension_numbers.add_output_spatial_dimensions(1); + dimension_numbers.add_input_spatial_dimensions(3); + dimension_numbers.add_output_spatial_dimensions(3); dimension_numbers.set_kernel_output_feature_dimension(0); dimension_numbers.set_kernel_input_feature_dimension(2); dimension_numbers.add_kernel_spatial_dimensions(1); @@ -380,8 +384,10 @@ TEST_F(ReferenceUtilTest, ConvGeneralDimensionsWithValidPadding) { dimension_numbers.set_input_feature_dimension(0); dimension_numbers.set_output_batch_dimension(2); dimension_numbers.set_output_feature_dimension(0); - dimension_numbers.add_spatial_dimensions(1); - dimension_numbers.add_spatial_dimensions(3); + dimension_numbers.add_input_spatial_dimensions(1); + dimension_numbers.add_output_spatial_dimensions(1); + dimension_numbers.add_input_spatial_dimensions(3); + dimension_numbers.add_output_spatial_dimensions(3); dimension_numbers.set_kernel_output_feature_dimension(0); dimension_numbers.set_kernel_input_feature_dimension(2); diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index a15f3f654b14a715a2fbc71cdd38d46ac0268c02..c7432aacd18215d8c561b636a8ccc0da8118398c 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -566,7 +566,6 @@ cc_library( hdrs = ["shaped_buffer.h"], deps = [ ":device_memory_allocator", - ":transfer_manager", "//tensorflow/compiler/xla:shape_tree", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -630,6 +629,7 @@ cc_library( cc_library( name = "llvm_compiler", + srcs = ["llvm_compiler.cc"], hdrs = ["llvm_compiler.h"], deps = [ ":compiler", @@ -642,6 +642,7 @@ cc_library( srcs = ["transfer_manager.cc"], hdrs = ["transfer_manager.h"], deps = [ + ":shaped_buffer", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -1053,9 +1054,7 @@ cc_library( srcs = ["algebraic_simplifier.cc"], hdrs = ["algebraic_simplifier.h"], deps = [ - ":call_inliner", ":hlo", - ":hlo_evaluator", ":hlo_pass", ":hlo_query", ":shape_inference", @@ -1091,6 +1090,32 @@ tf_cc_test( ], ) +cc_library( + name = "while_loop_simplifier", + srcs = ["while_loop_simplifier.cc"], + hdrs = ["while_loop_simplifier.h"], + deps = [ + ":call_inliner", + ":hlo", + ":hlo_evaluator", + ":hlo_pass", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "while_loop_simplifier_test", + srcs = ["while_loop_simplifier_test.cc"], + deps = [ + ":hlo_matchers", + ":while_loop_simplifier", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/core:test", + ], +) + cc_library( name = "defuser", srcs = ["defuser.cc"], @@ -1118,6 +1143,22 @@ tf_cc_test( ], ) +cc_library( + name = "dot_decomposer", + srcs = ["dot_decomposer.cc"], + hdrs = ["dot_decomposer.h"], + deps = [ + ":hlo", + ":hlo_pass", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + ], +) + cc_library( name = "tuple_simplifier", srcs = ["tuple_simplifier.cc"], @@ -1267,24 +1308,6 @@ cc_library( alwayslink = True, # Contains per-platform transfer manager registration ) -tf_cc_test( - name = "transfer_manager_test", - srcs = ["transfer_manager_test.cc"], - deps = [ - ":generic_transfer_manager", - "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/service/cpu:cpu_transfer_manager", - "//tensorflow/compiler/xla/tests:literal_test_util", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", - ], -) - cc_library( name = "hlo_cost_analysis", srcs = ["hlo_cost_analysis.cc"], @@ -1297,6 +1320,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", ], ) @@ -1334,6 +1358,7 @@ cc_library( deps = [ ":hlo", ":hlo_cost_analysis", + ":hlo_profile_printer", ":human_readable_profile_builder", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", @@ -1342,6 +1367,18 @@ cc_library( ], ) +tf_cc_test( + name = "hlo_execution_profile_test", + srcs = ["hlo_execution_profile_test.cc"], + deps = [ + ":cpu_plugin", + ":hlo_cost_analysis", + ":hlo_execution_profile", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + ], +) + tf_cc_test( name = "hlo_computation_test", srcs = ["hlo_computation_test.cc"], @@ -1618,10 +1655,14 @@ cc_library( deps = [ ":buffer_liveness", ":hlo", + ":hlo_alias_analysis", + ":hlo_dce", + ":hlo_graph_dumper", + ":hlo_ordering", ":hlo_pass", ":liveness_util", ":logical_buffer", - ":tuple_points_to_analysis", + ":tuple_simplifier", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", @@ -1636,15 +1677,17 @@ tf_cc_test( deps = [ ":copy_insertion", ":hlo", + ":hlo_graph_dumper", ":hlo_matchers", - ":tuple_points_to_analysis", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", ], ) @@ -1754,7 +1797,6 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:test_utils", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", ], ) @@ -1825,7 +1867,6 @@ tf_cc_test( "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", ], ) @@ -1864,6 +1905,22 @@ tf_cc_test( ], ) +cc_library( + name = "hlo_element_type_converter", + srcs = ["hlo_element_type_converter.cc"], + hdrs = ["hlo_element_type_converter.h"], + deps = [ + ":hlo", + ":hlo_evaluator", + ":hlo_pass", + ":hlo_query", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/core:lib", + ], +) + cc_library( name = "device_memory_allocator", srcs = ["device_memory_allocator.cc"], @@ -1961,6 +2018,7 @@ cc_library( ":hlo", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:xla_proto", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", @@ -2126,6 +2184,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:backend", "//tensorflow/compiler/xla/service:compiler", + "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", @@ -2133,6 +2192,16 @@ cc_library( ], ) +cc_library( + name = "hlo_profile_printer", + srcs = ["hlo_profile_printer.cc"], + hdrs = ["hlo_profile_printer.h"], + deps = [ + ":human_readable_profile_builder", + "//tensorflow/compiler/xla:types", + ], +) + # ----------------------------------------------------------------------------- filegroup( diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index ee5cf8a10074d72d81374cf9dcb2cb2164f0d9db..2c0d1900eb6108eb8028fd89220758df03746647 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -24,10 +24,8 @@ limitations under the License. #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/service/call_inliner.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/hlo_evaluator.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_query.h" @@ -48,9 +46,6 @@ limitations under the License. namespace xla { namespace { -using tensorflow::gtl::nullopt; -using tensorflow::gtl::optional; - // Returns whether operand is a literal with the given value. bool IsLiteralWithValue(const HloInstruction* operand, int8 value) { return operand->opcode() == HloOpcode::kConstant && @@ -137,7 +132,10 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { Status HandleConvert(HloInstruction* convert) override; + Status HandleComplex(HloInstruction* complex) override; + Status HandleReal(HloInstruction* real) override; + Status HandleImag(HloInstruction* imag) override; Status HandleConvolution(HloInstruction* convolution) override; @@ -175,8 +173,6 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { Status HandleMaximum(HloInstruction* maximum) override; Status HandleMinimum(HloInstruction* minimum) override; - Status HandleWhile(HloInstruction* while_op) override; - // Returns whether algebraic simplification has occurred. const bool changed() const { return changed_; } @@ -184,19 +180,46 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { static bool Run( HloComputation* computation, bool is_layout_sensitive, AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback, - bool enable_dot_simplification, bool enable_conv_simplification); + bool enable_dot_strength_reduction, bool enable_conv_simplification); private: explicit AlgebraicSimplifierVisitor( HloComputation* computation, bool is_layout_sensitive, AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback, - bool enable_dot_simplification, bool enable_conv_simplification) + bool enable_dot_strength_reduction, bool enable_conv_simplification) : computation_(computation), is_layout_sensitive_(is_layout_sensitive), valid_bitcast_callback_(std::move(valid_bitcast_callback)), - enable_dot_simplification_(enable_dot_simplification), + enable_dot_strength_reduction_(enable_dot_strength_reduction), enable_conv_simplification_(enable_conv_simplification) {} + // Transforms Dots where at least one input is a vector or has a degenerate + // dimension and converts it into a multiply and reduce. This should enable + // more fusion than leaving the nodes as Dot operations. + StatusOr HandleDotStrengthReduction(HloInstruction* dot); + + // Reshapes an instruction to rank 1 if it is not already rank 1. + HloInstruction* Flatten(HloInstruction* hlo) { + if (ShapeUtil::Rank(hlo->shape()) == 1) { + return hlo; + } + return computation_->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(hlo->shape().element_type(), + {ShapeUtil::ElementsIn(hlo->shape())}), + hlo)); + } + + // Helper method to perform and add reduction in a single dimension. + HloInstruction* AddReduce(HloInstruction* hlo, int64 dim) { + HloInstruction* zero = computation_->AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + HloComputation* AddReduce_computation = CreateScalarBinaryComputation( + computation_->parent(), F32, HloOpcode::kAdd); + Shape shape = ShapeUtil::DeleteDimension(dim, hlo->shape()); + return computation_->AddInstruction(HloInstruction::CreateReduce( + shape, hlo, zero, {dim}, AddReduce_computation)); + } + // Convenience method for replacing an instruction with a bitcast. void ReplaceWithBitcast(HloInstruction* instruction); @@ -269,8 +292,8 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { // Callback used to determine if a bitcast is possible. AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback_; - // Disable dot simplication on platforms where it causes a slowdown. - bool enable_dot_simplification_; + // Disable dot strength reduction on platforms where it causes a slowdown. + bool enable_dot_strength_reduction_; // Disable convolution simplication on platforms where it causes a slowdown. bool enable_conv_simplification_; @@ -279,10 +302,10 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { bool AlgebraicSimplifierVisitor::Run( HloComputation* computation, bool is_layout_sensitive, AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback, - bool enable_dot_simplification, bool enable_conv_simplification) { + bool enable_dot_strength_reduction, bool enable_conv_simplification) { AlgebraicSimplifierVisitor visitor( computation, is_layout_sensitive, std::move(valid_bitcast_callback), - enable_dot_simplification, enable_conv_simplification); + enable_dot_strength_reduction, enable_conv_simplification); TF_CHECK_OK(computation->Accept(&visitor)); return visitor.changed_; } @@ -578,68 +601,72 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { return Status::OK(); } -Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { - auto lhs = dot->mutable_operand(0); - auto rhs = dot->mutable_operand(1); - if (!enable_dot_simplification_) { - return Status::OK(); - } - // Only optimize F32 dot operations where the dot, rhs and lhs are rank 2 or - // below. - if (dot->shape().element_type() != F32 || ShapeUtil::Rank(lhs->shape()) > 2 || - ShapeUtil::Rank(rhs->shape()) > 2 || ShapeUtil::Rank(dot->shape()) > 2) { - return Status::OK(); - } - - // Replace a zero element dot with a broadcast of the constant 0. - if (ShapeUtil::HasZeroElements(dot->shape()) || - ShapeUtil::HasZeroElements(lhs->shape()) || - ShapeUtil::HasZeroElements(rhs->shape())) { - auto zero = computation_->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); - return ReplaceWithNewInstruction( - dot, HloInstruction::CreateBroadcast(dot->shape(), zero, {})); - } - - // Simplify dot(transpose(a), transpose(b)) to transpose(dot(b,a)). - if (lhs->IsRank2Transpose() && rhs->IsRank2Transpose()) { - auto new_dot = computation_->AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::PermuteDimensions({1, 0}, dot->shape()), HloOpcode::kDot, - rhs->mutable_operand(0), lhs->mutable_operand(0))); - return ReplaceWithNewInstruction( - dot, HloInstruction::CreateTranspose(dot->shape(), new_dot, {1, 0})); - } +StatusOr AlgebraicSimplifierVisitor::HandleDotStrengthReduction( + HloInstruction* dot) { + HloInstruction* lhs = dot->mutable_operand(0); + HloInstruction* rhs = dot->mutable_operand(1); + int64 lhs_collapsing_dim = + dot->dot_dimension_numbers().lhs_contracting_dimensions(0); + if (lhs->IsRank2Transpose()) { + lhs = lhs->mutable_operand(0); + lhs_collapsing_dim = 1 - lhs_collapsing_dim; + } + const int64 lhs_kept_dim = 1 - lhs_collapsing_dim; + + int64 rhs_collapsing_dim = + dot->dot_dimension_numbers().rhs_contracting_dimensions(0); + if (rhs->IsRank2Transpose()) { + rhs = rhs->mutable_operand(0); + rhs_collapsing_dim = 1 - rhs_collapsing_dim; + } + const int64 rhs_kept_dim = 1 - rhs_collapsing_dim; + + auto reshape_if_necessary = [&](HloInstruction* hlo) { + if (ShapeUtil::SameDimensions(hlo->shape(), dot->shape())) { + return hlo; + } + return computation_->AddInstruction( + HloInstruction::CreateReshape(dot->shape(), hlo)); + }; - // Simplify outer product into multiply with implicit broadcasting. - // - // A dot(a[M, 1], b[1, N]) = multiply(a [M,1], b [1, N]) - if (ShapeUtil::Rank(rhs->shape()) == 2 && rhs->shape().dimensions(0) == 1) { - return ReplaceWithNewInstruction( - dot, HloInstruction::CreateBinary(dot->shape(), HloOpcode::kMultiply, - lhs, rhs)); - } + auto broadcast_to_dim = [&](HloInstruction* hlo, const Shape& shape, + int64 dim) { + return computation_->AddInstruction( + HloInstruction::CreateBroadcast(shape, hlo, {dim})); + }; - // The following graph transformations take Dots where at least one input is a - // vector or has a degenerate dimension and converts it into a multiply and - // reduce. This should enable more fusion than leaving the nodes as Dot - // operations. + auto multiply = [&](HloInstruction* local_lhs, HloInstruction* local_rhs) { + return computation_->AddInstruction(HloInstruction::CreateBinary( + local_lhs->shape(), HloOpcode::kMultiply, local_lhs, local_rhs)); + }; // Strength reduce dot(a[K] , b[K]) = // reshape(result.shape, // reduce_sum(multiply(a, b), {0})) if (ShapeUtil::Rank(rhs->shape()) == 1 && ShapeUtil::Rank(lhs->shape()) == 1) { - auto multiply = computation_->AddInstruction(HloInstruction::CreateBinary( - rhs->shape(), HloOpcode::kMultiply, lhs, rhs)); - HloComputation* add_reduce_computation = CreateScalarBinaryComputation( - computation_->parent(), F32, HloOpcode::kAdd); - auto zero = computation_->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); - auto reduce = computation_->AddInstruction(HloInstruction::CreateReduce( - ShapeUtil::MakeShape(dot->shape().element_type(), {}), multiply, zero, - {0}, add_reduce_computation)); - return ReplaceWithNewInstruction( - dot, HloInstruction::CreateReshape(dot->shape(), reduce)); + TF_RETURN_IF_ERROR( + ReplaceInstruction(dot, reshape_if_necessary(AddReduce( + multiply(Flatten(lhs), Flatten(rhs)), 0)))); + return true; + } + + if (ShapeUtil::IsEffectiveScalar(rhs->shape()) && + ShapeUtil::IsEffectiveScalar(lhs->shape())) { + TF_RETURN_IF_ERROR(ReplaceInstruction( + dot, reshape_if_necessary(multiply(Flatten(lhs), Flatten(rhs))))); + return true; + } + + // Simplify outer product into multiply with implicit broadcasting. + // + // A dot(a[M, 1], b[1, N]) = multiply(a [M,1], b [1, N]) + if (ShapeUtil::Rank(rhs->shape()) == 2 && + rhs->shape().dimensions(rhs_collapsing_dim) == 1) { + TF_RETURN_IF_ERROR(ReplaceInstruction( + dot, multiply(broadcast_to_dim(Flatten(lhs), dot->shape(), 0), + broadcast_to_dim(Flatten(rhs), dot->shape(), 1)))); + return true; } // Strength reduce dot(a[1, K], b) = @@ -650,35 +677,21 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { // ) // ) if (ShapeUtil::Rank(lhs->shape()) == 1 || - (ShapeUtil::Rank(lhs->shape()) == 2 && lhs->shape().dimensions(0) == 1)) { - auto new_lhs = computation_->AddInstruction(HloInstruction::CreateReshape( - ShapeUtil::MakeShape(lhs->shape().element_type(), - {ShapeUtil::ElementsIn(lhs->shape())}), - lhs)); - HloComputation* add_reduce_computation = CreateScalarBinaryComputation( - computation_->parent(), F32, HloOpcode::kAdd); - auto zero = computation_->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); - HloInstruction* reduce; + (ShapeUtil::Rank(lhs->shape()) == 2 && + lhs->shape().dimensions(lhs_kept_dim) == 1)) { if (ShapeUtil::Rank(rhs->shape()) == 1) { - auto multiply = computation_->AddInstruction(HloInstruction::CreateBinary( - rhs->shape(), HloOpcode::kMultiply, new_lhs, rhs)); - reduce = computation_->AddInstruction(HloInstruction::CreateReduce( - ShapeUtil::MakeShape(dot->shape().element_type(), {}), multiply, zero, - {0}, add_reduce_computation)); - } else { - new_lhs = computation_->AddInstruction( - HloInstruction::CreateBroadcast(rhs->shape(), new_lhs, {0})); - auto multiply = computation_->AddInstruction(HloInstruction::CreateBinary( - rhs->shape(), HloOpcode::kMultiply, new_lhs, rhs)); - - reduce = computation_->AddInstruction(HloInstruction::CreateReduce( - ShapeUtil::MakeShape(dot->shape().element_type(), - {rhs->shape().dimensions(1)}), - multiply, zero, {0}, add_reduce_computation)); + TF_RETURN_IF_ERROR(ReplaceInstruction( + dot, + reshape_if_necessary(AddReduce(multiply(Flatten(lhs), rhs), 0)))); + return true; } - return ReplaceWithNewInstruction( - dot, HloInstruction::CreateReshape(dot->shape(), reduce)); + TF_RETURN_IF_ERROR(ReplaceInstruction( + dot, reshape_if_necessary( + AddReduce(multiply(broadcast_to_dim(Flatten(lhs), rhs->shape(), + rhs_collapsing_dim), + rhs), + rhs_collapsing_dim)))); + return true; } // Strength reduce dot(a, b[K, 1]) = @@ -686,26 +699,60 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { // reduce_sum(multiply(a, broadcast(reshape([K],b), {1})), {0}) // ) if (ShapeUtil::Rank(rhs->shape()) == 1 || - (ShapeUtil::Rank(rhs->shape()) == 2 && rhs->shape().dimensions(1) == 1)) { - auto new_rhs = computation_->AddInstruction(HloInstruction::CreateReshape( - ShapeUtil::MakeShape(rhs->shape().element_type(), - {ShapeUtil::ElementsIn(rhs->shape())}), - rhs)); - new_rhs = computation_->AddInstruction( - HloInstruction::CreateBroadcast(lhs->shape(), new_rhs, {1})); - auto multiply = computation_->AddInstruction(HloInstruction::CreateBinary( - lhs->shape(), HloOpcode::kMultiply, lhs, new_rhs)); - HloComputation* add_reduce_computation = CreateScalarBinaryComputation( - computation_->parent(), F32, HloOpcode::kAdd); + (ShapeUtil::Rank(rhs->shape()) == 2 && + rhs->shape().dimensions(rhs_kept_dim) == 1)) { + TF_RETURN_IF_ERROR(ReplaceInstruction( + dot, reshape_if_necessary(AddReduce( + multiply(lhs, broadcast_to_dim(Flatten(rhs), lhs->shape(), + lhs_collapsing_dim)), + lhs_collapsing_dim)))); + return true; + } + return false; +} + +Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { + auto lhs = dot->mutable_operand(0); + auto rhs = dot->mutable_operand(1); + + // Only optimize F32 dot operations where the dot, rhs and lhs are rank 2 or + // below. + if (dot->shape().element_type() != F32 || ShapeUtil::Rank(lhs->shape()) > 2 || + ShapeUtil::Rank(rhs->shape()) > 2 || ShapeUtil::Rank(dot->shape()) > 2) { + return Status::OK(); + } + + // Replace a zero element dot with a broadcast of the constant 0. + if (ShapeUtil::HasZeroElements(dot->shape()) || + ShapeUtil::HasZeroElements(lhs->shape()) || + ShapeUtil::HasZeroElements(rhs->shape())) { auto zero = computation_->AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); - auto reduce = computation_->AddInstruction(HloInstruction::CreateReduce( - ShapeUtil::MakeShape(dot->shape().element_type(), - {lhs->shape().dimensions(0)}), - multiply, zero, {1}, add_reduce_computation)); return ReplaceWithNewInstruction( - dot, HloInstruction::CreateReshape(dot->shape(), reduce)); + dot, HloInstruction::CreateBroadcast(dot->shape(), zero, {})); + } + + if (enable_dot_strength_reduction_ && !is_layout_sensitive_) { + TF_ASSIGN_OR_RETURN(bool did_strength_reduction, + HandleDotStrengthReduction(dot)); + if (did_strength_reduction) { + return Status::OK(); + } } + + // Simplify dot(transpose(a), transpose(b)) to transpose(dot(b,a)). + if (lhs->IsRank2Transpose() && rhs->IsRank2Transpose()) { + DotDimensionNumbers dot_dimension_numbers; + dot_dimension_numbers.add_lhs_contracting_dimensions(1); + dot_dimension_numbers.add_rhs_contracting_dimensions(0); + auto new_dot = computation_->AddInstruction(HloInstruction::CreateDot( + ShapeUtil::PermuteDimensions({1, 0}, dot->shape()), + rhs->mutable_operand(0), lhs->mutable_operand(0), + dot_dimension_numbers)); + return ReplaceWithNewInstruction( + dot, HloInstruction::CreateTranspose(dot->shape(), new_dot, {1, 0})); + } + return Status::OK(); } @@ -951,6 +998,18 @@ Status AlgebraicSimplifierVisitor::HandleConvert(HloInstruction* convert) { return Status::OK(); } +// Complex(Real(c), Imag(c)) -> c +Status AlgebraicSimplifierVisitor::HandleComplex(HloInstruction* complex) { + auto real = complex->mutable_operand(0); + auto imag = complex->mutable_operand(1); + if (real->opcode() == HloOpcode::kReal && + imag->opcode() == HloOpcode::kImag && + real->operand(0) == imag->operand(0)) { + return ReplaceInstruction(complex, real->mutable_operand(0)); + } + return Status::OK(); +} + // Real(Complex(r, i)) -> r Status AlgebraicSimplifierVisitor::HandleReal(HloInstruction* real) { auto operand = real->mutable_operand(0); @@ -1100,9 +1159,15 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) { if (IsAll(rhs, -1)) { auto* one = computation_->AddInstruction(HloInstruction::CreateConstant( Literal::One(rhs->shape().element_type()).CloneToUnique())); + + // Explicitly broadcast scalar 1 to the output shape, to avoid implicit + // broadcast in divide HLO as we are trying to eliminate implicit + // broadcasting at HLO level. + auto* broadcast_one = computation_->AddInstruction( + HloInstruction::CreateBroadcast(power->shape(), one, {})); return ReplaceWithNewInstruction( power, HloInstruction::CreateBinary(power->shape(), HloOpcode::kDivide, - one, lhs)); + broadcast_one, lhs)); } return Status::OK(); } @@ -1390,6 +1455,15 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( auto operand = reduce_window->mutable_operand(0); const Window& window = reduce_window->window(); auto function = reduce_window->to_apply(); + if (ShapeUtil::IsScalar(operand->shape())) { + TF_RET_CHECK(ShapeUtil::IsScalar(reduce_window->shape())); + return ReplaceWithNewInstruction( + reduce_window, + HloInstruction::CreateMap(reduce_window->shape(), + {operand, reduce_window->mutable_operand(1)}, + function)); + } + VLOG(10) << "Considering folding Pad: " << operand->ToString() << "\ninto reduce-window: " << reduce_window->ToString(); @@ -1591,8 +1665,11 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( auto new_lhs = add_bitcast(new_input_shape, lhs); auto new_rhs = add_bitcast(new_filter_shape, rhs); - auto dot = computation_->AddInstruction(HloInstruction::CreateBinary( - dot_output_shape, HloOpcode::kDot, new_lhs, new_rhs)); + DotDimensionNumbers dot_dimension_numbers; + dot_dimension_numbers.add_lhs_contracting_dimensions(1); + dot_dimension_numbers.add_rhs_contracting_dimensions(0); + auto dot = computation_->AddInstruction(HloInstruction::CreateDot( + dot_output_shape, new_lhs, new_rhs, dot_dimension_numbers)); return ReplaceInstruction(convolution, add_bitcast(convolution_shape, dot)); } @@ -1673,312 +1750,6 @@ Status AlgebraicSimplifierVisitor::HandleMinimum(HloInstruction* minimum) { return Status::OK(); } -// If all of instr's operands are either constants or have the form -// get-tuple-element(gte_operand, N) -// for the same value N, returns N. Otherwise, returns nullopt. -static optional GetGTEOperandIndex(const HloInstruction* instr, - const HloInstruction* gte_operand) { - VLOG(2) << "GetGTEOperandIndex(" << instr->ToString() << ", " - << gte_operand->ToString() << ")"; - optional tuple_idx; - for (const HloInstruction* operand : instr->operands()) { - if (operand->IsConstant()) { - continue; - } - if (operand->opcode() != HloOpcode::kGetTupleElement) { - VLOG(2) << "instr uses something other than gte(gte_operand): " - << operand->ToString(); - return nullopt; - } - if (operand->operand(0) != gte_operand) { - VLOG(2) << "instr has gte whose operand is not gte_operand: " - << operand->ToString(); - return nullopt; - } - if (tuple_idx && tuple_idx != operand->tuple_index()) { - VLOG(2) << "instr has operands with conflicting gte indices, " - << *tuple_idx << " vs " << operand->tuple_index(); - return nullopt; - } - - tuple_idx = operand->tuple_index(); - } - return tuple_idx; -} - -// Tries to get the tuple index of the induction variable of a while loop. -// -// Checks that the loop condition and root both plumb the induction variable -// through the same tuple index, and that they both apply exactly one op to the -// induction variable before deciding whether to do another loop iteration (in -// the loop condition's case) or packing the induction variable into the result -// tuple (in the loop body's case). -// -// Specifically, checks that the loop condition has structure -// -// root = op(constants, get-tuple-elem(param0, N), constants) -// -// and the loop body has the structure -// -// inc = op(constants, get-tuple-elem(param0, N), constants) -// root = tuple(..., inc, ...) // inc is N'th operand of tuple(). -// -// If so, returns N. Otherwise, returns nullopt. -static optional GetLoopInductionVarTupleIdx( - const HloInstruction* while_op) { - CHECK_EQ(while_op->opcode(), HloOpcode::kWhile); - VLOG(2) << "Finding induction variable for loop " - << while_op->ToShortString(); - - // The while_cond computation should have the form - // - // while_cond_root = - // op(constants, get-tuple-elem(while_cond_param, N), constants). - // - // If it does, set indvar_tuple_idx to N. - auto* while_cond = while_op->while_condition(); - auto* while_cond_root = while_cond->root_instruction(); - auto* while_cond_param = while_cond->parameter_instruction(0); - optional indvar_tuple_idx = - GetGTEOperandIndex(while_cond_root, while_cond_param); - if (!indvar_tuple_idx) { - VLOG(2) << "Induction variable not found in loop condition: " - << while_cond->root_instruction()->ToString(); - return nullopt; - } - - // The while_body computation should have the form - // - // while_body_inc = - // op(constants, get-tuple-elem(while_body_param, N), constants) - // while_body_root = tuple(..., while_body_inc, ...) - // - // where while_body_inc is operand N of while_body_root. - auto* while_body = while_op->while_body(); - auto* while_body_root = while_body->root_instruction(); - if (while_body_root->opcode() != HloOpcode::kTuple) { - VLOG(2) << "While body's root is not a tuple instruction: " - << while_body_root->ToString(); - return nullopt; - } - - auto* while_body_inc = while_body_root->operand(*indvar_tuple_idx); - auto* while_body_param = while_body->parameter_instruction(0); - optional while_body_indvar_tuple_idx = - GetGTEOperandIndex(while_body_inc, while_body_param); - if (!while_body_indvar_tuple_idx) { - VLOG(2) - << "Induction variable not found in while body increment instruction: " - << while_body_inc->ToString(); - return nullopt; - } - if (while_body_indvar_tuple_idx != indvar_tuple_idx) { - VLOG(2) << "Tuple index of induction variable does not match between loop " - "condition (" - << *indvar_tuple_idx << ") and while body (" - << *while_body_indvar_tuple_idx << ")"; - return nullopt; - } - - // Finally, check that the while loop's initial value is a tuple with enough - // elements. - auto* while_init = while_op->operand(0); - if (while_init->opcode() != HloOpcode::kTuple) { - VLOG(2) << "While init expected to be a tuple: " << while_init->ToString(); - return nullopt; - } - - VLOG(2) << "Induction variable's tuple index: " << *indvar_tuple_idx; - return indvar_tuple_idx; -} - -// Finds and returns the non-constant operand in instr. -// -// CHECK-fails if instr doesn't have exactly one unique non-constant operand. -static const HloInstruction* NonConstantOperand(const HloInstruction* instr) { - const HloInstruction* result = nullptr; - for (const HloInstruction* operand : instr->operands()) { - if (!operand->IsConstant()) { - if (result != nullptr) { - CHECK_EQ(result, operand); - } - result = operand; - } - } - CHECK_NE(result, nullptr); - return result; -} - -// Tries to determine the number of times the given loop executes. Currently -// simply returns 0, 1, or "can't tell" (nullopt). -static optional GetLoopTripCount(HloInstruction* while_op) { - CHECK_EQ(while_op->opcode(), HloOpcode::kWhile); - VLOG(2) << "Getting trip count for loop " << while_op->ToString(); - - // The loop's induction variable is found at - // - // get-tuple-elem(comp->parameter_instruction(0), *indvar_tuple_idx), - // - // where comp is while_op->while_body() or while_op->while_condition(). - optional indvar_tuple_idx = GetLoopInductionVarTupleIdx(while_op); - if (!indvar_tuple_idx) { - return nullopt; - } - - VLOG(2) << "Induction variable is at index " << *indvar_tuple_idx - << " in input tuple."; - - // Now that we know the index of the induction variable, we can we can try to - // compute how many times the loop executes. Start by computing the induction - // variable's initial value. - HloEvaluator evaluator; - auto* while_init = while_op->mutable_operand(0); - auto* indvar_init = while_init->mutable_operand(*indvar_tuple_idx); - StatusOr> indvar_init_result = - evaluator.Evaluate(indvar_init); - if (!indvar_init_result.ok()) { - VLOG(2) << "Couldn't evaluate induction variable init: " - << indvar_init_result.status(); - return nullopt; - } - - // Evaluates the while loop's condition, returning either "true" (continue - // looping), "false" (stop looping), or nullopt (can't evaluate). - auto evaluate_while_cond = [&](const Literal& indvar) -> optional { - auto* while_cond = while_op->while_condition(); - auto* while_cond_root = while_cond->root_instruction(); - auto* while_cond_indvar = NonConstantOperand(while_cond_root); - StatusOr> result = - evaluator.EvaluateWithSubstitutions(while_cond_root, - {{while_cond_indvar, &indvar}}); - if (!result.ok()) { - VLOG(2) << "Couldn't evaluate while cond: " << result.status(); - return nullopt; - } - return result.ValueOrDie()->GetArraySlice() == - tensorflow::gtl::ArraySlice{true}; - }; - - // The initial value of the induction variable. - const Literal& indvar_iter0_val = *indvar_init_result.ValueOrDie(); - - // Evaluate whether the while condition is true when seeded with - // indvar_iter0_val. - optional while_cond_iter0_val = evaluate_while_cond(indvar_iter0_val); - if (while_cond_iter0_val == false) { - VLOG(2) << "Loop has static trip count of 0."; - return 0; - } - - // Calculate the value of the induction variable after one iteration of the - // loop, and check whether the while condition is true with this new value. - auto* while_body = while_op->while_body(); - auto* while_body_indvar_update = - while_body->root_instruction()->operand(*indvar_tuple_idx); - auto* while_body_indvar = NonConstantOperand(while_body_indvar_update); - StatusOr> indvar_iter1_result = - evaluator.EvaluateWithSubstitutions( - while_body_indvar_update, {{while_body_indvar, &indvar_iter0_val}}); - if (!indvar_iter1_result.ok()) { - VLOG(2) << "Couldn't evaluate induction variable update: " - << indvar_iter1_result.status(); - return nullopt; - } - const Literal& indvar_iter1_val = *indvar_iter1_result.ValueOrDie(); - optional while_cond_iter1_val = evaluate_while_cond(indvar_iter1_val); - if (while_cond_iter1_val == false) { - VLOG(2) << "Determined that loop has static trip count of 1."; - return 1; - } - - VLOG(2) << "Loop has unknown trip count >= 1."; - return nullopt; -} - -// Determines whether the given instruction is a send/recv node, or has a -// subcomputation which contains a send/recv node. -static bool IsOrContainsSendOrRecv(const HloInstruction* instr); - -// Determines whether the given computation contains a send or recv node. -static bool ContainsSendOrRecv(const HloComputation* comp) { - for (const auto* instr : comp->instructions()) { - if (IsOrContainsSendOrRecv(instr)) { - return true; - } - } - return false; -} - -static bool IsOrContainsSendOrRecv(const HloInstruction* instr) { - if (instr->opcode() == HloOpcode::kSend || - instr->opcode() == HloOpcode::kRecv) { - return true; - } - for (const auto& subcomp : instr->called_computations()) { - if (ContainsSendOrRecv(subcomp)) { - return true; - } - } - return false; -} - -Status AlgebraicSimplifierVisitor::HandleWhile(HloInstruction* while_op) { - // We can't simplify while loops that contain send/recv nodes, because we rely - // on the particular loop structure around the node matching on the send and - // recv sides. - if (ContainsSendOrRecv(while_op->while_body()) || - ContainsSendOrRecv(while_op->while_condition())) { - VLOG(2) << "Not attempting to simplify while loop because it contains a " - "send/recv node: " - << while_op->ToShortString(); - return Status::OK(); - } - - // Cowardly refuse to simplify loops that are not removable. In practice, - // this means that we can't simplify loops that contain side-effecting - // instructions or have control predecessors/successors. - // - // This is not a fundamental limitation. The control operands can be moved - // onto the new HLOs after simplification, and any side-effecting ops inside - // the loop aren't removed, just cloned and added back to the loop. - // Nevertheless our infrastructure sees loop simplification as removal of - // these nodes and currently doesn't allow it. - if (!while_op->parent()->IsRemovable(while_op)) { - VLOG(2) << "Not attempting to simplify while loop it is not removable: " - << while_op->ToShortString(); - return Status::OK(); - } - - // Remove while loops with static trip count of 0. - optional trip_count = GetLoopTripCount(while_op); - if (trip_count && *trip_count == 0) { - // The loop never executes, so the value of the loop is the value of its - // "init" operand. - auto computation = while_op->parent(); - - // Remove while_op (i.e., call ReplaceInstruction rather than - // ReplaceUsesWithInstruction) so that if the algebraic simplifier is run in - // a loop without an intervening DCE, we don't try to re-simplify the loop. - TF_RETURN_IF_ERROR(computation->ReplaceInstruction( - while_op, while_op->mutable_operand(0))); - changed_ = true; - return Status::OK(); - } - - // Transform while loops with static trip count of 1 into a call op, then - // inline the call. - if (trip_count && *trip_count == 1) { - auto computation = while_op->parent(); - auto call_op = computation->AddInstruction(HloInstruction::CreateCall( - while_op->shape(), while_op->operands(), while_op->while_body())); - TF_RETURN_IF_ERROR(computation->ReplaceInstruction(while_op, call_op)); - TF_RETURN_IF_ERROR(CallInliner::Inline(call_op)); - changed_ = true; - return Status::OK(); - } - return Status::OK(); -} - StatusOr AlgebraicSimplifier::Run(HloModule* module) { XLA_VLOG_LINES(2, "AlgebraicSimplifier::Run(), before:\n" + module->ToString()); @@ -1986,7 +1757,7 @@ StatusOr AlgebraicSimplifier::Run(HloModule* module) { for (auto* comp : module->MakeNonfusionComputations()) { if (AlgebraicSimplifierVisitor::Run( comp, is_layout_sensitive_, valid_bitcast_callback_, - enable_dot_simplification_, enable_conv_simplification_)) { + enable_dot_strength_reduction_, enable_conv_simplification_)) { changed = true; } } diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.h b/tensorflow/compiler/xla/service/algebraic_simplifier.h index a9f476178c7af74c275a10de7727ea64e17d590f..43315f5cdc7afbe79039420320f4a0d0535e11f1 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.h +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.h @@ -40,11 +40,11 @@ class AlgebraicSimplifier : public HloPassInterface { // bitcasts. AlgebraicSimplifier(bool is_layout_sensitive, ValidBitcastCallback valid_bitcast_callback, - bool enable_dot_simplification = true, + bool enable_dot_strength_reduction = true, bool enable_conv_simplification = true) : is_layout_sensitive_(is_layout_sensitive), valid_bitcast_callback_(std::move(valid_bitcast_callback)), - enable_dot_simplification_(enable_dot_simplification), + enable_dot_strength_reduction_(enable_dot_strength_reduction), enable_conv_simplification_(enable_conv_simplification) {} ~AlgebraicSimplifier() override = default; tensorflow::StringPiece name() const override { return "algsimp"; } @@ -58,7 +58,7 @@ class AlgebraicSimplifier : public HloPassInterface { ValidBitcastCallback valid_bitcast_callback_; // Enable dot simplication on platforms where it is profitable. - bool enable_dot_simplification_; + bool enable_dot_strength_reduction_; // Enable convolution simplication on platforms where it is profitable. bool enable_conv_simplification_; diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 87d4fc9663daf3cc2806dfa6550812dd9b08b36c..7462e397ff07779c04bce18b68419bff9686dbd5 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -47,69 +47,7 @@ AlgebraicSimplifier::ValidBitcastCallback non_bitcasting_callback() { return [](const Shape&, const Shape&) { return false; }; } -class AlgebraicSimplifierTest : public HloVerifiedTestBase { - public: - // Makes a computation that contains a loop that runs num_iters times. - HloComputation* MakeSimpleLoop(HloModule* module, int num_iters); -}; - -HloComputation* AlgebraicSimplifierTest::MakeSimpleLoop(HloModule* module, - int num_iters) { - HloComputation::Builder builder(TestName()); - - auto loop_iter_init = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42))); - auto loop_data_init = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR1({0, 1, 2}))); - auto loop_init = builder.AddInstruction( - HloInstruction::CreateTuple({loop_iter_init, loop_data_init})); - - HloComputation* condition; - { - HloComputation::Builder cond_builder(TestName() + ".condition"); - auto loop_var = cond_builder.AddInstruction( - HloInstruction::CreateParameter(0, loop_init->shape(), "loop_var")); - auto loop_induction_var = - cond_builder.AddInstruction(HloInstruction::CreateGetTupleElement( - ShapeUtil::MakeShape(S32, {}), loop_var, 0)); - auto limit = cond_builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR0(42 + num_iters))); - cond_builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, loop_induction_var, - limit)); - condition = module->AddEmbeddedComputation(cond_builder.Build()); - } - - HloComputation* body; - { - HloComputation::Builder body_builder(TestName() + ".body"); - auto loop_var = body_builder.AddInstruction( - HloInstruction::CreateParameter(0, loop_init->shape(), "loop_var")); - auto loop_induction_var = - body_builder.AddInstruction(HloInstruction::CreateGetTupleElement( - ShapeUtil::MakeShape(S32, {}), loop_var, 0)); - auto new_loop_induction_var = - body_builder.AddInstruction(HloInstruction::CreateBinary( - loop_induction_var->shape(), HloOpcode::kAdd, loop_induction_var, - body_builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(1))))); - auto loop_data = - body_builder.AddInstruction(HloInstruction::CreateGetTupleElement( - loop_data_init->shape(), loop_var, 1)); - auto new_loop_data = - body_builder.AddInstruction(HloInstruction::CreateBinary( - loop_data_init->shape(), HloOpcode::kMultiply, loop_data, - loop_data)); - body_builder.AddInstruction( - HloInstruction::CreateTuple({new_loop_induction_var, new_loop_data})); - body = module->AddEmbeddedComputation(body_builder.Build()); - } - - builder.AddInstruction(HloInstruction::CreateWhile( - loop_init->shape(), condition, body, loop_init)); - - return module->AddEntryComputation(builder.Build()); -} +class AlgebraicSimplifierTest : public HloVerifiedTestBase {}; // Test that A + 0 is simplified to A TEST_F(AlgebraicSimplifierTest, AddZero) { @@ -433,6 +371,31 @@ TEST_F(AlgebraicSimplifierTest, DivOneArray) { EXPECT_EQ(root, param0); } +// Test that complex(real(c), imag(c)) is simplified to c. +TEST_F(AlgebraicSimplifierTest, ComplexOfRealImagC) { + Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2}); + Shape r2c64 = ShapeUtil::MakeShape(C64, {2, 2}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r2c64, "param0")); + HloInstruction* real = builder.AddInstruction( + HloInstruction::CreateUnary(r2f32, HloOpcode::kReal, param0)); + HloInstruction* imag = builder.AddInstruction( + HloInstruction::CreateUnary(r2f32, HloOpcode::kImag, param0)); + HloInstruction* cplx = builder.AddInstruction( + HloInstruction::CreateBinary(r2c64, HloOpcode::kComplex, real, imag)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root, cplx); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_EQ(root, param0); +} + // Test that real(complex(r,i)) is simplified to r. TEST_F(AlgebraicSimplifierTest, RealOfComplex) { Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2}); @@ -798,8 +761,10 @@ TEST_F(AlgebraicSimplifierTest, PowNegative1) { ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); HloInstruction* root = computation->root_instruction(); - EXPECT_THAT(root, op::Divide(op::Constant(), param0)); - EXPECT_EQ(root->operand(0)->literal().GetFirstElement(), 1); + EXPECT_THAT(root, op::Divide(op::Broadcast(), param0)); + EXPECT_EQ(root->operand(0)->opcode(), HloOpcode::kBroadcast); + EXPECT_EQ(root->operand(0)->operand(0)->literal().GetFirstElement(), + 1); } TEST_F(AlgebraicSimplifierTest, ReshapeBroadcast) { @@ -1659,8 +1624,11 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) { ConvolutionDimensionNumbers dnums; std::vector in_dims; int in_channel_idx = -1; - dnums.add_spatial_dimensions(-1); // filled in later - dnums.add_spatial_dimensions(-1); // filled in later + // filled in later + dnums.add_input_spatial_dimensions(-1); + dnums.add_output_spatial_dimensions(-1); + dnums.add_input_spatial_dimensions(-1); + dnums.add_output_spatial_dimensions(-1); for (int i = 0; i < strlen(options.dim_order); ++i) { char ch = options.dim_order[i]; if (ch == 'N') { @@ -1668,10 +1636,12 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) { dnums.set_output_batch_dimension(i); in_dims.push_back(options.in_batch); } else if (ch == 'H') { - dnums.set_spatial_dimensions(0, i); + dnums.set_input_spatial_dimensions(0, i); + dnums.set_output_spatial_dimensions(0, i); in_dims.push_back(options.in_height); } else if (ch == 'W') { - dnums.set_spatial_dimensions(1, i); + dnums.set_input_spatial_dimensions(1, i); + dnums.set_output_spatial_dimensions(1, i); in_dims.push_back(options.in_width); } else if (ch == 'C') { dnums.set_input_feature_dimension(i); @@ -2168,8 +2138,10 @@ TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) { builder.AddInstruction(HloInstruction::CreateParameter(0, r1f32, "x")); HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(1, r1f32, "y")); - builder.AddInstruction( - HloInstruction::CreateBinary(r1f32, HloOpcode::kDot, x, y)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + builder.AddInstruction(HloInstruction::CreateDot(r1f32, x, y, dot_dnums)); std::unique_ptr dot_computation(builder.Build()); HloComputation::Builder call_builder(TestName() + ".Call"); @@ -2208,99 +2180,6 @@ TEST_F(AlgebraicSimplifierTest, ConstantTupleBecomesTupleOfConstants) { op::Tuple(op::Constant(), op::Constant())); } -TEST_F(AlgebraicSimplifierTest, WhileLoopWithZeroIterations) { - HloModule module(TestName()); - HloComputation* computation = MakeSimpleLoop(&module, /*num_iters=*/0); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), - op::Tuple(op::Constant(), op::Constant())); -} - -TEST_F(AlgebraicSimplifierTest, WhileLoopWithOneIteration) { - HloModule module(TestName()); - HloComputation* computation = MakeSimpleLoop(&module, /*num_iters=*/1); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), - op::Tuple(op::Add(), op::Multiply())); -} - -TEST_F(AlgebraicSimplifierTest, WhileLoopWithTwoIterations) { - HloModule module(TestName()); - MakeSimpleLoop(&module, /*num_iters=*/2); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - EXPECT_FALSE(simplifier.Run(&module).ValueOrDie()); -} - -TEST_F(AlgebraicSimplifierTest, WhileLoopWithControlDependency) { - HloModule module(TestName()); - HloComputation* computation = MakeSimpleLoop(&module, /*num_iters=*/1); - auto* while_op = computation->root_instruction(); - ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile); - auto* true_op = while_op->while_body()->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(true))); - TF_ASSERT_OK(true_op->AddControlDependencyTo( - while_op->while_body()->root_instruction())); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); - EXPECT_THAT(computation->root_instruction()->control_predecessors(), - ElementsAre(op::Constant())) - << computation->ToString(); -} - -// Loops that contain send/recv nodes can't be simplified; the loop structure -// around send/recv nodes must be preserved. -TEST_F(AlgebraicSimplifierTest, NotRemovedIfContainsSend) { - HloModule module(TestName()); - HloComputation* computation = MakeSimpleLoop(&module, /*num_iters=*/1); - auto* while_op = computation->root_instruction(); - ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile); - auto* while_body = while_op->while_body(); - while_body->AddInstruction(HloInstruction::CreateSend( - while_body->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(true))), - /*channel_id=*/0)); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - EXPECT_FALSE(simplifier.Run(&module).ValueOrDie()); -} - -TEST_F(AlgebraicSimplifierTest, NotRemovedIfContainsRecv) { - HloModule module(TestName()); - HloComputation* computation = MakeSimpleLoop(&module, /*num_iters=*/1); - auto* while_op = computation->root_instruction(); - ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile); - auto* while_body = while_op->while_body(); - while_body->AddInstruction( - HloInstruction::CreateRecv(ShapeUtil::MakeShape(F32, {1}), - /*channel_id=*/0)); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - EXPECT_FALSE(simplifier.Run(&module).ValueOrDie()); -} - -// The limitation on not being able to simplify loops that contain infeeds (and -// other non-removable instructions) isn't fundamental -- it just stems from the -// fact that our infrastructure sees simplifying such a loop as tantamount to -// removing the non-removable instruction. -TEST_F(AlgebraicSimplifierTest, NotRemovedIfContainsNonRemovableInstruction) { - HloModule module(TestName()); - HloComputation* computation = MakeSimpleLoop(&module, /*num_iters=*/1); - auto* while_op = computation->root_instruction(); - ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile); - auto* while_body = while_op->while_body(); - while_body->AddInstruction( - HloInstruction::CreateInfeed(ShapeUtil::MakeShape(F32, {1}), "config")); - AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, - non_bitcasting_callback()); - EXPECT_FALSE(simplifier.Run(&module).ValueOrDie()); -} - // A dynamic-slice is trivial if its start indices are all zeroes and the size // of its input equals the size of its output. In this case, the dynamic slice // is equal to its input. @@ -2359,5 +2238,63 @@ TEST_F(AlgebraicSimplifierTest, TrivialDynamicUpdateSlice) { op::DynamicSlice(op::Parameter(), op::Parameter())); } +class DotStrengthReductionTest + : public AlgebraicSimplifierTest, + public ::testing::WithParamInterface< + ::testing::tuple> {}; +TEST_P(DotStrengthReductionTest, DotStrengthReduction) { + int m, k, n; + bool transpose_lhs, transpose_rhs; + std::tie(m, k, n, transpose_lhs, transpose_rhs) = GetParam(); + + Shape dot_shape = ShapeUtil::MakeShape(F32, {m, n}); + Shape lhs_shape = ShapeUtil::MakeShape(F32, {m, k}); + Shape transposed_lhs_shape = ShapeUtil::MakeShape(F32, {k, m}); + Shape rhs_shape = ShapeUtil::MakeShape(F32, {k, n}); + Shape transposed_rhs_shape = ShapeUtil::MakeShape(F32, {n, k}); + HloComputation::Builder builder(TestName()); + + auto lhs = builder.AddInstruction(HloInstruction::CreateParameter( + 0, transpose_lhs ? transposed_lhs_shape : lhs_shape, "lhs")); + if (transpose_lhs) { + lhs = builder.AddInstruction( + HloInstruction::CreateTranspose(lhs_shape, lhs, {1, 0})); + } + auto rhs = builder.AddInstruction(HloInstruction::CreateParameter( + 1, transpose_rhs ? transposed_rhs_shape : rhs_shape, "rhs")); + if (transpose_rhs) { + rhs = builder.AddInstruction( + HloInstruction::CreateTranspose(rhs_shape, rhs, {1, 0})); + } + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + builder.AddInstruction( + HloInstruction::CreateDot(dot_shape, lhs, rhs, dot_dnums)); + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + TF_ASSERT_OK_AND_ASSIGN(bool changed, simplifier.Run(module.get())); + const bool dot_should_be_transformed = m == 1 || k == 1 || n == 1; + const bool computation_should_be_modified = + dot_should_be_transformed || (transpose_lhs && transpose_rhs); + EXPECT_EQ(changed, computation_should_be_modified); + bool has_no_dot = true; + for (const auto& hlo : computation->instructions()) { + if (hlo->opcode() == HloOpcode::kDot) { + has_no_dot = false; + break; + } + } + EXPECT_EQ(has_no_dot, dot_should_be_transformed); +} + +INSTANTIATE_TEST_CASE_P( + DotStrengthReductionTestInstantiation, DotStrengthReductionTest, + ::testing::Combine(::testing::Values(1, 2), ::testing::Values(1, 2), + ::testing::Values(1, 2), ::testing::Bool(), + ::testing::Bool())); + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/backend.cc b/tensorflow/compiler/xla/service/backend.cc index 9abe30e3f371cc294c36c1dcd743224b11b0c4f5..05f2d062784147108a94ffb7bb0ca42ddfe4f010 100644 --- a/tensorflow/compiler/xla/service/backend.cc +++ b/tensorflow/compiler/xla/service/backend.cc @@ -13,14 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#define EIGEN_USE_THREADS + #include "tensorflow/compiler/xla/service/backend.h" #include #include #include -#define EIGEN_USE_THREADS - #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/platform_util.h" diff --git a/tensorflow/compiler/xla/service/batchnorm_rewriter.cc b/tensorflow/compiler/xla/service/batchnorm_rewriter.cc index abe881cd1a58a6173b9b93f10a7308d70106c889..2bbae25aee3db95406fd247deb788d2976207ba3 100644 --- a/tensorflow/compiler/xla/service/batchnorm_rewriter.cc +++ b/tensorflow/compiler/xla/service/batchnorm_rewriter.cc @@ -85,9 +85,9 @@ class BatchNormRewriterVisitor : public DfsHloVisitorWithDefault { HloOpcode opcode) { HloComputation::Builder b("scalar_computation"); auto scalar_lhs = b.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(F32, {}), "scalar_lhs")); + 0, ShapeUtil::MakeShape(primitive_type, {}), "scalar_lhs")); auto scalar_rhs = b.AddInstruction(HloInstruction::CreateParameter( - 1, ShapeUtil::MakeShape(F32, {}), "scalar_rhs")); + 1, ShapeUtil::MakeShape(primitive_type, {}), "scalar_rhs")); auto scalar_op = b.AddInstruction( HloInstruction::CreateBinary(ShapeUtil::MakeShape(primitive_type, {}), opcode, scalar_lhs, scalar_rhs)); @@ -149,26 +149,41 @@ Status BatchNormRewriterVisitor::HandleBatchNormTraining( if (!rewrite_training_op_) { return Status::OK(); } + + std::vector added_instructions; + auto add = [&](std::unique_ptr inst) { + HloInstruction* added_inst = computation_->AddInstruction(std::move(inst)); + added_instructions.push_back(added_inst); + return added_inst; + }; + int64 instruction_count_before = computation_->instruction_count(); + // Expand batch norm training into smaller HLO ops. HloInstruction* operand = batch_norm->mutable_operand(0); const Shape operand_shape = operand->shape(); + PrimitiveType ptype = operand_shape.element_type(); int64 feature_index = batch_norm->feature_index(); const int64 feature_count = operand_shape.dimensions(feature_index); const int64 size_in_elements = ShapeUtil::ElementsIn(operand_shape); - auto elements_per_feature = - computation_->AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR0(size_in_elements / feature_count))); + auto elements_per_feature_literal = + Literal::CreateR0(size_in_elements / feature_count); + TF_ASSIGN_OR_RETURN(elements_per_feature_literal, + elements_per_feature_literal->Convert(ptype)); + auto elements_per_feature = add( + HloInstruction::CreateConstant(std::move(elements_per_feature_literal))); HloInstruction* scale = batch_norm->mutable_operand(1); HloInstruction* offset = batch_norm->mutable_operand(2); const Shape feature_shape = scale->shape(); - auto zero = computation_->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); - - auto epsilon = computation_->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(batch_norm->epsilon()))); + auto zero_literal = Literal::CreateR0(0.0f); + TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype)); + auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal))); + auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon()); + TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype)); + auto epsilon = + add(HloInstruction::CreateConstant(std::move(epsilon_literal))); std::vector dimensions_without_feature; for (int64 i = 0; i < ShapeUtil::Rank(operand_shape); ++i) { @@ -177,103 +192,110 @@ Status BatchNormRewriterVisitor::HandleBatchNormTraining( } } - auto scale_broadcasted = computation_->AddInstruction( + auto scale_broadcasted = add( HloInstruction::CreateBroadcast(operand_shape, scale, {feature_index})); - auto offset_broadcasted = computation_->AddInstruction( + auto offset_broadcasted = add( HloInstruction::CreateBroadcast(operand_shape, offset, {feature_index})); HloComputation* add_reduce_computation = - GetScalarBinaryComputation(F32, HloOpcode::kAdd); + GetScalarBinaryComputation(ptype, HloOpcode::kAdd); // X^2. - auto operand_squared = - computation_->AddInstruction(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kMultiply, operand, operand)); + auto operand_squared = add(HloInstruction::CreateBinary( + operand_shape, HloOpcode::kMultiply, operand, operand)); // Sum[X]. - auto sum = computation_->AddInstruction(HloInstruction::CreateReduce( - feature_shape, operand, zero, dimensions_without_feature, - add_reduce_computation)); + auto sum = add(HloInstruction::CreateReduce(feature_shape, operand, zero, + dimensions_without_feature, + add_reduce_computation)); // Sum[X^2]. - auto squared_sum = computation_->AddInstruction(HloInstruction::CreateReduce( + auto squared_sum = add(HloInstruction::CreateReduce( feature_shape, operand_squared, zero, dimensions_without_feature, add_reduce_computation)); // Fuse two parallel reduces together to improve performance. - if (use_fusion_) { - auto tuple = computation_->AddInstruction( - HloInstruction::CreateTuple({sum, squared_sum})); + if (use_fusion_ && !batch_norm->has_sharding()) { + auto tuple = add(HloInstruction::CreateTuple({sum, squared_sum})); auto fused = computation_->CreateFusionInstruction( {tuple, sum, squared_sum, operand_squared}, HloInstruction::FusionKind::kInput); - sum = computation_->AddInstruction( - HloInstruction::CreateGetTupleElement(feature_shape, fused, 0)); + sum = add(HloInstruction::CreateGetTupleElement(feature_shape, fused, 0)); - squared_sum = computation_->AddInstruction( - HloInstruction::CreateGetTupleElement(feature_shape, fused, 1)); + squared_sum = + add(HloInstruction::CreateGetTupleElement(feature_shape, fused, 1)); } // E[X]. - auto mean = computation_->AddInstruction(HloInstruction::CreateBinary( + auto mean = add(HloInstruction::CreateBinary( feature_shape, HloOpcode::kDivide, sum, elements_per_feature)); - auto mean_broadcasted = computation_->AddInstruction( + auto mean_broadcasted = add( HloInstruction::CreateBroadcast(operand_shape, mean, {feature_index})); // E[X^2]. - auto square_mean = computation_->AddInstruction(HloInstruction::CreateBinary( + auto square_mean = add(HloInstruction::CreateBinary( feature_shape, HloOpcode::kDivide, squared_sum, elements_per_feature)); // E^2[X]. - auto mean_square = computation_->AddInstruction(HloInstruction::CreateBinary( + auto mean_square = add(HloInstruction::CreateBinary( feature_shape, HloOpcode::kMultiply, mean, mean)); // Var[X]. - auto var = computation_->AddInstruction(HloInstruction::CreateBinary( + auto var = add(HloInstruction::CreateBinary( feature_shape, HloOpcode::kSubtract, square_mean, mean_square)); - auto var_broadcasted = computation_->AddInstruction( - HloInstruction::CreateBroadcast(operand_shape, var, {feature_index})); + auto var_broadcasted = + add(HloInstruction::CreateBroadcast(operand_shape, var, {feature_index})); // Var[X] + epsilon. - auto var_add_epsilon = - computation_->AddInstruction(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon)); + auto var_add_epsilon = add(HloInstruction::CreateBinary( + operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon)); - auto neg_half = computation_->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(-0.5f))); + auto neg_half_literal = Literal::CreateR0(-0.5f); + TF_ASSIGN_OR_RETURN(neg_half_literal, neg_half_literal->Convert(ptype)); + auto neg_half = + add(HloInstruction::CreateConstant(std::move(neg_half_literal))); // 1 / Sqrt[Var[X] + epsilon]. - auto rsqrt_var_add_epsilon = - computation_->AddInstruction(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kPower, var_add_epsilon, neg_half)); + auto rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary( + operand_shape, HloOpcode::kPower, var_add_epsilon, neg_half)); // X - E[X]. - auto operand_minus_mean = - computation_->AddInstruction(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kSubtract, operand, mean_broadcasted)); + auto operand_minus_mean = add(HloInstruction::CreateBinary( + operand_shape, HloOpcode::kSubtract, operand, mean_broadcasted)); // (X - E[X]) / Sqrt[Var[X] + epsilon]. - auto normalized = computation_->AddInstruction( + auto normalized = add( HloInstruction::CreateBinary(operand_shape, HloOpcode::kMultiply, operand_minus_mean, rsqrt_var_add_epsilon)); // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale. - auto scaled_normalized = - computation_->AddInstruction(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kMultiply, normalized, scale_broadcasted)); + auto scaled_normalized = add(HloInstruction::CreateBinary( + operand_shape, HloOpcode::kMultiply, normalized, scale_broadcasted)); // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale + offset. - auto shifted_normalized = computation_->AddInstruction( - HloInstruction::CreateBinary(operand_shape, HloOpcode::kAdd, - scaled_normalized, offset_broadcasted)); - - TF_CHECK_OK(ReplaceWithNewInstruction( - batch_norm, - HloInstruction::CreateTuple({shifted_normalized, mean, var}))); + auto shifted_normalized = add(HloInstruction::CreateBinary( + operand_shape, HloOpcode::kAdd, scaled_normalized, offset_broadcasted)); + + auto tuple = HloInstruction::CreateTuple({shifted_normalized, mean, var}); + + if (batch_norm->has_sharding()) { + int64 instruction_count_after = computation_->instruction_count(); + CHECK_EQ(instruction_count_after, + instruction_count_before + added_instructions.size()); + for (HloInstruction* inst : added_instructions) { + if (ShapeUtil::Equal(inst->shape(), operand_shape)) { + inst->set_sharding(batch_norm->sharding()); + } else { + inst->set_sharding(HloSharding::Replicate()); + } + } + tuple->set_sharding(batch_norm->sharding()); + } + TF_CHECK_OK(ReplaceWithNewInstruction(batch_norm, std::move(tuple))); return Status::OK(); } @@ -286,6 +308,7 @@ Status BatchNormRewriterVisitor::HandleBatchNormInference( HloInstruction* operand = batch_norm->mutable_operand(0); const Shape operand_shape = operand->shape(); int64 feature_index = batch_norm->feature_index(); + PrimitiveType ptype = operand_shape.element_type(); HloInstruction* scale = batch_norm->mutable_operand(1); HloInstruction* offset = batch_norm->mutable_operand(2); @@ -293,8 +316,10 @@ Status BatchNormRewriterVisitor::HandleBatchNormInference( HloInstruction* var = batch_norm->mutable_operand(4); const Shape feature_shape = scale->shape(); + auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon()); + TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype)); auto epsilon = computation_->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(batch_norm->epsilon()))); + HloInstruction::CreateConstant(std::move(epsilon_literal))); std::vector dimensions_without_feature; @@ -304,50 +329,69 @@ Status BatchNormRewriterVisitor::HandleBatchNormInference( } } - auto scale_broadcasted = computation_->AddInstruction( + std::vector added_instructions; + auto add = [&](std::unique_ptr inst) { + HloInstruction* added_inst = computation_->AddInstruction(std::move(inst)); + added_instructions.push_back(added_inst); + return added_inst; + }; + int64 instruction_count_before = computation_->instruction_count(); + + auto scale_broadcasted = add( HloInstruction::CreateBroadcast(operand_shape, scale, {feature_index})); - auto offset_broadcasted = computation_->AddInstruction( + auto offset_broadcasted = add( HloInstruction::CreateBroadcast(operand_shape, offset, {feature_index})); - auto mean_broadcasted = computation_->AddInstruction( + auto mean_broadcasted = add( HloInstruction::CreateBroadcast(operand_shape, mean, {feature_index})); - auto var_broadcasted = computation_->AddInstruction( - HloInstruction::CreateBroadcast(operand_shape, var, {feature_index})); + auto var_broadcasted = + add(HloInstruction::CreateBroadcast(operand_shape, var, {feature_index})); // Var[X] + epsilon. - auto var_add_epsilon = - computation_->AddInstruction(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon)); + auto var_add_epsilon = add(HloInstruction::CreateBinary( + operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon)); - auto neg_half = computation_->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(-0.5f))); + auto neg_half_literal = Literal::CreateR0(-0.5f); + TF_ASSIGN_OR_RETURN(neg_half_literal, neg_half_literal->Convert(ptype)); + auto neg_half = + add(HloInstruction::CreateConstant(std::move(neg_half_literal))); // 1 / Sqrt[Var[X] + epsilon]. - auto rsqrt_var_add_epsilon = - computation_->AddInstruction(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kPower, var_add_epsilon, neg_half)); + auto rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary( + operand_shape, HloOpcode::kPower, var_add_epsilon, neg_half)); // X - E[X]. - auto operand_minus_mean = - computation_->AddInstruction(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kSubtract, operand, mean_broadcasted)); + auto operand_minus_mean = add(HloInstruction::CreateBinary( + operand_shape, HloOpcode::kSubtract, operand, mean_broadcasted)); // (X - E[X]) / Sqrt[Var[X] + epsilon]. - auto normalized = computation_->AddInstruction( + auto normalized = add( HloInstruction::CreateBinary(operand_shape, HloOpcode::kMultiply, operand_minus_mean, rsqrt_var_add_epsilon)); // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale. - auto scaled_normalized = - computation_->AddInstruction(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kMultiply, normalized, scale_broadcasted)); + auto scaled_normalized = add(HloInstruction::CreateBinary( + operand_shape, HloOpcode::kMultiply, normalized, scale_broadcasted)); // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale + offset. auto shifted_normalized = HloInstruction::CreateBinary( operand_shape, HloOpcode::kAdd, scaled_normalized, offset_broadcasted); + int64 instruction_count_after = computation_->instruction_count(); + CHECK_EQ(instruction_count_after, + instruction_count_before + added_instructions.size()); + if (batch_norm->has_sharding()) { + for (HloInstruction* inst : added_instructions) { + if (ShapeUtil::Equal(inst->shape(), operand_shape)) { + inst->set_sharding(batch_norm->sharding()); + } else { + inst->set_sharding(HloSharding::Replicate()); + } + } + shifted_normalized->set_sharding(batch_norm->sharding()); + } TF_CHECK_OK( ReplaceWithNewInstruction(batch_norm, std::move(shifted_normalized))); return Status::OK(); @@ -370,9 +414,17 @@ Status BatchNormRewriterVisitor::HandleBatchNormGrad( if (!rewrite_grad_op_) { return Status::OK(); } + std::vector added_instructions; + auto add = [&](std::unique_ptr inst) { + HloInstruction* added_inst = computation_->AddInstruction(std::move(inst)); + added_instructions.push_back(added_inst); + return added_inst; + }; + int64 instruction_count_before = computation_->instruction_count(); HloInstruction* activation = batch_norm->mutable_operand(0); const Shape activation_shape = activation->shape(); + PrimitiveType ptype = activation_shape.element_type(); HloInstruction* scale = batch_norm->mutable_operand(1); const Shape feature_shape = scale->shape(); HloInstruction* mean = batch_norm->mutable_operand(2); @@ -383,18 +435,26 @@ Status BatchNormRewriterVisitor::HandleBatchNormGrad( const int64 size_in_elements = ShapeUtil::ElementsIn(activation_shape); const int64 feature_count = activation_shape.dimensions(feature_index); - auto elements_per_feature = - computation_->AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR0(size_in_elements / feature_count))); - - auto zero = computation_->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); - - auto neg_half = computation_->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(-0.5f))); - - auto epsilon = computation_->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(batch_norm->epsilon()))); + auto elements_per_feature_literal = + Literal::CreateR0(size_in_elements / feature_count); + TF_ASSIGN_OR_RETURN(elements_per_feature_literal, + elements_per_feature_literal->Convert(ptype)); + auto elements_per_feature = add( + HloInstruction::CreateConstant(std::move(elements_per_feature_literal))); + + auto zero_literal = Literal::CreateR0(0.0f); + TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype)); + auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal))); + + auto neg_half_literal = Literal::CreateR0(-0.5f); + TF_ASSIGN_OR_RETURN(neg_half_literal, neg_half_literal->Convert(ptype)); + auto neg_half = + add(HloInstruction::CreateConstant(std::move(neg_half_literal))); + + auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon()); + TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype)); + auto epsilon = + add(HloInstruction::CreateConstant(std::move(epsilon_literal))); std::vector dimensions_without_feature; @@ -404,126 +464,131 @@ Status BatchNormRewriterVisitor::HandleBatchNormGrad( } } - auto scale_broadcasted = - computation_->AddInstruction(HloInstruction::CreateBroadcast( - activation_shape, scale, {feature_index})); - auto variance_broadcasted = - computation_->AddInstruction(HloInstruction::CreateBroadcast( - activation_shape, variance, {feature_index})); + auto scale_broadcasted = add(HloInstruction::CreateBroadcast( + activation_shape, scale, {feature_index})); + auto variance_broadcasted = add(HloInstruction::CreateBroadcast( + activation_shape, variance, {feature_index})); // E[X]. - auto mean_broadcasted = computation_->AddInstruction( + auto mean_broadcasted = add( HloInstruction::CreateBroadcast(activation_shape, mean, {feature_index})); // rsqrt[Var[X] + epsilon]. - auto rsqrt_var_add_epsilon_broadcasted = - computation_->AddInstruction(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kPower, - computation_->AddInstruction( - HloInstruction::CreateBinary(activation_shape, HloOpcode::kAdd, - variance_broadcasted, epsilon)), - neg_half)); - - auto rsqrt_var_add_epsilon = - computation_->AddInstruction(HloInstruction::CreateBinary( - feature_shape, HloOpcode::kPower, - computation_->AddInstruction(HloInstruction::CreateBinary( - feature_shape, HloOpcode::kAdd, variance, epsilon)), - neg_half)); + auto rsqrt_var_add_epsilon_broadcasted = add(HloInstruction::CreateBinary( + activation_shape, HloOpcode::kPower, + add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kAdd, + variance_broadcasted, epsilon)), + neg_half)); + + auto rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary( + feature_shape, HloOpcode::kPower, + add(HloInstruction::CreateBinary(feature_shape, HloOpcode::kAdd, variance, + epsilon)), + neg_half)); // X - E[X]. - auto activation_minus_mean = computation_->AddInstruction( - HloInstruction::CreateBinary(activation_shape, HloOpcode::kSubtract, - activation, mean_broadcasted)); + auto activation_minus_mean = add(HloInstruction::CreateBinary( + activation_shape, HloOpcode::kSubtract, activation, mean_broadcasted)); // Grad[Y] * (X - E[X]). - auto grad_output_times_activiation_minus_mean = computation_->AddInstruction( - HloInstruction::CreateBinary(activation_shape, HloOpcode::kMultiply, - grad_output, activation_minus_mean)); + auto grad_output_times_activiation_minus_mean = + add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kMultiply, + grad_output, activation_minus_mean)); HloComputation* add_reduce_computation = - GetScalarBinaryComputation(F32, HloOpcode::kAdd); + GetScalarBinaryComputation(ptype, HloOpcode::kAdd); // sum(Grad[Y] * (X - E[X])). auto sum_grad_output_times_activiation_minus_mean = - computation_->AddInstruction(HloInstruction::CreateReduce( + add(HloInstruction::CreateReduce( feature_shape, grad_output_times_activiation_minus_mean, zero, dimensions_without_feature, add_reduce_computation)); // Grad[beta] = Sum(Grad[Y]). - auto grad_beta = computation_->AddInstruction(HloInstruction::CreateReduce( + auto grad_beta = add(HloInstruction::CreateReduce( feature_shape, grad_output, zero, dimensions_without_feature, add_reduce_computation)); - if (use_fusion_) { - auto tuple = computation_->AddInstruction(HloInstruction::CreateTuple( + if (use_fusion_ && !batch_norm->has_sharding()) { + auto tuple = add(HloInstruction::CreateTuple( {sum_grad_output_times_activiation_minus_mean, grad_beta})); auto fused = computation_->CreateFusionInstruction( {tuple, sum_grad_output_times_activiation_minus_mean, grad_beta}, HloInstruction::FusionKind::kInput); - sum_grad_output_times_activiation_minus_mean = computation_->AddInstruction( - HloInstruction::CreateGetTupleElement(feature_shape, fused, 0)); + sum_grad_output_times_activiation_minus_mean = + add(HloInstruction::CreateGetTupleElement(feature_shape, fused, 0)); - grad_beta = computation_->AddInstruction( - HloInstruction::CreateGetTupleElement(feature_shape, fused, 1)); + grad_beta = + add(HloInstruction::CreateGetTupleElement(feature_shape, fused, 1)); } // Grad[scale] = Sum(Grad[Y] * (X - E[X]) * rsqrt[Var[X] + epsilon]). - auto grad_scale = computation_->AddInstruction(HloInstruction::CreateBinary( + auto grad_scale = add(HloInstruction::CreateBinary( feature_shape, HloOpcode::kMultiply, sum_grad_output_times_activiation_minus_mean, rsqrt_var_add_epsilon)); // I2 = Sum(Grad[Y]) - auto I2 = computation_->AddInstruction(HloInstruction::CreateBroadcast( - activation_shape, grad_beta, {feature_index})); + auto i2 = add(HloInstruction::CreateBroadcast(activation_shape, grad_beta, + {feature_index})); // I3 = Sum(Grad[Y] * (X - E[X])) - auto I3 = computation_->AddInstruction(HloInstruction::CreateBroadcast( + auto i3 = add(HloInstruction::CreateBroadcast( activation_shape, sum_grad_output_times_activiation_minus_mean, {feature_index})); // I4 = (X - E[X]) * I3 - auto I4 = computation_->AddInstruction(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kMultiply, I3, activation_minus_mean)); + auto i4 = add(HloInstruction::CreateBinary( + activation_shape, HloOpcode::kMultiply, i3, activation_minus_mean)); // I5 = I4 / (Var[X] + epsilon) - auto I5 = computation_->AddInstruction(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kDivide, I4, - computation_->AddInstruction(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kAdd, variance_broadcasted, epsilon)))); + auto i5 = add(HloInstruction::CreateBinary( + activation_shape, HloOpcode::kDivide, i4, + add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kAdd, + variance_broadcasted, epsilon)))); // scale * rsqrt[Var[X] + epsilon] * 1/N - auto scale_times_rsqrt_var_add_epsilon = - computation_->AddInstruction(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kMultiply, scale_broadcasted, - rsqrt_var_add_epsilon_broadcasted)); + auto scale_times_rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary( + activation_shape, HloOpcode::kMultiply, scale_broadcasted, + rsqrt_var_add_epsilon_broadcasted)); - scale_times_rsqrt_var_add_epsilon = - computation_->AddInstruction(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kDivide, - scale_times_rsqrt_var_add_epsilon, elements_per_feature)); + scale_times_rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary( + activation_shape, HloOpcode::kDivide, scale_times_rsqrt_var_add_epsilon, + elements_per_feature)); - auto I1 = computation_->AddInstruction( - HloInstruction::CreateBinary(activation_shape, HloOpcode::kMultiply, - grad_output, elements_per_feature)); + auto i1 = + add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kMultiply, + grad_output, elements_per_feature)); // I6 = I1 - I2 - I5 - auto I6 = computation_->AddInstruction(HloInstruction::CreateBinary( + auto i6 = add(HloInstruction::CreateBinary( activation_shape, HloOpcode::kSubtract, - computation_->AddInstruction(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kSubtract, I1, I2)), - I5)); + add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kSubtract, + i1, i2)), + i5)); // Grad[X] = scale * rsqrt[Var[X] + epsilon] * 1/N * I6. - auto grad_activation = computation_->AddInstruction( - HloInstruction::CreateBinary(activation_shape, HloOpcode::kMultiply, - scale_times_rsqrt_var_add_epsilon, I6)); + auto grad_activation = + add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kMultiply, + scale_times_rsqrt_var_add_epsilon, i6)); + auto tuple = + HloInstruction::CreateTuple({grad_activation, grad_scale, grad_beta}); + if (batch_norm->has_sharding()) { + int64 instruction_count_after = computation_->instruction_count(); + CHECK_EQ(instruction_count_after, + instruction_count_before + added_instructions.size()); + for (HloInstruction* inst : added_instructions) { + if (ShapeUtil::Equal(inst->shape(), activation_shape)) { + inst->set_sharding(batch_norm->sharding()); + } else { + inst->set_sharding(HloSharding::Replicate()); + } + } + tuple->set_sharding(batch_norm->sharding()); + } - TF_CHECK_OK(ReplaceWithNewInstruction( - batch_norm, - HloInstruction::CreateTuple({grad_activation, grad_scale, grad_beta}))); + TF_CHECK_OK(ReplaceWithNewInstruction(batch_norm, std::move(tuple))); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index 8536429846f87fd5c4b073cc4b13b3f1c5eb2e5c..7ece79d781acfaffc21d6a29e8a12e68622a1617 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -101,6 +101,11 @@ BufferAllocationProto BufferAllocation::ToProto() const { proto_assigned->set_offset(buffer_offset_size.second.offset); proto_assigned->set_size(buffer_offset_size.second.size); } + std::sort(proto.mutable_assigned()->begin(), proto.mutable_assigned()->end(), + [](const BufferAllocationProto::Assigned& assign1, + const BufferAllocationProto::Assigned& assign2) { + return assign1.logical_buffer_id() < assign2.logical_buffer_id(); + }); return proto; } @@ -260,6 +265,42 @@ bool BufferAssignment::SharesSliceAtIndex( GetUniqueSlice(hlo_b, shape_index_b).ConsumeValueOrDie(); } +bool BufferAssignment::HaveDisjointSlices(const HloInstruction* hlo_a, + const HloInstruction* hlo_b) const { + using SliceSet = + FlatSet; + // Gets the slices all of instr's subshapes. If any subshape doesn't have an + // assigned slice, returns the empty set. + auto collect_slices = [&](const HloInstruction* instr) -> SliceSet { + SliceSet slices; + Status status = ShapeUtil::ForEachSubshapeWithStatus( + instr->shape(), + [&](const Shape& /*subshape*/, const ShapeIndex& index) { + auto shape_slices = GetAllSlices(instr, index); + if (shape_slices.empty()) { + return InvalidArgument("No slices assigned to part of instr."); + } + slices.insert(shape_slices.begin(), shape_slices.end()); + return Status::OK(); + }); + if (!status.ok()) { + return {}; + } + return slices; + }; + + SliceSet slices_a = collect_slices(hlo_a); + SliceSet slices_b = collect_slices(hlo_b); + // hlo_a and hlo_b have disjoint slices if collect_slices succeeded (i.e. + // didn't return the empty set) for both HLOs, and the two resulting sets of + // slices are disjoint. + return !slices_a.empty() && !slices_b.empty() && + std::none_of(slices_a.begin(), slices_a.end(), + [&](const BufferAllocation::Slice& slice) { + return slices_b.count(slice) > 0; + }); +} + StatusOr BufferAssignment::GetUniqueTopLevelOutputSlice() const { return GetUniqueTopLevelSlice( @@ -492,19 +533,19 @@ Status GatherComputationsByAllocationType( std::vector* global_computations) { // Create a worklist of computations paired with whether the allocation must // be thread-local. - std::deque> worklist; + std::deque> worklist; worklist.push_back(std::make_pair(module->entry_computation(), /*is_thread_local*/ false)); // Sets for quickly checking membership. Computations are returned in vectors // for stable iteration. - FlatSet thread_local_set; - FlatSet global_set; + FlatSet thread_local_set; + FlatSet global_set; while (!worklist.empty()) { auto worklist_front = worklist.front(); worklist.pop_front(); - HloComputation* computation = worklist_front.first; + const HloComputation* computation = worklist_front.first; bool is_thread_local = worklist_front.second; bool in_thread_local_set = thread_local_set.count(computation) > 0; bool in_global_set = global_set.count(computation) > 0; @@ -540,6 +581,7 @@ Status GatherComputationsByAllocationType( instruction->called_computations()) { switch (instruction->opcode()) { case HloOpcode::kCall: + case HloOpcode::kConditional: case HloOpcode::kWhile: // Call and while must be called from a computation with global // allocations as they may return references to buffers inside the @@ -648,7 +690,7 @@ bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation, } if (allow_input_output_aliasing_ && allocation->maybe_live_out()) { - HloComputation* entry_computation = + const HloComputation* entry_computation = assignment->module_->entry_computation(); for (auto param : entry_computation->parameter_instructions()) { for (auto& param_buffer : @@ -814,17 +856,6 @@ Status BufferAssigner::AssignBuffersForComputation( continue; } - if (instruction->opcode() == HloOpcode::kRecv) { - // Make sure that recv operations get a new unique allocation so that - // don't share their buffer with any other operations. - BufferAllocation* allocation = assignment->NewAllocation( - *buffer, buffer_size, is_thread_local, /*is_reusable=*/false); - allocation_indices.push_back(allocation->index()); - VLOG(3) << "New allocation #" << allocation->index() - << " for recv: " << *buffer; - continue; - } - if (ShapeUtil::IsTuple(buffer->shape())) { // TODO(b/34669761): Don't reuse tuple buffers because the GPU backend // assumes longer buffer liveness than indicated by the analysis. @@ -946,8 +977,8 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( const HloOrdering& hlo_ordering = assignment->liveness().hlo_ordering(); if (run_whole_module_heap_simulation) { // Run the heap simulation over the whole module. This reduces memory usage, - // since buffers for kCall and kWhile sub-computations are only live for the - // duration of their calling instructions. + // since buffers for kCall, kWhile, and kConditional sub-computations are + // only live for the duration of their calling instructions. VLOG(1) << "Running whole-module heap simulation"; SequentialHloOrdering::HloModuleSequence module_sequence; FlatSet all_buffers_to_assign; @@ -1235,7 +1266,6 @@ const LogicalBuffer* AddBufferToColocatedSet( // CopyInsertion ensures root points-to set is unambiguous and distinct. const auto& points_to = points_to_analysis.GetPointsToSet(instruction); DCHECK(!points_to.IsAmbiguous()); - DCHECK(points_to.IsDistinct()); colocated_set->push_back(points_to.element(index)[0]); return colocated_set->back(); } @@ -1243,7 +1273,8 @@ const LogicalBuffer* AddBufferToColocatedSet( } // namespace // Builds sets of buffers in 'colocated_buffer_sets' which should be colocated -// in the same allocation (currently just supports kWhile and kCall). +// in the same allocation (currently just supports kWhile, kCall, and +// kConditional). void BufferAssigner::BuildColocatedBufferSets( const HloModule* module, const BufferLiveness& buffer_liveness, const LogicalBuffer::SizeFunction& buffer_size, @@ -1307,6 +1338,26 @@ void BufferAssigner::BuildColocatedBufferSets( &colocated_set); AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets); }); + } else if (opcode == HloOpcode::kConditional) { + const HloInstruction* conditional_hlo = instruction; + ShapeUtil::ForEachSubshape( + conditional_hlo->shape(), + [this, conditional_hlo, &points_to_analysis, colocated_buffer_sets]( + const Shape& /*subshape*/, const ShapeIndex& index) { + std::vector colocated_set; + // Add conditional.result. + AddBufferToColocatedSet(conditional_hlo, index, + points_to_analysis, &colocated_set); + // Add conditional.true_computation.root. + AddBufferToColocatedSet( + conditional_hlo->true_computation()->root_instruction(), + index, points_to_analysis, &colocated_set); + // Add conditional.false_computation.root. + AddBufferToColocatedSet( + conditional_hlo->false_computation()->root_instruction(), + index, points_to_analysis, &colocated_set); + AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets); + }); } } } diff --git a/tensorflow/compiler/xla/service/buffer_assignment.h b/tensorflow/compiler/xla/service/buffer_assignment.h index 08a53af8baa3f250919517c87c023c329b129024..08a40bfeb2a2a78c25805308e73154c6cc667f21 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.h +++ b/tensorflow/compiler/xla/service/buffer_assignment.h @@ -327,6 +327,12 @@ class BufferAssignment { return SharesSliceAtIndex(hlo_a, {}, hlo_b, {}); } + // Returns true if hlo_a and hlo_b both have at least one buffer assigned for + // their top-level and each of their nested shape indices, and if hlo_a's + // buffers are all different from hlo_b's buffers. + bool HaveDisjointSlices(const HloInstruction* hlo_a, + const HloInstruction* hlo_b) const; + // Returns the underlying points-to analysis used for this assignment. const TuplePointsToAnalysis& points_to_analysis() const { return liveness_->points_to_analysis(); diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index 89410f42bd7b5fa8f9b380c868fcd4fedb54576c..6fc9d783f1b34de8c0f93c6aa342591891d08eaf 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -85,7 +85,7 @@ class BufferAssignmentTest : public HloTestBase { std::unique_ptr RunBufferAssignment(HloModule* module, int64 alignment = 1) { return BufferAssigner::Run( - module, MakeUnique(module), + module, xla::MakeUnique(module), backend().compiler()->BufferSizeBytesFunction(), [alignment](LogicalBuffer::Color) { return alignment; }) .ConsumeValueOrDie(); @@ -94,7 +94,7 @@ class BufferAssignmentTest : public HloTestBase { std::unique_ptr RunColoredBufferAssignment( HloModule* module, BufferLiveness::Colorer colorer, int64 alignment = 1) { return BufferAssigner::Run( - module, MakeUnique(module), + module, xla::MakeUnique(module), backend().compiler()->BufferSizeBytesFunction(), [alignment](LogicalBuffer::Color) { return alignment; }, false, std::move(colorer)) @@ -166,6 +166,15 @@ class BufferAssignmentTest : public HloTestBase { return builder.Build(); } + std::unique_ptr BuildR0F32UnaryOpComputation( + HloOpcode opcode, const string& name) { + auto builder = HloComputation::Builder(name); + auto param = + builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x")); + builder.AddInstruction(HloInstruction::CreateUnary(r0f32_, opcode, param)); + return builder.Build(); + } + // Verifies that the given instruction hlo has a valid input buffer assigned, // i.e., the parameter number matches the op's. const BufferAllocation& GetAssignedInputAllocation( @@ -740,6 +749,56 @@ TEST_F(BufferAssignmentTest, ExampleWhile) { << " instructions; total buffer size " << size0 + sizec + sizeb; } +TEST_F(BufferAssignmentTest, ExampleConditional) { + auto module = CreateNewModule(); + auto true_computation = module->AddEmbeddedComputation( + BuildR0F32UnaryOpComputation(HloOpcode::kCeil, "Ceil")); + auto false_computation = module->AddEmbeddedComputation( + BuildR0F32UnaryOpComputation(HloOpcode::kFloor, "Floor")); + + auto builder = HloComputation::Builder(TestName()); + auto pred = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(false))); + auto const1 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(56.4f))); + auto const2 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(12.4f))); + auto conditional = builder.AddInstruction(HloInstruction::CreateConditional( + r0f32_, pred, const1, true_computation, const2, false_computation)); + module->AddEntryComputation(builder.Build()); + + const std::vector conditional_instrs = + GetInstructions(conditional); + const std::vector true_instrs = + GetInstructions(true_computation->root_instruction()); + const std::vector false_instrs = + GetInstructions(false_computation->root_instruction()); + EXPECT_EQ(4, conditional_instrs.size()); + EXPECT_EQ(2, true_instrs.size()); + EXPECT_EQ(2, false_instrs.size()); + + auto buffers = RunBufferAssignment(module.get()); + ValidateBuffers(conditional_instrs, *buffers); + ValidateBuffers(true_instrs, *buffers); + ValidateBuffers(false_instrs, *buffers); + + EXPECT_FALSE(BuffersDistinct(conditional_instrs, true_instrs, *buffers)) + << "Should be reuse between conditional and true computation."; + EXPECT_FALSE(BuffersDistinct(conditional_instrs, false_instrs, *buffers)) + << "Should be reuse between conditional and false computation."; + EXPECT_FALSE(BuffersDistinct(true_instrs, false_instrs, *buffers)) + << "Should be reuse between true and false computations."; + + const BufferAllocation& conditional_buffer = + GetTopLevelAllocation(*buffers, conditional); + const BufferAllocation& true_buffer = + GetTopLevelAllocation(*buffers, true_computation->root_instruction()); + const BufferAllocation& false_buffer = + GetTopLevelAllocation(*buffers, false_computation->root_instruction()); + EXPECT_EQ(conditional_buffer.size(), true_buffer.size()); + EXPECT_EQ(conditional_buffer.size(), false_buffer.size()); +} + TEST_F(BufferAssignmentTest, UnaryOpReuseChain) { // param0[100] ---> (exp) ---> (tanh) ---> (exp) ---> (neg) auto builder = HloComputation::Builder(TestName()); @@ -1360,10 +1419,13 @@ TEST_F(BufferAssignmentTest, OneTempAllocation) { HloInstruction::CreateParameter(1, shape_3x4, "param_b")); auto param_c = builder.AddInstruction( HloInstruction::CreateParameter(2, shape_4x4, "param_c")); - auto dot_ab = builder.AddInstruction(HloInstruction::CreateBinary( - shape_2x4, HloOpcode::kDot, param_a, param_b)); - auto dot_bc = builder.AddInstruction(HloInstruction::CreateBinary( - shape_3x4, HloOpcode::kDot, param_b, param_c)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + auto dot_ab = builder.AddInstruction( + HloInstruction::CreateDot(shape_2x4, param_a, param_b, dot_dnums)); + auto dot_bc = builder.AddInstruction( + HloInstruction::CreateDot(shape_3x4, param_b, param_c, dot_dnums)); builder.AddInstruction( HloInstruction::CreateConcatenate(shape_5x4, {dot_ab, dot_bc}, 1)); @@ -1448,7 +1510,7 @@ class WhileBufferAssignmentTest : public HloTestBase { auto sequence = CreateMemoryMinimizingSequence(*module, ByteSizeOf).ConsumeValueOrDie(); return BufferAssigner::Run( - module, MakeUnique(module, sequence), + module, xla::MakeUnique(module, sequence), ByteSizeOf, [alignment](LogicalBuffer::Color) { return alignment; }) .ConsumeValueOrDie(); @@ -1469,7 +1531,7 @@ static void RunCopyInsertion(HloModule* module) { } TEST_F(WhileBufferAssignmentTest, TwoForwardWhileLoops) { - auto module = MakeUnique(TestName()); + auto module = xla::MakeUnique(TestName()); auto builder = HloComputation::Builder("entry"); auto input0 = builder.AddInstruction( @@ -1526,7 +1588,7 @@ TEST_F(WhileBufferAssignmentTest, TwoForwardWhileLoops) { } TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) { - auto module = MakeUnique(TestName()); + auto module = xla::MakeUnique(TestName()); auto builder = HloComputation::Builder("entry"); auto input0 = builder.AddInstruction( @@ -1538,8 +1600,6 @@ TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) { HloInstruction::CreateConstant(Literal::CreateR0(0.0))); auto output0 = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape_, zero, {1})); - auto output1 = builder.AddInstruction( - HloInstruction::CreateBroadcast(data_shape_, zero, {1})); auto cond0 = module->AddEmbeddedComputation(BuildWhileConditionComputation("cond")); @@ -1556,10 +1616,8 @@ TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) { auto body1 = module->AddEmbeddedComputation(BuildWhileBodyComputation("body")); - auto tuple1 = builder.AddInstruction( - HloInstruction::CreateTuple({input0, weights0, output1})); auto while1 = builder.AddInstruction( - HloInstruction::CreateWhile(loop_state_shape_, cond1, body1, tuple1)); + HloInstruction::CreateWhile(loop_state_shape_, cond1, body1, while0)); module->AddEntryComputation(builder.Build()); RunCopyInsertion(module.get()); @@ -1575,7 +1633,7 @@ TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) { } TEST_F(BufferAssignmentTest, TwoCalls) { - auto module = MakeUnique(TestName()); + auto module = xla::MakeUnique(TestName()); Shape r0f32 = ShapeUtil::MakeShape(xla::F32, {}); HloComputation* sub_computation; { @@ -1640,7 +1698,7 @@ static bool IsPostOrderTraversal( } TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { - auto module = MakeUnique(TestName()); + auto module = xla::MakeUnique(TestName()); auto builder = HloComputation::Builder(TestName()); auto zero = builder.AddInstruction( @@ -1676,11 +1734,14 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { auto while1 = builder.AddInstruction( HloInstruction::CreateWhile(loop_state_shape_, cond, body, tuple1)); + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape_, while0, 0)); + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape_, while1, 1)); auto root_add = builder.AddInstruction(HloInstruction::CreateBinary( - while0->shape(), HloOpcode::kAdd, while0, while1)); - module->AddEntryComputation(builder.Build()); + while0->shape(), HloOpcode::kAdd, gte0, gte1)); - RunCopyInsertion(module.get()); + module->AddEntryComputation(builder.Build()); { FlattenCallGraph flatten; @@ -1688,84 +1749,35 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { EXPECT_TRUE(result); } + RunCopyInsertion(module.get()); + auto sequence = CreateMemoryMinimizingSequence(*module, ByteSizeOf).ConsumeValueOrDie(); // To trigger b/38494731, we want a specific Hlo sequence for the // root computation, so we overwrite that entry with a manually // crafted sequence. - std::vector sequence_for_buffer_assigment = { - input1, weights1, one, output1, tuple1, while1, input0, - weights0, zero, output0, tuple0, while0, root_add}; + sequence[module->entry_computation()] = { + input1, weights1, one, output1, while1->operand(0), while1, + input0, weights0, zero, output0, while0->operand(0), while0, + gte0, gte1, root_add}; // If this ASSERT_TRUE fails, we constructed a bogus sequence above // and this test itself is buggy. - ASSERT_TRUE(IsPostOrderTraversal(sequence_for_buffer_assigment)); - - sequence[module->entry_computation()] = - std::move(sequence_for_buffer_assigment); + ASSERT_TRUE(IsPostOrderTraversal(sequence[module->entry_computation()])); auto assignment = BufferAssigner::Run( module.get(), - MakeUnique(module.get(), sequence), ByteSizeOf, - [](LogicalBuffer::Color) { return 1; }) + xla::MakeUnique(module.get(), sequence), + ByteSizeOf, [](LogicalBuffer::Color) { return 1; }) .ConsumeValueOrDie(); EXPECT_TRUE(BuffersDistinct({while0}, {while1}, *assignment)); } -// Test buffer assignment for while nodes with multiple uses. -// TODO(b/37245345): Fix buffer assignment for this case. -TEST_F(WhileBufferAssignmentTest, DISABLED_TwoWhiles) { - auto module = MakeUnique(TestName()); - auto builder = HloComputation::Builder(TestName()); - - auto input0 = builder.AddInstruction( - HloInstruction::CreateParameter(0, data_shape_, "input0")); - auto weights0 = builder.AddInstruction( - HloInstruction::CreateParameter(1, data_shape_, "weights0")); - - auto zero = builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0))); - auto output0 = builder.AddInstruction( - HloInstruction::CreateBroadcast(data_shape_, zero, {1})); - - auto cond0 = - module->AddEmbeddedComputation(BuildWhileConditionComputation("cond")); - auto body0 = - module->AddEmbeddedComputation(BuildWhileBodyComputation("body")); - - auto tuple0 = builder.AddInstruction( - HloInstruction::CreateTuple({input0, weights0, output0})); - auto while0 = builder.AddInstruction( - HloInstruction::CreateWhile(loop_state_shape_, cond0, body0, tuple0)); - auto while1 = builder.AddInstruction( - HloInstruction::CreateWhile(loop_state_shape_, cond0, body0, while0)); - - auto get0 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(data_shape_, while0, 2)); - auto get1 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(data_shape_, while1, 2)); - builder.AddInstruction( - HloInstruction::CreateBinary(data_shape_, HloOpcode::kAdd, get0, get1)); - module->AddEntryComputation(builder.Build()); - - RunCopyInsertion(module.get()); - - { - FlattenCallGraph flatten; - TF_ASSERT_OK_AND_ASSIGN(bool result, flatten.Run(module.get())); - EXPECT_TRUE(result); - } - - auto assignment = RunBufferAssignment(module.get()); - - EXPECT_TRUE(BuffersDistinct({while0}, {while1}, *assignment)); -} - TEST_F(WhileBufferAssignmentTest, WhilesDontShareEntryParamIfLiveOut) { - auto module = MakeUnique(TestName()); + auto module = xla::MakeUnique(TestName()); auto builder = HloComputation::Builder("entry"); auto input0 = builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/buffer_liveness_test.cc b/tensorflow/compiler/xla/service/buffer_liveness_test.cc index 56600b583803e23324db778959de620440fce5cf..13825fe05bb1b98045f1a3dac3d7272a2d1151fb 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness_test.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness_test.cc @@ -120,7 +120,7 @@ TEST_F(BufferLivenessTest, ElementwiseChain) { auto liveness = BufferLiveness::Run(module.get(), - MakeUnique(module.get())) + xla::MakeUnique(module.get())) .ConsumeValueOrDie(); EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, negate)); @@ -167,10 +167,10 @@ TEST_F(BufferLivenessTest, MultipleEntryParameters_Sequential) { SequentialHloOrdering::HloModuleSequence sequence; sequence.insert({entry, {param0, negate, param1, exp, add}}); - auto liveness = BufferLiveness::Run( - module.get(), - MakeUnique(module.get(), sequence)) - .ConsumeValueOrDie(); + auto liveness = + BufferLiveness::Run(module.get(), xla::MakeUnique( + module.get(), sequence)) + .ConsumeValueOrDie(); // Entry parameters interfere as if they are defined simultaneously at // the very beginning. @@ -216,7 +216,7 @@ TEST_F(BufferLivenessTest, NonElementwiseOperand) { auto liveness = BufferLiveness::Run(module.get(), - MakeUnique(module.get())) + xla::MakeUnique(module.get())) .ConsumeValueOrDie(); EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, exp)); @@ -250,7 +250,7 @@ TEST_F(BufferLivenessTest, OverlappedBuffers) { auto liveness = BufferLiveness::Run(module.get(), - MakeUnique(module.get())) + xla::MakeUnique(module.get())) .ConsumeValueOrDie(); EXPECT_TRUE(InstructionsMayInterfere(*liveness, param, negate)); @@ -294,7 +294,7 @@ TEST_F(BufferLivenessTest, OverlappedBuffersSequentialOrder) { std::vector order = {param, negate, exp, add}; module_sequence.emplace(computation, order); auto liveness = - BufferLiveness::Run(module.get(), MakeUnique( + BufferLiveness::Run(module.get(), xla::MakeUnique( module.get(), module_sequence)) .ConsumeValueOrDie(); @@ -334,7 +334,7 @@ TEST_F(BufferLivenessTest, TupleLiveOut) { auto liveness = BufferLiveness::Run(module.get(), - MakeUnique(module.get())) + xla::MakeUnique(module.get())) .ConsumeValueOrDie(); // All buffers should be live out except the param @@ -370,7 +370,7 @@ TEST_F(BufferLivenessTest, EmbeddedComputation) { auto liveness = BufferLiveness::Run(module.get(), - MakeUnique(module.get())) + xla::MakeUnique(module.get())) .ConsumeValueOrDie(); // Buffers in different computations should always interfere. @@ -409,7 +409,7 @@ TEST_F(BufferLivenessTest, TupleConstantLiveOut) { auto liveness = BufferLiveness::Run(module.get(), - MakeUnique(module.get())) + xla::MakeUnique(module.get())) .ConsumeValueOrDie(); // Only the element buffers of the tuple constant which are pointed to by @@ -474,7 +474,7 @@ TEST_F(BufferLivenessTest, IndependentTupleElements) { auto liveness = BufferLiveness::Run(module.get(), - MakeUnique(module.get())) + xla::MakeUnique(module.get())) .ConsumeValueOrDie(); // We compare tuple element pairs that are input/output to the computation: @@ -536,7 +536,7 @@ TEST_F(BufferLivenessTest, DependentTupleElements) { auto liveness = BufferLiveness::Run(module.get(), - MakeUnique(module.get())) + xla::MakeUnique(module.get())) .ConsumeValueOrDie(); // We compare tuple element pairs that are input/output to the computation: @@ -624,8 +624,8 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest { // Run BufferLiveness on 'module'. auto liveness = - BufferLiveness::Run(module.get(), - MakeUnique(module.get())) + BufferLiveness::Run( + module.get(), xla::MakeUnique(module.get())) .ConsumeValueOrDie(); // Return whether or not buffers interference is detected between // 'tuple_param0' and 'tuple_root' at shape index '{1}'. @@ -736,8 +736,8 @@ class DynamicUpdateSliceLivenessTest : public BufferLivenessTest { module->AddEmbeddedComputation(builder.Build()); // Run BufferLiveness on 'module'. auto liveness = - BufferLiveness::Run(module.get(), - MakeUnique(module.get())) + BufferLiveness::Run( + module.get(), xla::MakeUnique(module.get())) .ConsumeValueOrDie(); // Return whether or not buffers interference is detected between // 'tuple_param0' and 'tuple_root' at shape index '{1}'. diff --git a/tensorflow/compiler/xla/service/call_graph.cc b/tensorflow/compiler/xla/service/call_graph.cc index 1adecdb939cb2c1259003d3be2c90b5a299b0f30..13eb02ca012f44b2b5ed7c6f5becb7d54b07c33c 100644 --- a/tensorflow/compiler/xla/service/call_graph.cc +++ b/tensorflow/compiler/xla/service/call_graph.cc @@ -54,6 +54,7 @@ std::ostream& operator<<(std::ostream& out, const CallContext& context) { CallContext GetInstructionCallContext(const HloInstruction* instruction) { switch (instruction->opcode()) { case HloOpcode::kCall: + case HloOpcode::kConditional: case HloOpcode::kWhile: return CallContext::kSequential; case HloOpcode::kMap: diff --git a/tensorflow/compiler/xla/service/call_graph_test.cc b/tensorflow/compiler/xla/service/call_graph_test.cc index 0395ea8c8b52315f7ca2221f412750ebadda2dd8..1ea7d538cd515c3098b6a1f03c6146d288330406 100644 --- a/tensorflow/compiler/xla/service/call_graph_test.cc +++ b/tensorflow/compiler/xla/service/call_graph_test.cc @@ -34,12 +34,13 @@ using ::testing::UnorderedElementsAre; class CallGraphTest : public HloTestBase { protected: // Build and return a trivial computation taking and returning a scalar. - std::unique_ptr MakeScalarComputation() { + std::unique_ptr MakeScalarComputation( + HloOpcode opcode = HloOpcode::kNegate) { HloComputation::Builder builder(TestName() + ".ScalarComputation"); HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, kScalarShape, "param0")); builder.AddInstruction( - HloInstruction::CreateUnary(kScalarShape, HloOpcode::kNegate, param0)); + HloInstruction::CreateUnary(kScalarShape, opcode, param0)); return builder.Build(); } @@ -236,6 +237,54 @@ TEST_F(CallGraphTest, ContextBothComputations) { EXPECT_EQ(CallContext::kBoth, sub_node.context()); } +TEST_F(CallGraphTest, ComputationWithConditional) { + // Test a call graph of a module with a conditional. + auto module = CreateNewModule(); + HloComputation* true_computation = + module->AddEmbeddedComputation(MakeScalarComputation(HloOpcode::kCeil)); + HloComputation* false_computation = + module->AddEmbeddedComputation(MakeScalarComputation(HloOpcode::kFloor)); + + HloComputation::Builder builder(TestName()); + HloInstruction* pred = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction* const1 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(56.4f))); + HloInstruction* const2 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(12.6f))); + HloInstruction* conditional = + builder.AddInstruction(HloInstruction::CreateConditional( + kScalarShape, pred, const1, true_computation, const2, + false_computation)); + HloComputation* entry_computation = + module->AddEntryComputation(builder.Build()); + + std::unique_ptr call_graph = CallGraph::Build(module.get()); + + EXPECT_EQ(3, call_graph->nodes().size()); + + const CallGraphNode& entry_node = call_graph->GetNode(entry_computation); + EXPECT_EQ(entry_computation, entry_node.computation()); + EXPECT_EQ(1, entry_node.callsites().size()); + + const CallSite& conditional_callsite = entry_node.callsites()[0]; + EXPECT_EQ(conditional, conditional_callsite.instruction()); + EXPECT_THAT(conditional_callsite.called_computations(), + UnorderedElementsAre(true_computation, false_computation)); + EXPECT_EQ(CallContext::kSequential, conditional_callsite.context()); + EXPECT_EQ(entry_node.GetCallSite(conditional), &conditional_callsite); + + const CallGraphNode& true_node = call_graph->GetNode(true_computation); + EXPECT_TRUE(true_node.callees().empty()); + EXPECT_EQ(1, true_node.callers().size()); + EXPECT_EQ(entry_computation, true_node.callers()[0]); + + const CallGraphNode& false_node = call_graph->GetNode(false_computation); + EXPECT_TRUE(false_node.callees().empty()); + EXPECT_EQ(1, false_node.callers().size()); + EXPECT_EQ(entry_computation, false_node.callers()[0]); +} + TEST_F(CallGraphTest, ComplexGraph) { // Test a call graph of a module with several computation called in various // contexts. The call graph looks like: diff --git a/tensorflow/compiler/xla/service/compiler.cc b/tensorflow/compiler/xla/service/compiler.cc index 3b1900428af1863c73efe67c27061d979557b3a4..e2e9d2a0c048fec6c6ffbeef1223ae0e6aef50d1 100644 --- a/tensorflow/compiler/xla/service/compiler.cc +++ b/tensorflow/compiler/xla/service/compiler.cc @@ -27,14 +27,8 @@ namespace se = ::perftools::gputools; namespace xla { -/* static */ tensorflow::mutex* Compiler::platform_compiler_mutex_; - -/* static */ void Compiler::LazyInitMutex() { - static std::once_flag mutex_init_flag; - std::call_once(mutex_init_flag, []() { - Compiler::platform_compiler_mutex_ = new tensorflow::mutex; - }); -} +/* static */ tensorflow::mutex Compiler::platform_compiler_mutex_( + tensorflow::LINKER_INITIALIZED); /* static */ std::map* @@ -55,8 +49,7 @@ Compiler::GetPlatformCompilers() { /* static */ void Compiler::RegisterCompilerFactory( se::Platform::Id platform_id, std::function()> compiler_factory) { - LazyInitMutex(); - tensorflow::mutex_lock lock(*platform_compiler_mutex_); + tensorflow::mutex_lock lock(platform_compiler_mutex_); auto* factories = GetPlatformCompilerFactories(); CHECK(factories->find(platform_id) == factories->end()) << "Compiler factory already registered for platform"; @@ -65,8 +58,7 @@ Compiler::GetPlatformCompilers() { /* static */ StatusOr Compiler::GetForPlatform( const se::Platform* platform) { - LazyInitMutex(); - tensorflow::mutex_lock lock(*platform_compiler_mutex_); + tensorflow::mutex_lock lock(platform_compiler_mutex_); auto* compilers = GetPlatformCompilers(); // See if we already instantiated a compiler for this platform. diff --git a/tensorflow/compiler/xla/service/compiler.h b/tensorflow/compiler/xla/service/compiler.h index 4c2d9600d909e82dcb62f508a10445c08c1cdee6..fc67330f5cbdbcb0d1a259d284599916a908d1fe 100644 --- a/tensorflow/compiler/xla/service/compiler.h +++ b/tensorflow/compiler/xla/service/compiler.h @@ -97,21 +97,32 @@ class Compiler { // Returns the ID of the platform that this compiler targets. virtual perftools::gputools::Platform::Id PlatformId() const = 0; + // Runs Hlo passes to optimize the given Hlo module, returns the optimized + // module. + virtual StatusOr> RunHloPasses( + std::unique_ptr module, + perftools::gputools::StreamExecutor* executor) = 0; + // Compiles the HLO module for execution on a device given by the executor, - // and returns an executable object or an error status. Takes ownership of the - // HLO module and is free to transform it. + // and returns an executable object or an error status. No HLO passes are + // applied to module. Generally a module should be passed through RunHloPasses + // prior to calling this method because the some HLO passes are required for + // correctness. Takes ownership of the HLO module and is free to transform it. // // The compiler may optionally specialize to the individual device // (not just type of device) indicated by the executor. // // Use the overload below to compile computations that run in parallel. - virtual StatusOr> Compile( + virtual StatusOr> RunBackend( std::unique_ptr module, perftools::gputools::StreamExecutor* executor) = 0; // Compiles a set of HLO modules that can run in parallel, potentially // communicating data between the modules, and returns a corresponding // sequence of executable objects. + // + // TODO(b/68666782): Remove this method after adding support for multiple + // modules to RunHloPasses and RunBackends. virtual StatusOr>> Compile( std::vector> modules, std::vector> @@ -157,8 +168,7 @@ class Compiler { private: // Mutex that guards the platform-compiler map. - static tensorflow::mutex* platform_compiler_mutex_; - static void LazyInitMutex(); + static tensorflow::mutex platform_compiler_mutex_; // Map from platform kind to compiler factory. static std::map* diff --git a/tensorflow/compiler/xla/service/computation_placer.cc b/tensorflow/compiler/xla/service/computation_placer.cc index cdfa30dd9a7b6a5b9e58087491a9d99caaa1b998..657fba6b6231104bf47f9dec80f7cd36a0ba3efd 100644 --- a/tensorflow/compiler/xla/service/computation_placer.cc +++ b/tensorflow/compiler/xla/service/computation_placer.cc @@ -52,6 +52,12 @@ Status DeviceAssignment::Serialize(DeviceAssignmentProto* proto) const { /* static */ StatusOr> DeviceAssignment::Deserialize(const DeviceAssignmentProto& proto) { TF_RET_CHECK(proto.computation_devices_size() == proto.computation_count()); + if (proto.replica_count() <= 0 || proto.computation_count() <= 0) { + return InvalidArgument( + "Invalid device assignment topology: replica_count=%d, " + "computation_count=%d", + proto.replica_count(), proto.computation_count()); + } auto assignment = MakeUnique(proto.replica_count(), proto.computation_count()); for (int computation = 0; computation < proto.computation_count(); @@ -94,7 +100,7 @@ StatusOr ComputationPlacer::AssignDevices( se::Platform::Id platform_id, ComputationPlacerCreationFunction creation_function) { tensorflow::mutex_lock lock( - *ComputationPlacer::platform_computation_placer_mutex()); + ComputationPlacer::platform_computation_placer_mutex_); auto* computation_placers = GetPlatformComputationPlacers(); CHECK(computation_placers->find(platform_id) == computation_placers->end()); (*computation_placers)[platform_id].creation_function = creation_function; @@ -103,7 +109,7 @@ StatusOr ComputationPlacer::AssignDevices( /* static */ StatusOr ComputationPlacer::GetForPlatform( const se::Platform* platform) { tensorflow::mutex_lock lock( - *ComputationPlacer::platform_computation_placer_mutex()); + ComputationPlacer::platform_computation_placer_mutex_); auto* computation_placers = GetPlatformComputationPlacers(); auto it = computation_placers->find(platform->id()); @@ -122,11 +128,9 @@ StatusOr ComputationPlacer::AssignDevices( return it->second.placer.get(); } -/* static */ tensorflow::mutex* -ComputationPlacer::platform_computation_placer_mutex() { - static tensorflow::mutex* m = new tensorflow::mutex; - return m; -} +/* static */ tensorflow::mutex + ComputationPlacer::platform_computation_placer_mutex_( + tensorflow::LINKER_INITIALIZED); /* static */ std::map* diff --git a/tensorflow/compiler/xla/service/computation_placer.h b/tensorflow/compiler/xla/service/computation_placer.h index 7d9abcd100dd9e878da885110bc1bd1ac65e3f84..737ccabaa7a61931b6e2787f75b02857562d4820 100644 --- a/tensorflow/compiler/xla/service/computation_placer.h +++ b/tensorflow/compiler/xla/service/computation_placer.h @@ -89,11 +89,8 @@ class ComputationPlacer { const perftools::gputools::Platform* platform); private: - // Routine that returns the mutex that guards the platform-to-computation - // placer map. Done as a routine to ensure correct initialization ordering, - // since RegisterComputationPlacer can be called during program initialization - // time. - static tensorflow::mutex* platform_computation_placer_mutex(); + // The mutex that guards the platform-to-computation placer map. + static tensorflow::mutex platform_computation_placer_mutex_; // State kept for each kind of ComputationPlacer. Registration functions set // up creation_function, and then we use that to lazily create "placer" the diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc index 0453a698a09b740d68b35258ede7c537fcf290d4..cd983bc03e993caed883916de01d75dffdbc4bab 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.cc +++ b/tensorflow/compiler/xla/service/copy_insertion.cc @@ -15,15 +15,17 @@ limitations under the License. #include "tensorflow/compiler/xla/service/copy_insertion.h" -#include - +#include "tensorflow/compiler/xla/service/hlo_alias_analysis.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_dce.h" +#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_ordering.h" #include "tensorflow/compiler/xla/service/liveness_util.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" -#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" +#include "tensorflow/compiler/xla/service/tuple_simplifier.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -31,597 +33,1174 @@ limitations under the License. #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" namespace xla { +using ::tensorflow::str_util::Join; +using ::tensorflow::strings::StrAppend; +using ::tensorflow::strings::StrCat; + namespace { -using tensorflow::gtl::FlatMap; -using tensorflow::gtl::FlatSet; +bool IsEntryParameterValue(const HloValue& value) { + const HloComputation* computation = value.defining_instruction()->parent(); + return value.defining_instruction()->opcode() == HloOpcode::kParameter && + computation == computation->parent()->entry_computation(); +} + +bool IsConstantValue(const HloValue& value) { + return value.defining_instruction()->opcode() == HloOpcode::kConstant; +} + +bool ValueIsReadOnly(const HloValue& value) { + return IsConstantValue(value) || IsEntryParameterValue(value); +} -// InstructionCopier encapsulates indices at which to copy 'instruction'. -// All 'instruction' users in 'copy_users' are updated to use the copy. +// Deep copy the given instructions 'from' and 'to' at the ShapeIndexes given in +// 'indices_to_copy'. Add control edges from the respective kCopy instructions +// in deep copy of 'from' to the respective kCopy instruction in the deep copy +// of 'to'. // -// Instruction copies are generated in two phases: -// 1) Recording buffer indices at which 'instruction' requires copies (i.e. -// setting 'indices_to_copy_[index]'=true). -// 2) Inserting kCopy instructions based on indices recorded in phase 1). -// *) Array instructions are copied by inserting a single kCopy instruction. -// *) Tuple-shaped instructions are copied by recursively expanding tuples -// (and tuple-shaped elements), and inserting kCopy instructions for any -// tuple elements which require a copy. As the recursion unwinds, new tuple -// instructions are added to gather the copied (and uncopied) references -// into the output tuple (i.e. the copy of the tuple-shaped instruction). +// Requirements: 'from' and 'to' must have compatible shapes. // -// Example two-element tuple with one element that needs a copy: +// For example, suppose 'from' and 'to' are two-element tuples where index 0 is +// the only index to copy. Prior to deep-copying we have: // -// original-instruction -// / \ -// GTE(0) GTE(1) -// | | -// Copy | -// \ / -// Tuple // copied-instruction // -// As an optimization, if the original instruction is itself a Tuple -// instruction, we elide the unnecessary extra GTE and Tuple instructions, -// and just insert the copy into a new Tuple instruction, with control -// dependencies to ensure the copy occurs after any possible interference. -class InstructionCopier { - public: - InstructionCopier(HloInstruction* instruction, - const std::vector& copy_users) - : instruction_(instruction), - copy_users_(copy_users), - indices_to_copy_(instruction->shape()), - control_predecessors_(instruction->shape()) {} - - // Sets indices that are read-only, and thus do not need to be copied. - void SetReadOnlyIndices(const ShapeTree& read_only_indices) { - read_only_indices_ = read_only_indices; - } +// 'from' +// | +// ... +// | +// 'to' +// +// DeepCopyAndAddControlEdges produces: +// +// 'from' +// / \ +// GTE GTE +// | | +// Copy | +// / \ / +// | Tuple +// | | +// ctrl ... +// edge | +// | | +// | 'to' +// | / \ +// | GTE GTE +// \ | | +// Copy | +// \ / +// Tuple +// +StatusOr> +DeepCopyAndAddControlEdges(HloInstruction* from, HloInstruction* to, + const ShapeTree& indices_to_copy) { + DCHECK(ShapeUtil::Compatible(from->shape(), to->shape())); + // to/from_copy_tree hold the kCopy instruction produces by the deep + // copies. Elements which are not copied (indices_to_copy.element(index) == + // false) have nullptr at that index. + ShapeTree from_copy_tree(from->shape(), + /*init_value=*/nullptr); + TF_ASSIGN_OR_RETURN(HloInstruction * from_deep_copy, + from->parent()->DeepCopyInstruction( + from, &indices_to_copy, &from_copy_tree)); - // Sets copy overrides, which are copy instructions to use at each index. This - // is used to share a single copy of read-only entry parameters and constants - // between multiple While loops. - void SetCopyOverrides(const ShapeTree& copy_overrides) { - copy_overrides_ = copy_overrides; + ShapeTree to_copy_tree(to->shape(), /*init_value=*/nullptr); + TF_ASSIGN_OR_RETURN( + HloInstruction * to_deep_copy, + to->parent()->DeepCopyInstruction(to, &indices_to_copy, &to_copy_tree)); + + // Add control edges between the respective kCopy instructions. + for (const auto& pair : from_copy_tree) { + const ShapeIndex& index = pair.first; + HloInstruction* from_copy = pair.second; + HloInstruction* to_copy = to_copy_tree.element(index); + if (from_copy == nullptr) { + TF_RET_CHECK(to_copy == nullptr); + continue; + } + TF_RET_CHECK(to_copy != nullptr); + TF_RETURN_IF_ERROR(from_copy->AddControlDependencyTo(to_copy)); } - // Returns true if all recorded indices are false (returns true otherwise). - bool HasAllIndicesFalse() const; + return std::make_pair(from_deep_copy, to_deep_copy); +} - // Records instruction buffer indices which point-to a Parameter or Constant. - Status RecordIndicesWhichPointToParamOrConstant( - const TuplePointsToAnalysis& points_to_analysis); +// Compute the indices of the loop state which need copies in order to avoid +// live range interference. Generally, an element in the loop state does not +// need to be copied if the element is passed through transparently through the +// body. +// +// Returns whether any indices need to be copied. +bool IndicesToCopyForWhile(const HloDataflowAnalysis& dataflow, + const HloInstruction* xla_while, + ShapeTree* indices_to_copy) { + DCHECK(ShapeUtil::Compatible(indices_to_copy->shape(), xla_while->shape())); - // Records instruction buffer indices to copy which are necessary to ensure: - // *) PointsToSet of 'instruction_' is unambiguous and distinct. - // *) No liveness interference between 'instruction_' and 'other_instruction'. - // - // If 'read_only_indices_out' is non-null, read-only indices are set to true. - Status RecordIndicesToCopyForColocatingBuffers( - const BufferLiveness& liveness, const HloInstruction* other_instruction, - ShapeTree* read_only_indices_out); + bool any_copies = false; + const HloInstruction* init = xla_while->operand(0); + for (auto& pair : *indices_to_copy) { + const ShapeIndex& index = pair.first; + bool& should_copy = pair.second; + // If there is any ambiguity, then loop state must be copied. + if (dataflow.GetValueSet(init, index).values().size() > 1 || + dataflow.GetValueSet(xla_while, index).values().size() > 1) { + should_copy = true; + } else { + // If the output of the while instruction is not the same as the init + // value of the while, then this element is not passed through the body + // transparently and must be copied. + should_copy = dataflow.GetUniqueValueAt(xla_while, index) != + dataflow.GetUniqueValueAt(init, index); + } + any_copies |= should_copy; + } + return any_copies; +} - // Records control predecessors to add for inserted copy instructions. - // 'parameter' must have the same shape as the instruction that will be - // copied, and must define all buffers in the shape. Control predecessors are - // only recorded for indices that have already been marked for copying. - Status RecordControlPredecessors( - const TuplePointsToAnalysis& points_to_analysis, - HloInstruction* parameter); +// Add kCopy instructions around the given kWhile instruction to eliminate any +// possible live range interference of HLO values assuming a dependency-based +// ordering (HloDependencyOrdering). Copies are added conservatively. There +// likely are copies which are not strictly necessary, but there are removed +// later in the pass via CopyRemover. +// +// +// Elements (each ShapeIndex) in the loop state are considered independently. A +// copy is added to each element of the loop state which is modified in the +// while body. For each such element, a total of three kCopy instructions are +// added at following locations: +// +// (1) The init value is copied before the kWhile instruction. Before: +// +// (Init) +// | +// kWhile +// | +// ... +// +// After: +// +// (Init) +// | +// kCopy +// | +// kWhile +// | +// ... +// +// This copy is necessary in case the init value is simultaneously live +// with the kWhile. +// +// (2) Copies are added to the parameter and root of the while body +// computation. Before: +// +// kParameter +// | +// ... +// | +// (body root) +// +// After: +// +// kParameter +// | +// kCopy ----------+ +// | | +// ... ctrl +// | edge +// (body root) | +// | | +// kCopy <---------+ +// +// The root kCopy becomes the new root of the computation. Both copies are +// necessary to any potential interference between the parameter value and +// the root value. The control edge prevents potential interference +// between the copies themselves. +// +// If the loop state is a tuple then the above kCopy instructions are a deep +// copy constructed of kCopy, KGetTupleElement, and kTuple instruction as +// constructed by HloInstruction::DeepCopyInstruction. +Status AddCopiesForWhile(const HloAliasAnalysis& alias_analysis, + HloInstruction* xla_while) { + VLOG(2) << "Adding copies for kWhile instruction " << xla_while->name(); + TF_RET_CHECK(xla_while->opcode() == HloOpcode::kWhile); - // Inserts copies of 'instruction' buffers at indices in 'indices_to_copy', - // and replaces all uses for instructions in 'copy_users_' with copy. - // Returns the instruction which is a copy 'instruction'. - HloInstruction* Copy(); + ShapeTree indices_to_copy(xla_while->shape()); + if (!IndicesToCopyForWhile(alias_analysis.dataflow_analysis(), xla_while, + &indices_to_copy)) { + VLOG(2) << "No copies necessary for kWhile instruction " + << xla_while->name(); + return Status::OK(); + } - HloInstruction* instruction() { return instruction_; } + VLOG(2) << "Adding copies for " << xla_while->name() << " at indices:"; + for (auto& pair : indices_to_copy) { + if (pair.second) { + VLOG(2) << " " << pair.first; + } + } - const std::vector& copy_users() const { return copy_users_; } + // Deep copy init. + HloInstruction* while_init = xla_while->mutable_operand(0); + TF_ASSIGN_OR_RETURN( + HloInstruction * while_init_copy, + xla_while->parent()->DeepCopyInstruction(while_init, &indices_to_copy)); + TF_RETURN_IF_ERROR(while_init->ReplaceUseWith(xla_while, while_init_copy)); - private: - // Does the given index represent a read-only buffer? - bool IsReadOnlyIndex(const ShapeIndex& index) const { - return !ShapeUtil::IsNil(read_only_indices_.shape()) && - read_only_indices_.element(index); - } + // Deep copy the parameter and the root. Extend a control edge from the copy + // of the parameter value to the corresponding copy value of the root. + HloComputation* body = xla_while->while_body(); + HloInstruction* param = body->parameter_instruction(0); + HloInstruction* root = body->root_instruction(); - // Returns the copy override at the given index, or nullptr. - HloInstruction* GetCopyOverride(const ShapeIndex& index) const { - return ShapeUtil::IsNil(copy_overrides_.shape()) - ? nullptr - : copy_overrides_.element(index); - } + // If param is the root then all indices should have been passed through the + // while body and we should have returned early above. + TF_RET_CHECK(param != root); - // Records instruction buffer indices which have ambiguous or non-distinct - // points-to sets. - Status RecordAmbiguousOrNonDistinctIndices( - const TuplePointsToAnalysis& points_to_analysis); + // Copy users before making a deep copy of the parameter as the deep copy + // will create new users of the parameter (eg, the GTE instructions of the + // deep copy). + std::vector param_users = param->users(); - // Records instruction buffer indices which have interfering live ranges - // with 'other_instruction' buffers at same index. - Status RecordIndicesWhichInterfereWithOtherInstruction( - const BufferLiveness& liveness, const HloInstruction* other_instruction, - ShapeTree* read_only_indices_out); + ShapeIndex current_index; + TF_ASSIGN_OR_RETURN(auto pair, + DeepCopyAndAddControlEdges(param, root, indices_to_copy)); - // Recursively inserts copies of 'instruction' tuple elements at indices - // specified in 'indices_to_copy', and returns the copy of 'instruction'. - HloInstruction* CopyTuple(HloInstruction* instruction, ShapeIndex* index); + HloInstruction* param_copy = pair.first; + HloInstruction* root_copy = pair.second; - void RecordIndex(const ShapeIndex& index) { - *indices_to_copy_.mutable_element(index) = true; + for (HloInstruction* user : param_users) { + TF_RETURN_IF_ERROR(param->ReplaceUseWith(user, param_copy)); } - HloInstruction* instruction_; - const std::vector copy_users_; - ShapeTree indices_to_copy_; - ShapeTree> control_predecessors_; - ShapeTree read_only_indices_; - ShapeTree copy_overrides_; -}; + body->set_root_instruction(root_copy); -bool InstructionCopier::HasAllIndicesFalse() const { - bool all_indices_false = true; - indices_to_copy_.ForEachElement( - [&all_indices_false](const ShapeIndex& /*index*/, bool data) { - if (data) { - all_indices_false = false; - } - }); - return all_indices_false; + return Status::OK(); } -Status InstructionCopier::RecordIndicesWhichPointToParamOrConstant( - const TuplePointsToAnalysis& points_to_analysis) { - const PointsToSet& points_to = - points_to_analysis.GetPointsToSet(instruction_); - // Shallow copy the instruction if the points-to set of the top-level - // buffer is ambiguous. This is necessary because the backends must know - // statically what the top-level buffer of the result is. - if (points_to.element(/*index=*/{}).size() > 1) { - RecordIndex({}); +// Removes any control dependencies to or from the given instruction. +Status StripControlDependenciesFrom(HloInstruction* instruction) { + while (!instruction->control_successors().empty()) { + TF_RETURN_IF_ERROR(instruction->RemoveControlDependencyTo( + instruction->control_successors().front())); + } + + while (!instruction->control_predecessors().empty()) { + TF_RETURN_IF_ERROR( + instruction->control_predecessors().front()->RemoveControlDependencyTo( + instruction)); } - // Multiple buffers within a parameter/constant may be live out, so collect - // a set of indices at which to copy first. - points_to.ForEachElement([this](const ShapeIndex& index, - const PointsToSet::BufferList& buffers) { - if (IsReadOnlyIndex(index)) { - return; - } - for (const LogicalBuffer* buffer : buffers) { - // pointee is the HloInstruction producing the buffer which may be - // liveout. - HloInstruction* pointee = buffer->instruction(); - if (pointee->opcode() == HloOpcode::kParameter || - pointee->opcode() == HloOpcode::kConstant) { - VLOG(2) << "Parameter or constant buffer " << buffer->ToString() - << " index: " << tensorflow::str_util::Join(index, ",") - << " may be live out of computation: " << pointee->ToString(); - RecordIndex(index); - break; - } - } - }); return Status::OK(); } -Status InstructionCopier::RecordIndicesToCopyForColocatingBuffers( - const BufferLiveness& liveness, const HloInstruction* other_instruction, - ShapeTree* read_only_indices_out) { - TF_RETURN_IF_ERROR( - RecordAmbiguousOrNonDistinctIndices(liveness.points_to_analysis())); - TF_RETURN_IF_ERROR(RecordIndicesWhichInterfereWithOtherInstruction( - liveness, other_instruction, read_only_indices_out)); +// Add kCopy instructions to the given module to guarantee there is no +// live-range interference. Generally interference can only occur around kWhile +// instructions which have update-in-place semantics. +Status AddCopiesToResolveInterference(HloModule* module) { + TF_ASSIGN_OR_RETURN(std::unique_ptr alias_analysis, + HloAliasAnalysis::Run(module)); + + for (HloComputation* computation : module->computations()) { + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kWhile) { + TF_RETURN_IF_ERROR(AddCopiesForWhile(*alias_analysis, instruction)); + } + } + } return Status::OK(); } -Status InstructionCopier::RecordAmbiguousOrNonDistinctIndices( - const TuplePointsToAnalysis& points_to_analysis) { - const PointsToSet& points_to = - points_to_analysis.GetPointsToSet(instruction_); - // Mapping from LogicalBuffer to index (used to detect non-distinct indices). - FlatMap> - buffer_to_source_indices; - points_to.ForEachElement( - [this, &buffer_to_source_indices]( - const ShapeIndex& index, const PointsToSet::BufferList& buffers) { - if (buffers.size() > 1) { - // Record ambiguous points-to set at 'index'. - if (!indices_to_copy_.element(index)) { - VLOG(2) << "Adding copy of buffer for instruction: " - << instruction_->name() - << " at index: " << tensorflow::str_util::Join(index, ",") - << " with ambiguous points-to set."; - RecordIndex(index); +// Class for removing unnecessary copies from the module. +// +// kCopy instructions are added conservatively to guarantee no live range +// interference between HLO values. This class uses a more fine-grained analysis +// to remove some of these added copies which are not strictly necessary. +class CopyRemover { + public: + CopyRemover(const HloAliasAnalysis& alias_analysis, + const HloOrdering& ordering, HloModule* module) + : module_(module), + alias_analysis_(alias_analysis), + ordering_(ordering), + buffer_value_tracker_(*module, alias_analysis, ordering) {} + + // Try to elide the given copy. The copy is elided if the instruction is not + // necessary to prevent live-range interference of HLO values. Returns true if + // copy was elided. + // + // The copy instruction is not actually removed here. Instead it is left for + // dead in the graph. Later calls to DCE will remove the instruction. + StatusOr TryElideCopy(HloInstruction* copy) { + if (buffer_value_tracker_.TryElideCopy(copy)) { + TF_RETURN_IF_ERROR(StripControlDependenciesFrom(copy)); + TF_RETURN_IF_ERROR(copy->ReplaceAllUsesWith(copy->mutable_operand(0))); + return true; + } + return false; + } + + string ToString() const { + string out = StrCat("CopyRemover, module ", module_->name(), "\n"); + StrAppend(&out, " Buffer values, in dependency order:\n"); + for (const HloBuffer& buffer : alias_analysis_.buffers()) { + StrAppend(&out, " HloBuffer ", buffer.id(), ":\n"); + } + return out; + } + + private: + // Class which tracks the HLO values within each HLO buffer in the module + // during copy removal. + // + // The values are held in a linked list where there is one list for each + // buffer. Removing a copy instruction merges together the values in the + // source buffer of the copy to the destination buffer of the copy. This class + // tracks these value lists as copies are removed from the graph (and value + // lists are merged). + // + // The BufferValueTracker object is initialized to match the state of + // HloAliasAnalysis. However, as copies are removed this state diverges. The + // values-to-buffer mapping is maintained outside of HloAliasAnalysis because + // a fully updatable alias analysis is very slow. + class BufferValueTracker { + public: + // The values held in a single HLO buffer are represented using a linked + // list. An element type in this list is ValueNode. + // + // This linked list is hand-rolled to enable efficient splicing of lists + // using only references to list elements without knowing which lists are + // being spliced. std::list requires a reference to the list object to + // splice. + struct ValueNode { + explicit ValueNode(const HloValue* v) : value(v) {} + + const HloValue* value; + + // The uses are maintained outside of HloValue::uses() because + // HloValue::uses() is not updatable (a fully updatable dataflow analysis + // is slow). + std::vector uses; + + // next/prev elements in the linked list. The list is circularly linked so + // these values are never null for elements in the list. + ValueNode* prev = nullptr; + ValueNode* next = nullptr; + }; + + BufferValueTracker(const HloModule& module, + const HloAliasAnalysis& alias_analysis, + const HloOrdering& ordering) + : dataflow_(alias_analysis.dataflow_analysis()), ordering_(ordering) { + // Construct a list for each HLO buffer in the alias analysis. Maintain a + // map from HloValue to the respective list element representing that + // value. The map is used to construct the copy info map below. + tensorflow::gtl::FlatMap value_to_node; + for (const HloBuffer& buffer : alias_analysis.buffers()) { + // Verify values contained in the buffer are strictly ordered. This + // should always be the case after adding copies to eliminate + // interference. Specifically, the addition of the control flow edges + // between copies added around aliased operations (kWhile) guarantees + // this strict order. + for (const HloValue* value_a : buffer.values()) { + for (const HloValue* value_b : buffer.values()) { + if (value_a != value_b) { + DCHECK(ordering_.LiveRangeStrictlyBefore(*value_a, *value_b, + dataflow_) || + ordering_.LiveRangeStrictlyBefore(*value_b, *value_a, + dataflow_)) + << value_a->ToShortString() << " and " + << value_b->ToShortString() << " are not ordered"; + } } } - // For each 'buffer': record a mapping from 'buffer' to 'index'. - for (const LogicalBuffer* buffer : buffers) { - buffer_to_source_indices[buffer].push_back(index); - } - }); - // Record all non-distinct indices detected in 'buffer_to_source_indices'. - for (const auto& buff_to_src : buffer_to_source_indices) { - if (buff_to_src.second.size() == 1) { - continue; + std::vector values = buffer.values(); + std::sort(values.begin(), values.end(), + [this](const HloValue* a, const HloValue* b) { + return ordering_.IsDefinedBefore(*a, *b); + }); + + // Create a list containing all of the values in the buffer. + AddValueList(values, &value_to_node); + } + + // Create copy_map_ which contains the source and destination values + // of all copies. + CreateCopyMap(module, value_to_node); + + XLA_VLOG_LINES(3, ToString()); + TF_DCHECK_OK(Verify()); } - for (const ShapeIndex& src_index : buff_to_src.second) { - // Record non-distinct points-to set at 'src_index'. - if (!indices_to_copy_.element(src_index)) { - VLOG(2) << "Adding copy of buffer for instruction: " - << instruction_->name() - << " at index: " << tensorflow::str_util::Join(src_index, ",") - << " because of non-distinct points-to set."; - RecordIndex(src_index); + + // Add a list containing the given values to BufferValueTracker. This + // represents the values contained in a single buffer. For each value in + // 'values' an entry is created in value_to_node which indicates the + // respective ValueNode representing that value. + void AddValueList( + tensorflow::gtl::ArraySlice values, + tensorflow::gtl::FlatMap* value_to_node) { + ValueNode* tail = nullptr; + ValueNode* head = nullptr; + for (const HloValue* value : values) { + auto new_node = new ValueNode(value); + (*value_to_node)[value] = new_node; + + // Copy the HLO values's uses into the ValueNode for the value. These + // uses in ValueNode are updated as copies are removed. + new_node->uses.reserve(value->uses().size()); + for (const HloUse& use : value->uses()) { + new_node->uses.push_back(&use); + } + + // Connect the new node into the linked list. + if (tail == nullptr) { + head = new_node; + } else { + tail->next = new_node; + new_node->prev = tail; + } + tail = new_node; } + + // The linked list is circular so connect the head and tail. + tail->next = head; + head->prev = tail; + value_lists_.insert(head); } - } - return Status::OK(); -} -Status InstructionCopier::RecordIndicesWhichInterfereWithOtherInstruction( - const BufferLiveness& liveness, const HloInstruction* other_instruction, - ShapeTree* read_only_indices_out) { - // Record all buffer indices for 'instruction_', which interfere with - // 'other_instruction' at the same index. - ShapeUtil::ForEachSubshape( - instruction_->shape(), - [this, &liveness, other_instruction, read_only_indices_out]( - const Shape& /*subshape*/, const ShapeIndex& index) { - if (IsReadOnlyIndex(index)) { - return; + // This method also fills in copy_map_ which indicates which nodes + // in the value lists corresponding to the source and destination values of + // kCopy instructions. value_to_node should map each HloValue to its + // respective ValueNode. + void CreateCopyMap( + const HloModule& module, + const tensorflow::gtl::FlatMap& + value_to_node) { + for (HloComputation* computation : module.computations()) { + for (HloInstruction* instruction : computation->instructions()) { + // Add copies with unambiguous source values to the map. Copies with + // ambiguous sources are not removable. + if (instruction->opcode() == HloOpcode::kCopy) { + const HloValueSet& src_value_set = + dataflow_.GetValueSet(instruction->operand(0)); + if (src_value_set.values().size() == 1) { + CopyNodes& copy_node = copy_map_[instruction]; + copy_node.dest = + value_to_node.at(&dataflow_.GetUniqueValueAt(instruction)); + copy_node.src = value_to_node.at(&src_value_set.GetUniqueValue()); + } + } } - if (indices_to_copy_.element(index)) { - // Return if previous pass already set index. - return; + } + } + + ~BufferValueTracker() { + for (const ValueNode* head : value_lists_) { + const ValueNode* p = head; + do { + const ValueNode* tmp = p->next; + delete p; + p = tmp; + } while (p != head); + } + } + + // Verify invariants within the linked lists. + Status Verify() const { + for (const ValueNode* head : value_lists_) { + const ValueNode* p = head; + do { + // Verify links between elements are consistent. + TF_RET_CHECK(p->prev->next == p); + TF_RET_CHECK(p->next->prev == p); + + const HloInstruction* def = p->value->defining_instruction(); + if (def->opcode() == HloOpcode::kCopy && + ContainsKey(copy_map_, def)) { + TF_RET_CHECK(copy_map_.at(def).dest == p); + } + for (const HloUse* use : p->uses) { + if (use->instruction->opcode() == HloOpcode::kCopy && + ContainsKey(copy_map_, use->instruction)) { + TF_RET_CHECK(copy_map_.at(use->instruction).src == p); + } + } + + p = p->next; + } while (p != head); + } + return Status::OK(); + } + + // Try to elide the given copy. Elision of a copy is possible only if no + // live range interference is introduced by the copy's elimination. If + // elision is possible, then the internal state (value lists) are updated, + // and true is returned. Returns false otherwise. + bool TryElideCopy(const HloInstruction* copy) { + VLOG(2) << "Trying to remove " << copy->name(); + + if (!ContainsKey(copy_map_, copy)) { + VLOG(2) << copy->name() << " is not removable"; + return false; + } + + const CopyNodes& copy_node = copy_map_.at(copy); + ValueNode* src = copy_node.src; + ValueNode* dest = copy_node.dest; + DCHECK(src != nullptr); + DCHECK(dest != nullptr); + + auto is_live_range_before = [this](const ValueNode& a, + const ValueNode& b) { + if (LiveRangeBefore(a, b)) { + VLOG(2) << " Live range of " << a.value->ToShortString() + << " is before " << b.value->ToShortString(); + return true; + } else { + VLOG(2) << " Live range of " << a.value->ToShortString() + << " is not before " << b.value->ToShortString(); + return false; } - const auto& points_to_analysis = liveness.points_to_analysis(); - // Lookup buffers for 'instruction_' and 'other_instruction'. - const auto instruction_buffers = - points_to_analysis.GetPointsToSet(instruction_).element(index); - // If 'instruction_' has ambiguous points-to-set at 'index', it would - // have been recorded in a previous pass (and we would have returned - // early at the entry to this function). As a result, here we know that - // 'instruction_' has just one buffer in its points-to-set. - CHECK_EQ(1, instruction_buffers.size()); - const LogicalBuffer* instruction_buffer = instruction_buffers[0]; - - const auto other_instruction_buffers = - points_to_analysis.GetPointsToSet(other_instruction).element(index); - // Do not insert a copy if both instructions point at the same buffer. - // This eliminates unnecessary copies of read-only tuple elements. - // If 'instruction_' and 'other_instruction' point to the same buffer, - // then that buffer is not updated on the path between the two - // instructions. Therefore, any other (possibly interference-causing) - // users of that buffer from 'other_instruction' will see the same data, - // irrespective of whether we insert a copy of this buffer at - // 'instruction_' or not. - if (other_instruction_buffers.size() == 1 && - other_instruction_buffers[0]->id() == instruction_buffer->id()) { - if (read_only_indices_out != nullptr) { - *read_only_indices_out->mutable_element(index) = true; + }; + + VLOG(3) << copy->name() << " copies value " + << src->value->ToShortString(); + VLOG(3) << "Source buffer values: " << ValueListToString(src); + VLOG(3) << "Dest buffer values: " << ValueListToString(src); + + // A kCopy instruction copies an HLO value from a source buffer and + // defines an HLO value in a destination buffer. Most generally, the + // source and destination buffers may each hold more than one value at + // different points in the computation so we define the following: + // + // Values in source buffer: {s_0, ..., s_n} + // Values in destination buffer: {d_0, ..., d_m} + // + // A kCopy instruction between these buffers copies a value s_x in the + // source buffer and defines a value d_y in the destination buffer. The + // elision of a copy merges the source and destination buffers together, + // so the list of values for the source and destination buffers are + // merged. + // + // We handle two different cases for copy elision: + // + // (1) the kCopy defines the first value in the destination buffer (d_0). + // + // (2) the kCopy copies the last value in the source buffer (s_n). + // + // For the remaining case where the kCopy copies a not-last value from the + // source buffer to a not-first value of the destination buffer, the kCopy + // instruction cannot be removed. This case is generated, for example, if + // the kCopy copies a while body parameter of the loop state at one tuple + // index to a different tuple index in the while body root. Removal of the + // copy necessarily results in live range interference of values in the + // loop state at the two different tuple indices. + // + // We can only perform copy elision if the resulting merged values have + // totally ordered live ranges; otherwise the merged buffer would have + // live range interference. + if (IsHead(*dest)) { + // The copy copies an arbitrary value in the source buffer (call it s_x) + // and defines d_0, the first value in the destination buffer. After + // merging, the values in the combined buffer must be strictly ordered + // as follows** to elide the copy: + // + // {s_0, ..., s_x, d_1, ..., d_m, s_{x+1}, ..., s_n} + // + // Removing the copy eliminates d_0, and uses of d_0 become uses of + // s_x. In the above ordering, the live range of d_m must be ordered + // before the live range of s_{x+1} and the definition and all uses of + // s_x must be ordered before the definition of d_1. These conditions + // are checked below prior to elision. + // + // ** Technically it might be possible to have a non-interfering + // non-trivial interleaving of the values of the source and + // destination buffers in the resulting order. However, this case is + // slow and complicated to check and likely not worth it. So instead + // we simply check for the case where *all* values of the destination + // buffer (d_1 through d_m) are spliced into the point where the copy + // used to be. + VLOG(2) << copy->name() << " defines the first value in its buffer"; + ValueNode* next_dest = Next(*dest); + if (next_dest != nullptr) { + // Live range of 'from' value (s_x) must be before 'next_dest' (d_1); + if (!is_live_range_before(*src, *next_dest)) { + return false; } - return; } - // We can't say anything about the ambiguity of 'other_instruction' at - // this point, so we need to check interference between the single - // buffer in the points-to set of 'instruction_' and all buffers in - // 'other_instruction_buffers'. - for (const LogicalBuffer* other_buffer : other_instruction_buffers) { - if (liveness.MayInterfere(*instruction_buffer, *other_buffer)) { - VLOG(2) << "Adding copy of buffer for instruction: " - << instruction_->name() - << " instruction_buffer: " << instruction_buffer->ToString() - << " at index: " << tensorflow::str_util::Join(index, ",") - << " because of interference with buffer: " - << other_buffer->ToString(); - RecordIndex(index); - break; + ValueNode* next_src = Next(*src); + + if (next_src != nullptr) { + // Live range of 'last_dest' (d_m) must be before 'next_src' s_{x+1}. + ValueNode* last_dest = dest->prev; + DCHECK(IsTail(*last_dest)); + if (!is_live_range_before(*last_dest, *next_src)) { + return false; } } - }); - return Status::OK(); -} -// This is called when 'instruction_' is a while body root, and 'parameter' is -// the while body parameter. We record all users of all aliases of 'parameter' -// as control predecessors, so that when we add a copy of 'instruction_', we can -// mark the control dependencies. This is necessary because points-to and -// liveness analysis doesn't know about the aliasing between the while body root -// and param. Without these control dependencies, the copy might get scheduled -// to run at a point that interferes with users of the buffer. -Status InstructionCopier::RecordControlPredecessors( - const TuplePointsToAnalysis& points_to_analysis, - HloInstruction* parameter) { - return indices_to_copy_.ForEachElementWithStatus( - [this, &points_to_analysis, parameter](const ShapeIndex& index, - bool will_copy) { - if (will_copy) { - TF_ASSIGN_OR_RETURN( - const LogicalBuffer* buffer, - points_to_analysis.GetBufferDefinedAt(parameter, index)); - for (const BufferAlias& alias : - points_to_analysis.GetBufferAliases(*buffer)) { - for (HloInstruction* user : alias.instruction()->users()) { - if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(), - user, points_to_analysis)) { - continue; - } - - if (user != instruction_) { - control_predecessors_.mutable_element(index)->push_back(user); - } - } + // Splice in destination buffer values list right after 'src'. + SpliceAfter(dest, src); + } else if (IsTail(*src)) { + // The copy copies the last value in the source buffer, s_n, and defines + // an arbitrary value in the destination buffer, d_y. After + // merging, the values in the combined buffer must be strictly ordered + // as follows** to elide the copy: + // + // {d_0, ..., d_{y-1}, s_0, ..., s_n, d_{y+1}, ..., d_m} + // + // Removing the copy eliminates d_y, and uses of d_y become uses of + // s_n. To enforce the above order, the live range of d_{y-1} must be + // before the live range of s_0, and the live range of s_n must be + // before the live range of d_{y+1}. + // + // ** See comment above in the code handling Case (1). + VLOG(2) << copy->name() << " copies the last value (" + << src->value->ToShortString() << ") in its buffer"; + + ValueNode* prev_dest = Prev(*dest); + // nullptr condition handled above in the first 'if' case. + DCHECK(prev_dest != nullptr); + ValueNode* first_src = src->next; + DCHECK(IsHead(*first_src)); + if (!is_live_range_before(*prev_dest, *first_src)) { + // Live range of value d_{y-1} is not before s_0. + return false; + } + ValueNode* next_dest = Next(*dest); + if (next_dest != nullptr) { + if (!is_live_range_before(*src, *next_dest)) { + // Live range of value s_n is not before d_{y+1}. + return false; } } - return Status::OK(); - }); -} -// Recursively inserts copies of 'instruction' tuple element buffers at -// indices in 'indices_to_copy_', expanding tuples as needed. -HloInstruction* InstructionCopier::CopyTuple(HloInstruction* instruction, - ShapeIndex* index) { - const int64 num_tuple_elements = - ShapeUtil::TupleElementCount(instruction->shape()); - std::vector elem_copies(num_tuple_elements); - for (int64 i = 0; i < num_tuple_elements; ++i) { - HloInstruction* elem; - if (instruction->opcode() == HloOpcode::kTuple) { - // If the instruction is already a Tuple instruction, we know that the - // element buffers are aliased, so we can just grab the operand directly. - elem = instruction->mutable_operand(i); - } else { - // Otherwise we need to add a GTE to unpack the element out of the tuple. - elem = instruction->parent()->AddInstruction( - HloInstruction::CreateGetTupleElement( - ShapeUtil::GetSubshape(instruction->shape(), {i}), instruction, - i)); - } - index->push_back(i); - if (ShapeUtil::IsTuple(elem->shape())) { - elem_copies[i] = CopyTuple(elem, index); - } else if (!indices_to_copy_.element(*index)) { - elem_copies[i] = elem; - } else if (HloInstruction* copy_override = GetCopyOverride(*index)) { - elem_copies[i] = copy_override; - } else { - HloInstruction* elem_copy = elem->parent()->AddInstruction( - HloInstruction::CreateUnary(elem->shape(), HloOpcode::kCopy, elem)); - for (HloInstruction* control_predecessor : - control_predecessors_.element(*index)) { - VLOG(2) << "Adding control dependency from " - << control_predecessor->ToString() << " to " - << elem_copy->ToString(); - TF_CHECK_OK(control_predecessor->AddControlDependencyTo(elem_copy)); + // Splice source buffer values list right after 'prev_dest'. + SpliceAfter(first_src, prev_dest); + } else { + VLOG(2) + << copy->name() + << " copies value in middle of source buffer to value in middle " + "of destination buffer"; + return false; } - elem_copies[i] = elem_copy; + + RemoveCopyValue(dest); + + XLA_VLOG_LINES(4, ToString()); + TF_DCHECK_OK(Verify()); + + return true; } - index->pop_back(); - } - return instruction->parent()->AddInstruction( - HloInstruction::CreateTuple(elem_copies)); -} -// Inserts copies of 'instruction_' buffers at indices in 'indices_to_copy_'. -HloInstruction* InstructionCopier::Copy() { - ShapeIndex index; - HloInstruction* copy; - if (ShapeUtil::IsTuple(instruction_->shape())) { - copy = CopyTuple(instruction_, &index); - } else { - copy = instruction_->parent()->AddInstruction(HloInstruction::CreateUnary( - instruction_->shape(), HloOpcode::kCopy, instruction_)); - } - for (HloInstruction* user : copy_users_) { - VLOG(2) << "Adding copy between instruction: " << instruction_->name() - << " and user: " << user->name(); - TF_CHECK_OK(instruction_->ReplaceUseWith(user, copy)); + // Delete the given ValueNode associated with a elided kCopy + // instruction. This should be called after splicing the value lists of the + // source and destination buffers together. + void RemoveCopyValue(ValueNode* copy_value_node) { + CHECK_EQ(copy_value_node->value->defining_instruction()->opcode(), + HloOpcode::kCopy); + ValueNode* operand_node = copy_value_node->prev; + CHECK(operand_node != copy_value_node); + + VLOG(2) << "Removing copy " << operand_node->value->ToShortString() + << " => " << copy_value_node->value->ToShortString(); + + // Splice out the copy value node. + operand_node->next = copy_value_node->next; + copy_value_node->next->prev = operand_node; + + // Patch up uses. Remove use of copy from operand_node uses. + auto it = + std::find_if(operand_node->uses.begin(), operand_node->uses.end(), + [copy_value_node](const HloUse* use) { + return use->instruction == + copy_value_node->value->defining_instruction(); + }); + CHECK(it != operand_node->uses.end()); + operand_node->uses.erase(it); + + // If the elided copy has any uses which are themselves kCopy instructions + // then patch up the copy info to reflect the that this kCopy instruction + // has a different operand (the operand of the elided copy). + for (const HloUse* copy_use : copy_value_node->uses) { + operand_node->uses.push_back(copy_use); + if (copy_use->instruction->opcode() == HloOpcode::kCopy) { + copy_map_.at(copy_use->instruction).src = operand_node; + } + } + + // Delete the copy info and the value node. + copy_map_.erase(copy_value_node->value->defining_instruction()); + delete copy_value_node; + } + + // Returns true if the live range of given value 'a' is before the live + // range of 'b'. + // + // We cannot use LiveRangeStrictlyBefore because HloValue::uses() is not + // updated as copies are removed. + bool LiveRangeBefore(const ValueNode& a, const ValueNode& b) { + if (a.uses.empty()) { + VLOG(2) << "Empty uses"; + return ordering_.IsDefinedBefore(*a.value, *b.value); + } + for (const HloUse* use : a.uses) { + VLOG(2) << "use: " << *use; + VLOG(2) << "is before:" << *b.value; + if (!ordering_.UseIsBeforeValueDefinition(*use, *b.value, dataflow_)) { + VLOG(2) << "Not before"; + return false; + } + } + return true; + } + + // Returns whether 'node' is the last node in its list. + bool IsTail(const ValueNode& node) const { + return ContainsKey(value_lists_, node.next); + } + + // Returns whether 'node' is the first node in its list. + bool IsHead(const ValueNode& node) const { + return ContainsKey(value_lists_, &node); + } + + // Returns the next node in the list after 'node'. If 'node' is the + // tail, then nullptr is returned. + ValueNode* Next(const ValueNode& node) const { + if (IsTail(node)) { + return nullptr; + } else { + return node.next; + } + } + + // Returns the previous node in the list before 'node'. If 'node' + // is the head, then nullptr is returned. + ValueNode* Prev(const ValueNode& node) const { + if (IsHead(node)) { + return nullptr; + } else { + return node.prev; + } + } + + // Splices the entire linked list with 'head' as its head right after the + // node 'insert_after' in another linked list. + void SpliceAfter(ValueNode* head, ValueNode* insert_after) { + DCHECK(IsHead(*head)); + value_lists_.erase(head); + + ValueNode* tail = head->prev; + tail->next = insert_after->next; + insert_after->next->prev = tail; + + insert_after->next = head; + head->prev = insert_after; + } + + string ValueListToString(const ValueNode* element) { + const ValueNode* head = element; + while (!IsHead(*head)) { + head = Prev(*head); + } + std::vector values; + for (const ValueNode* p = head; p != nullptr; p = Next(*p)) { + values.push_back(p->value); + } + return StrCat("{", + Join(values, ", ", + [](string* s, const HloValue* value) { + StrAppend(s, value->ToShortString()); + }), + "}"); + } + + string ToString() const { + string out = StrCat("BufferValueTracker:\n"); + StrAppend(&out, " Def-use chains in each buffer:\n"); + for (const ValueNode* head : value_lists_) { + StrAppend(&out, " Buffer defined by ", head->value->ToShortString(), + ":\n"); + const ValueNode* p = head; + do { + StrAppend(&out, " ", p->value->ToShortString(), ", uses: ", + Join(p->uses, "; ", + [](string* s, const HloUse* use) { + StrAppend(s, use->ToString()); + }), + "\n"); + + p = p->next; + } while (p != head); + } + StrAppend(&out, " Potentially removable copies:\n"); + for (const auto& pair : copy_map_) { + const HloInstruction* copy = pair.first; + const CopyNodes& copy_info = pair.second; + + StrAppend(&out, " ", copy->name(), " : ", + copy_info.src->value->ToShortString(), " => ", + copy_info.dest->value->ToShortString(), "\n"); + } + return out; + } + + private: + const HloDataflowAnalysis& dataflow_; + const HloOrdering& ordering_; + + // The heads of all the value lists. Each value list represents the HLO + // values contained in a particular HLO buffer. The values in the list are + // in dependency order. + tensorflow::gtl::FlatSet value_lists_; + + // Copy removal requires fast access to the value list elements + // corresponding to the source and destination values of the kCopy + // instruction. This data structure holds pointers to these elements for + // each kCopy instruction in the graph. + struct CopyNodes { + // The source and destinations values of the kCopy instruction. + ValueNode* src = nullptr; + ValueNode* dest = nullptr; + }; + tensorflow::gtl::FlatMap copy_map_; + }; + + HloModule* module_; + const HloAliasAnalysis& alias_analysis_; + const HloOrdering& ordering_; + + // Object tracking the HLO values contained in each HLO buffer. + BufferValueTracker buffer_value_tracker_; +}; + +// Try to remove as many copies from the module as possible without introducing +// live range interference. Copy instructions (identified by their unique id) in +// the set copies_to_exclude are not considered for removal. +Status RemoveUnnecessaryCopies( + const HloOrdering& ordering, + const tensorflow::gtl::FlatSet& copies_to_exclude, HloModule* module) { + TF_ASSIGN_OR_RETURN(std::unique_ptr alias_analysis, + HloAliasAnalysis::Run(module)); + CopyRemover copy_remover(*alias_analysis, ordering, module); + XLA_VLOG_LINES(3, copy_remover.ToString()); + + tensorflow::gtl::FlatSet existing_copies; + for (HloComputation* computation : module->computations()) { + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kCopy && + !ContainsKey(copies_to_exclude, instruction->unique_id())) { + TF_RETURN_IF_ERROR(copy_remover.TryElideCopy(instruction).status()); + } + } } - return copy; + + return Status::OK(); } -// The 'read_only_indices' are initialized based on points-to analysis on the -// while body corresponding to 'while_hlo'. If the init buffer corresponding to -// a read-only index aliases with a constant, it cannot be considered read-only, -// and must be copied. This is necessary because BufferAssignment does not -// currently assign an allocation for constants (b/32248867). -// This function performs this fix-up of 'read_only_indices'. +// Add copies to address special constraints on the roots of computations not +// related to live range interference: // -// Returns a ShapeTree of copy_overrides, which implements an optimization to -// allow multiple while loops that share the same read-only constants to -// share a single copy. -StatusOr> RevertReadOnlyIndicesForConstants( - const HloInstruction* while_hlo, - const TuplePointsToAnalysis& points_to_analysis, - ShapeTree* read_only_indices, - FlatMap* shared_copies) { - const HloInstruction* init_hlo = while_hlo->operand(0); - const PointsToSet& points_to = points_to_analysis.GetPointsToSet(init_hlo); - - // Mapping from LogicalBuffer to index (used to detect non-distinct indices). - FlatSet buffer_set; - - ShapeTree copy_overrides(init_hlo->shape()); - points_to.ForEachElement([init_hlo, read_only_indices, shared_copies, - &buffer_set, ©_overrides]( - const ShapeIndex& index, - const PointsToSet::BufferList& buffers) { - // Look for read-only entry parameters. - if (!read_only_indices->element(index)) { - return; - } - for (const LogicalBuffer* buffer : buffers) { - HloInstruction* pointee = buffer->instruction(); - const bool is_constant = pointee->opcode() == HloOpcode::kConstant; - if (!is_constant) { - continue; - } +// (1) Entry computation root must be unambiguous and distinct. +// +// (2) Any computation called by a kCall instruction must have an +// unambiguous root. +// +// (3) Constants and parameters cannot be live out of the entry computation +// +Status AddSpecialCaseCopies(const CallGraph& call_graph, HloModule* module) { + TF_ASSIGN_OR_RETURN(std::unique_ptr alias_analysis, + HloAliasAnalysis::Run(module)); + + // Identify which shape indices of which instructions need to be copied. Store + // these results in 'instructions_to_copy'. + std::unordered_map> instructions_to_copy; + auto add_index_to_copy = [&instructions_to_copy](HloInstruction* instruction, + const ShapeIndex& index) { + auto it = instructions_to_copy.find(instruction); + if (it == instructions_to_copy.end()) { + auto it_added = instructions_to_copy.emplace( + std::piecewise_construct, std::forward_as_tuple(instruction), + std::forward_as_tuple(instruction->shape(), /*init_value=*/false)); + it = it_added.first; + } + *it->second.mutable_element(index) = true; + }; - // We have found an constant that is read-only in - // the while body. These buffers are managed by the caller, and cannot - // be aliased with HLO buffers. Revert this read-only index, - // to allow it to be copied. - *read_only_indices->mutable_element(index) = false; - - // Optimization to allow multiple while loops that share the same - // read-only entry constants to share a single copy. - // Only unambiguous and distinct array-shaped buffers are allowed, to - // reduce code complexity. The shape of the entry parameter must be - // identical to the shape of the init_hlo at this index, to ensure - // there were no intervening bitcast or GTE instructions, which are - // also hard to handle. - const Shape& pointee_shape = pointee->shape(); - const Shape& init_shape = - ShapeUtil::GetSubshape(init_hlo->shape(), index); - if (buffers.size() == 1 && ShapeUtil::IsArray(pointee_shape) && - ShapeUtil::Equal(pointee_shape, init_shape) && - buffer_set.count(buffer) < 1) { - HloInstruction** copy = &(*shared_copies)[pointee]; - if (*copy == nullptr) { - *copy = pointee->parent()->AddInstruction(HloInstruction::CreateUnary( - pointee_shape, HloOpcode::kCopy, pointee)); + // Iterate through values of all constants and entry parameters. These values + // are special because they are held in read-only buffers. If any of these + // values share a buffer with other values (for example, the init value of a + // while is a constant) then copy the value at its definition and replace all + // its uses with the copy. + for (const HloValue* value : alias_analysis->dataflow_analysis().values()) { + if (ValueIsReadOnly(*value) && + alias_analysis->GetBufferContainingValue(*value).values().size() > 1) { + VLOG(2) << "Value " << value->ToShortString() + << " is read only, but its buffer contains more than one value. " + "Copying."; + add_index_to_copy(value->defining_instruction(), value->defining_index()); + } + } + + // Identify copies which must be added at root instructions + for (HloComputation* computation : module->computations()) { + const CallGraphNode& node = call_graph.GetNode(computation); + if (node.context() == CallContext::kParallel) { + continue; + } + TF_RET_CHECK(node.context() == CallContext::kSequential); + + const bool is_entry = computation == module->entry_computation(); + HloInstruction* root = computation->root_instruction(); + + // Mark nondistinct/ambiguous indices. + tensorflow::gtl::FlatSet seen; + ShapeUtil::ForEachSubshape( + root->shape(), [&](const Shape& /*subshape*/, const ShapeIndex& index) { + std::vector buffers_at_index = + alias_analysis->ComputeBuffersAt(root, index); + bool buffer_seen_before = false; + for (const HloBuffer* buffer : buffers_at_index) { + buffer_seen_before |= !seen.insert(buffer).second; + } + if (buffers_at_index.size() > 1 || (buffer_seen_before && is_entry)) { + VLOG(2) << "Index " << index << " of root of computation " + << computation->name() << " (" << root->name() + << ") has ambiguous or non-distinct buffer. Copying."; + add_index_to_copy(root, index); + } + }); + + // For entry instructions, mark any parameter or constant values. + if (is_entry) { + for (const auto& pair : + alias_analysis->dataflow_analysis().GetInstructionValueSet(root)) { + const ShapeIndex& index = pair.first; + const HloValueSet& value_set = pair.second; + for (const HloValue* value : value_set.values()) { + if (ValueIsReadOnly(*value)) { + VLOG(2) << "Root of entry computation (" << root->name() + << ") has constant or entry parameter value at index " + << index << ". Copying."; + add_index_to_copy(root, index); + } } - // Add the copy as an override. - *copy_overrides.mutable_element(index) = *copy; } + } + } - // Tracks whether this current buffer is distinct. - buffer_set.insert(buffer); + // Add copy instructions indicated in 'instructions_to_copy' to the module. + for (const auto& pair : instructions_to_copy) { + HloInstruction* instruction = pair.first; + const ShapeTree& indices_to_copy = pair.second; - // We've already reverted the read-only index and handled the - // single-copy optimization above, so there's nothing more to do. - break; + std::vector users = instruction->users(); + TF_ASSIGN_OR_RETURN(HloInstruction * deep_copy, + instruction->parent()->DeepCopyInstruction( + instruction, &indices_to_copy)); + for (HloInstruction* user : users) { + TF_RETURN_IF_ERROR(instruction->ReplaceUseWith(user, deep_copy)); + } + if (instruction == instruction->parent()->root_instruction()) { + instruction->parent()->set_root_instruction(deep_copy); } - }); - return copy_overrides; + } + + return Status::OK(); +} + +Status VerifyNoLiveRangeInterference(HloModule* module) { + TF_ASSIGN_OR_RETURN(std::unique_ptr alias_analysis, + HloAliasAnalysis::Run(module)); + DependencyHloOrdering ordering(module); + TF_RET_CHECK(!alias_analysis->HasLiveRangeInterference(ordering)); + return Status::OK(); } -} // anonymous namespace - -// NOTE: This is only called by gpu::CopyInsertion. It's not called here in the -// base class, since the regular CopyInsertion logic above selectively copies -// tuple elements, while this method assumes all buffers need to be deep copied. -StatusOr CopyInsertion::FindOrInsertCopy(HloInstruction* hlo) { - auto copy_it = inserted_copies_.find(hlo); - if (copy_it == inserted_copies_.end()) { - HloInstruction* copy = hlo->parent()->DeepCopyInstruction(hlo).ValueOrDie(); - inserted_copies_.insert({hlo, copy}); - return copy; - } else { - return copy_it->second; +void MaybeDumpModule(const string& message, const HloModule& module) { + if (VLOG_IS_ON(3)) { + VLOG(3) << message; + XLA_VLOG_LINES(3, module.ToString()); + hlo_graph_dumper::MaybeDumpHloModule(module, message); } } +} // namespace + StatusOr CopyInsertion::Run(HloModule* module) { - bool changed = false; - VLOG(2) << "CopyInsertion for module " << module->name(); + // Copy insertion is performed in three steps: + // + // (1) Add copies conservatively to guarantee that there is no live-range + // interference. This is done simplistically and usually results in more + // copies than is strictly necessary. + // + // (2) Using a more fine-grained analysis, remove as many copies that were + // added in (1) as possible while ensuring no live-range interference. + // + // (3) Add copies to resolve issues not related to live range interference + // such as parameters and constants live out of the entry computation. + // + // We add copies then remove them (step (1) then (2)) rather than simply + // adding only the copies that are necessary because, in general, it is + // difficult to figure out the minimal set of copies to add once there is + // interference. On the other hand, it is easy to determine if removing a copy + // will introduce interference. + // + // The final copy insertion in (3) is done separately to simplify the + // implementation of copy removal in (2) which is the most complicated part of + // the pass. As is, copy removal only has to reason about live range + // interference. If all copies were added in step (1) then copy removal would + // also have to reason about things like constants and parameters live out of + // the computation. + MaybeDumpModule("before copy insertion", *module); - TF_ASSIGN_OR_RETURN( - std::unique_ptr liveness, - BufferLiveness::Run(module, MakeUnique(module))); - const auto& points_to_analysis = liveness->points_to_analysis(); - XLA_VLOG_LINES(2, points_to_analysis.ToString()); - XLA_VLOG_LINES(2, module->ToString()); - - // Gather all while body computations and while instructions. - FlatSet while_body_computations; - std::vector while_instructions; - for (auto* computation : module->computations()) { + std::unique_ptr call_graph = CallGraph::Build(module); + if (!call_graph->IsFlattened()) { + return FailedPrecondition( + "Call graph must be flattened before copy insertion."); + } + + // Gather Ids of existing kCopy instructions in the module. We avoid removing + // these copies (except via DCE in TupleSimplifier) because they may have been + // added for reasons not considered by copy insertion (eg, layout assignment). + // Instruction id is used instead of HloInstruction* because the pointer + // values may be recycled. + tensorflow::gtl::FlatSet existing_copies; + for (HloComputation* computation : module->computations()) { for (HloInstruction* instruction : computation->instructions()) { - if (instruction->opcode() == HloOpcode::kWhile) { - while_body_computations.insert(instruction->while_body()); - while_instructions.push_back(instruction); + if (instruction->opcode() == HloOpcode::kCopy) { + existing_copies.insert(instruction->unique_id()); } } } - // Collect instruction buffer indices to copy in 'instructions_to_copy'. - std::vector instructions_to_copy; - - // Add copies of computation root instructions, if needed. - FlatMap> while_body_read_only_indices; - for (auto* computation : module->MakeNonfusionComputations()) { - VLOG(2) << "computation " << computation->name(); - InstructionCopier root_copier(computation->root_instruction(), - /*copy_users=*/{}); - if (while_body_computations.count(computation) > 0) { - // Record root indices to copy for while body sub-computations. We do not - // need to call RecordIndicesWhichPointToParamOrConstant for the while - // body root instruction here, because any necessary copies needed to - // avoid constants or parameters in the output are handled by while.init - // operand copy insertion below (which will share an allocation). - HloInstruction* while_body_param = computation->parameter_instruction(0); - ShapeTree read_only_indices(while_body_param->shape()); - TF_RETURN_IF_ERROR(root_copier.RecordIndicesToCopyForColocatingBuffers( - *liveness, while_body_param, &read_only_indices)); - while_body_read_only_indices[computation] = read_only_indices; - - // Mark control predecessors, based on the body param, for any copies - // we'll be inserting. This ensures the copy doesn't run too early. - TF_RETURN_IF_ERROR(root_copier.RecordControlPredecessors( - points_to_analysis, while_body_param)); - } else { - // Record root indices to copy for general computations. - TF_RETURN_IF_ERROR(root_copier.RecordIndicesWhichPointToParamOrConstant( - points_to_analysis)); + TF_RETURN_IF_ERROR(AddCopiesToResolveInterference(module)); + + // Simplify the tuple structures introduced by the deep copies. This should be + // done before removing copies (RemoveUnnecessaryCopies) because tuple + // simplification changes dependencies in the graph which changes live range + // interference in the graph. Also run DCE to remove the dead Tuple/GTE + // instructions introduced by tuple simplification. + TupleSimplifier tuple_simplifier; + HloDCE dce; + TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status()); + TF_RETURN_IF_ERROR(dce.Run(module).status()); + + TF_DCHECK_OK(VerifyNoLiveRangeInterference(module)); + + MaybeDumpModule("after adding copies to resolve interference", *module); + + DependencyHloOrdering ordering(module); + TF_RETURN_IF_ERROR( + RemoveUnnecessaryCopies(ordering, existing_copies, module)); + + MaybeDumpModule("after removing unnecessary copies", *module); + + TF_RETURN_IF_ERROR(AddSpecialCaseCopies(*call_graph, module)); + + MaybeDumpModule("after adding special-case copies", *module); + + TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status()); + TF_RETURN_IF_ERROR(dce.Run(module).status()); + TF_DCHECK_OK(VerifyNoLiveRangeInterference(module)); + + MaybeDumpModule("after copy insertion", *module); + + if (VLOG_IS_ON(1)) { + int64 num_total_copies = 0; + for (HloComputation* computation : module->computations()) { + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kCopy) { + num_total_copies++; + } + } } - instructions_to_copy.push_back(root_copier); + VLOG(1) << "Num copies before copy-insertion: " << existing_copies.size(); + VLOG(1) << "Num copies after copy-insertion: " << num_total_copies; } - // Add copies of while 'init' operand instructions, if needed. 'shared_copies' - // is used to ensure that multiple while loops can share a single copy of the - // same entry parameter or constant, if all loops use it read-only. - // - // TODO(b/33301720) Remove redundant while instruction copies. - FlatMap shared_copies; - for (HloInstruction* while_hlo : while_instructions) { - // Fix read_only_indices to account for entry constants. Also - // initialize copy_overrides, which ensures a single copy for each read-only - // constant that is used in multiple while loops. - ShapeTree* read_only_indices = - &while_body_read_only_indices[while_hlo->while_body()]; - TF_ASSIGN_OR_RETURN( - const ShapeTree copy_overrides, - RevertReadOnlyIndicesForConstants(while_hlo, points_to_analysis, - read_only_indices, &shared_copies)); - // Create InstructionCopier for init operand of while instruction. - HloInstruction* init_hlo = while_hlo->mutable_operand(0); - InstructionCopier init_copier(init_hlo, {while_hlo}); - init_copier.SetReadOnlyIndices(*read_only_indices); - init_copier.SetCopyOverrides(copy_overrides); - // Record 'init' buffer indices which point-to a Constant or Parameter. - TF_RETURN_IF_ERROR(init_copier.RecordIndicesWhichPointToParamOrConstant( - points_to_analysis)); - // Record indices necessary to colocate while and init operand buffers. - TF_RETURN_IF_ERROR(init_copier.RecordIndicesToCopyForColocatingBuffers( - *liveness, while_hlo, /*read_only_indices_out=*/nullptr)); - instructions_to_copy.push_back(init_copier); + return true; +} + +namespace { + +bool IsWhileBody(const HloComputation* computation, + const CallGraph& call_graph) { + const CallGraphNode& node = call_graph.GetNode(computation); + + if (node.context() == CallContext::kSequential && + !node.caller_callsites().empty()) { + // Callgraph should be flattened so sequential context computations can + // have at most one caller. + CHECK_EQ(node.caller_callsites().size(), 1); + const HloInstruction* calling_instruction = + node.caller_callsites()[0].instruction(); + if (calling_instruction->opcode() == HloOpcode::kWhile && + calling_instruction->while_body() == node.computation()) { + return true; + } } + return false; +} - for (InstructionCopier& to_copy : instructions_to_copy) { - if (to_copy.HasAllIndicesFalse()) { +} // namespace + +/* static */ StatusOr CopyInsertion::AddCopiesForBufferAssignment( + HloModule* module) { + std::unique_ptr call_graph = CallGraph::Build(module); + TF_ASSIGN_OR_RETURN(std::unique_ptr dataflow, + HloDataflowAnalysis::Run(module)); + + bool changed = false; + + // If a buffer live out of a computation is a constant, a parameter, or not + // defined in the computation, then copy it to account for the limited + // computation-scoped analysis in buffer assignment. An exception to this rule + // is the while body which is handled properly without copies. + for (HloComputation* computation : module->computations()) { + if (computation == module->entry_computation() || + IsWhileBody(computation, *call_graph)) { continue; } - changed = true; - // Copy instruction at recorded buffer indices. - HloComputation* computation = to_copy.instruction()->parent(); - HloInstruction* copy = to_copy.Copy(); - if (to_copy.instruction() == computation->root_instruction()) { - computation->set_root_instruction(copy); + HloInstruction* root = computation->root_instruction(); + ShapeTree indices_to_copy(root->shape(), /*init_value=*/false); + bool copy_root = false; + for (const auto& pair : dataflow->GetInstructionValueSet(root)) { + const ShapeIndex& index = pair.first; + const HloValueSet& value_set = pair.second; + for (const HloValue* value : value_set.values()) { + HloInstruction* def = value->defining_instruction(); + if (def->parent() != computation || + def->opcode() == HloOpcode::kConstant || + def->opcode() == HloOpcode::kParameter) { + *indices_to_copy.mutable_element(index) = true; + copy_root = true; + } + } + } + if (copy_root) { + TF_ASSIGN_OR_RETURN( + HloInstruction * root_copy, + computation->DeepCopyInstruction(root, &indices_to_copy)); + computation->set_root_instruction(root_copy); + changed = true; } } - VLOG(3) << "After copy insertion for module " << module->name(); - XLA_VLOG_LINES(3, module->ToString()); + TupleSimplifier tuple_simplifier; + HloDCE dce; + TF_ASSIGN_OR_RETURN(bool tuple_simplifier_changed, + tuple_simplifier.Run(module)); + TF_ASSIGN_OR_RETURN(bool dce_changed, dce.Run(module)); - return changed; + return changed || tuple_simplifier_changed || dce_changed; } } // namespace xla diff --git a/tensorflow/compiler/xla/service/copy_insertion.h b/tensorflow/compiler/xla/service/copy_insertion.h index 28bb62e40c7674960dbb1bb63dc8967b06956028..65e3d31e347e2cb249a072e7d06ca10c55401748 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.h +++ b/tensorflow/compiler/xla/service/copy_insertion.h @@ -25,12 +25,25 @@ limitations under the License. namespace xla { -// HLO pass which inserts a copy of the root instruction (creating a new root) -// if the root is or points-to any constant or parameter instruction. -// If the root instruction is a Tuple, only tuple elements which point to -// constant or parameter instructions will be copied. -// Copy insertion is necessary because constant and parameter arrays have -// different lifetimes than computation results. +// Copy insertion is a legalization HLO pass which inserts copies (kCopy +// instructions) to eliminate several kinds of problems in the HLO module. +// +// (1) Entry parameter or a constant live out of the entry computation. Entry +// computation arguments and constants have different lifetimes than the +// computation result and cannot share the same allocation. Parameters and +// constants live out of non-entry computations do not need copies. +// +// (2) Different values which are simultaneously live and which must be held +// in the same buffer. This can occur in while bodies. Specifically, the +// while loop state (the arguments to the while instruction) is updated +// in-place and the update may clobber the value from the previous +// iteration before the previous value is dead. Computations called from +// kCall instructions do not need such copies because kCall has no update +// in-place semantics. +// +// (3) The buffer set of the root instruction of the entry computation must be +// unambiguous and distinct. That is, InstructionAliasSet::IsAmbiguous and +// InstructionAliasSet::IsDistinct return true. class CopyInsertion : public HloPassInterface { public: tensorflow::StringPiece name() const override { return "copy-insertion"; } @@ -39,14 +52,16 @@ class CopyInsertion : public HloPassInterface { // (copies were inserted). StatusOr Run(HloModule* module) override; - protected: - // Returns a copy of `hlo`. Looks in inserted_copies_ first to avoid making - // duplicate copies. - StatusOr FindOrInsertCopy(HloInstruction* hlo); - - // A map containing all copies inserted during the copy insertion pass. The - // key is the copied instruction and the value is the copy. - tensorflow::gtl::FlatMap inserted_copies_; + // The CPU and GPU backend need additional copies added due to deficiencies in + // buffer assignment. Specifically, copies are needed for constants live-out + // of computations, and for values which are live-in and live-out of the same + // computation. These copies are needed because buffer-assignment uses a + // computation-scoped analyis (TuplePointsToAnalysis) and has limited + // visibility across computation boundaries. This method adds these necessary + // copies. Returns whether the module was modified. + // + // TODO(b/62548313): Remove this when buffer assignment is module-scoped. + static StatusOr AddCopiesForBufferAssignment(HloModule* module); }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc index a2eacc5c7dae2424e01fdd49d82546b5488d4312..8388574716ad1b78eb8868a8cd732005050b3310 100644 --- a/tensorflow/compiler/xla/service/copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc @@ -17,18 +17,19 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/test_benchmark.h" namespace op = xla::testing::opcode_matchers; @@ -37,35 +38,53 @@ namespace { using ::testing::UnorderedElementsAre; +int64 CountCopies(const HloComputation& computation) { + int64 count = 0; + for (const auto& instruction : computation.instructions()) { + if (instruction->opcode() == HloOpcode::kCopy) { + count++; + } + } + return count; +} + +int64 CountCopies(const HloModule& module) { + int64 count = 0; + for (const auto& computation : module.computations()) { + count += CountCopies(*computation); + } + return count; +} + +int64 CountControlEdges(const HloComputation& computation) { + int64 count = 0; + for (const auto& instruction : computation.instructions()) { + count += instruction->control_successors().size(); + } + return count; +} + +int64 CountControlEdges(const HloModule& module) { + int64 count = 0; + for (const auto& computation : module.computations()) { + count += CountControlEdges(*computation); + } + return count; +} + class CopyInsertionTest : public HloTestBase { protected: void InsertCopies(HloModule* module) { CopyInsertion copy_insertion; - EXPECT_IS_OK(copy_insertion.Run(module).status()); - - // Verify the points to set of the root of the computation after copy - // insertion contains no constants or parameters, and is distinct and - // non-ambiguous. - auto points_to_analysis = - TuplePointsToAnalysis::Run(module).ConsumeValueOrDie(); - const auto& points_to = points_to_analysis->GetPointsToSet( - module->entry_computation()->root_instruction()); - EXPECT_TRUE(points_to.IsDistinct()); - EXPECT_TRUE(!points_to.IsAmbiguous()); - - auto maybe_live_out_buffers = - points_to_analysis - ->GetPointsToSet(module->entry_computation()->root_instruction()) - .CreateFlattenedSet(); - - for (const LogicalBuffer* buffer : maybe_live_out_buffers) { - EXPECT_NE(buffer->instruction()->opcode(), HloOpcode::kConstant); - EXPECT_NE(buffer->instruction()->opcode(), HloOpcode::kParameter); - } + ASSERT_IS_OK(copy_insertion.Run(module).status()); } + + const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {}); }; TEST_F(CopyInsertionTest, SingleParameter) { + // Computation is a single parameter passed into a tuple. The parameter should + // be copied before entering the tuple. auto builder = HloComputation::Builder(TestName()); HloInstruction* x = builder.AddInstruction( HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "x")); @@ -77,14 +96,15 @@ TEST_F(CopyInsertionTest, SingleParameter) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - HloInstruction* old_root = module->entry_computation()->root_instruction(); InsertCopies(module.get()); EXPECT_THAT(module->entry_computation()->root_instruction(), - op::Tuple(op::Copy(old_root->operand(0)))); + op::Tuple(op::Copy(x))); } TEST_F(CopyInsertionTest, SingleConstant) { + // Computation is a single constant passed into a tuple. The parameter should + // be copied before entering the tuple. auto builder = HloComputation::Builder(TestName()); HloInstruction* constant = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(1.0))); @@ -96,11 +116,42 @@ TEST_F(CopyInsertionTest, SingleConstant) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - HloInstruction* old_root = module->entry_computation()->root_instruction(); InsertCopies(module.get()); + EXPECT_EQ(CountCopies(*module), 1); EXPECT_THAT(module->entry_computation()->root_instruction(), - op::Tuple(op::Copy(old_root->operand(0)))); + op::Tuple(op::Copy(constant))); +} + +TEST_F(CopyInsertionTest, ExistingCopiesNotRemoved) { + // Verify that an kCopy instructions which exist in the pass before + // copy-insertion remain in the graph after copy-insertion. + auto module = CreateNewModule(); + + auto builder = HloComputation::Builder(TestName()); + HloInstruction* constant = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + HloInstruction* copy_1 = builder.AddInstruction(HloInstruction::CreateUnary( + constant->shape(), HloOpcode::kCopy, constant)); + HloInstruction* copy_2 = builder.AddInstruction(HloInstruction::CreateUnary( + constant->shape(), HloOpcode::kCopy, constant)); + HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary( + constant->shape(), HloOpcode::kAdd, copy_1, copy_2)); + HloInstruction* add_copy = builder.AddInstruction( + HloInstruction::CreateUnary(constant->shape(), HloOpcode::kCopy, add)); + + module->AddEntryComputation(builder.Build()); + + EXPECT_EQ(CountCopies(*module), 3); + + InsertCopies(module.get()); + + EXPECT_EQ(CountCopies(*module), 3); + + EXPECT_EQ(module->entry_computation()->root_instruction(), add_copy); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + op::Copy(op::Add(op::Copy(op::Constant()), op::Copy(op::Constant())))); } TEST_F(CopyInsertionTest, MultipleConstantsAndParameters) { @@ -127,12 +178,12 @@ TEST_F(CopyInsertionTest, MultipleConstantsAndParameters) { auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - HloInstruction* old_root = module->entry_computation()->root_instruction(); InsertCopies(module.get()); + EXPECT_EQ(CountCopies(*module), 2); - EXPECT_THAT(module->entry_computation()->root_instruction(), - op::Tuple(op::Copy(old_root->operand(0)), - op::Copy(old_root->operand(1)), old_root->operand(2))); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + op::Tuple(op::Copy(constant2), op::Copy(x), op::Add(constant1, y))); } TEST_F(CopyInsertionTest, AmbiguousPointsToSet) { @@ -165,6 +216,7 @@ TEST_F(CopyInsertionTest, AmbiguousPointsToSet) { HloInstruction* old_root = module->entry_computation()->root_instruction(); InsertCopies(module.get()); + EXPECT_EQ(CountCopies(*module), 2); EXPECT_THAT(module->entry_computation()->root_instruction(), op::Tuple(op::Copy(op::GetTupleElement(old_root)), @@ -187,6 +239,7 @@ TEST_F(CopyInsertionTest, BitcastParameter) { HloInstruction* old_root = module->entry_computation()->root_instruction(); InsertCopies(module.get()); + EXPECT_EQ(CountCopies(*module), 1); EXPECT_THAT(module->entry_computation()->root_instruction(), op::Copy(old_root)); @@ -208,6 +261,7 @@ TEST_F(CopyInsertionTest, BitcastConstant) { HloInstruction* old_root = module->entry_computation()->root_instruction(); InsertCopies(module.get()); + EXPECT_EQ(CountCopies(*module), 1); EXPECT_THAT(module->entry_computation()->root_instruction(), op::Copy(old_root)); @@ -227,11 +281,11 @@ TEST_F(CopyInsertionTest, BitcastTupleElementParameter) { EXPECT_THAT(x->users(), UnorderedElementsAre(bitcast)); - HloInstruction* old_root = module->entry_computation()->root_instruction(); InsertCopies(module.get()); + EXPECT_EQ(CountCopies(*module), 1); EXPECT_THAT(module->entry_computation()->root_instruction(), - op::Tuple(op::Copy(old_root->operand(0)))); + op::Tuple(op::Copy(bitcast))); } TEST_F(CopyInsertionTest, NestedTupleParameter) { @@ -257,6 +311,8 @@ TEST_F(CopyInsertionTest, NestedTupleParameter) { HloInstruction* old_root = module->entry_computation()->root_instruction(); InsertCopies(module.get()); + EXPECT_EQ(CountCopies(*module), 3); + HloInstruction* new_root = module->entry_computation()->root_instruction(); EXPECT_NE(old_root, new_root); @@ -283,7 +339,7 @@ TEST_F(CopyInsertionTest, ElementOfNestedTupleParameter) { ShapeUtil::MakeShape(F32, {42})}), "param0")); - // The return value of the computation is the zero-th elemnt of the nested + // The return value of the computation is the zero-th element of the nested // tuple. This element is itself a tuple. auto gte = builder.AddInstruction(HloInstruction::CreateGetTupleElement( ShapeUtil::GetSubshape(param->shape(), {0}), param, 0)); @@ -293,12 +349,13 @@ TEST_F(CopyInsertionTest, ElementOfNestedTupleParameter) { EXPECT_EQ(gte, module->entry_computation()->root_instruction()); - HloInstruction* old_root = module->entry_computation()->root_instruction(); InsertCopies(module.get()); + EXPECT_EQ(CountCopies(*module), 2); - EXPECT_THAT(module->entry_computation()->root_instruction(), - op::Tuple(op::Copy(op::GetTupleElement(old_root)), - op::Copy(op::GetTupleElement(old_root)))); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + op::Tuple(op::Copy(op::GetTupleElement(op::GetTupleElement(param))), + op::Copy(op::GetTupleElement(op::GetTupleElement(param))))); } TEST_F(CopyInsertionTest, AmbiguousTopLevelRoot) { @@ -331,6 +388,7 @@ TEST_F(CopyInsertionTest, AmbiguousTopLevelRoot) { HloInstruction* old_root = module->entry_computation()->root_instruction(); InsertCopies(module.get()); + EXPECT_EQ(CountCopies(*module), 1); EXPECT_THAT(module->entry_computation()->root_instruction(), op::Copy(old_root)); @@ -346,12 +404,10 @@ class WhileCopyInsertionTest : public CopyInsertionTest { // The parameter 'nested' specifies the loop state shape from which to // read the induction variable. std::unique_ptr BuildConditionComputation( - bool nested = false) { + const Shape& loop_state_shape) { auto builder = HloComputation::Builder(TestName() + ".Condition"); auto limit_const = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(10))); - const Shape& loop_state_shape = - nested ? nested_loop_state_shape_ : loop_state_shape_; auto loop_state = builder.AddInstruction( HloInstruction::CreateParameter(0, loop_state_shape, "loop_state")); auto induction_variable = @@ -582,7 +638,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest { auto loop_state_init = builder.AddInstruction( HloInstruction::CreateTuple({induction_var_init, inner_init})); auto while_hlo = builder.AddInstruction(HloInstruction::CreateWhile( - loop_state_shape_, condition, body, loop_state_init)); + loop_state_init->shape(), condition, body, loop_state_init)); module_->AddEntryComputation(builder.Build()); return while_hlo; } @@ -658,11 +714,28 @@ class WhileCopyInsertionTest : public CopyInsertionTest { auto one_vec = builder.AddInstruction(HloInstruction::CreateConstant( Literal::CreateR1({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); // Take a reference to 'data_init' to make it interfere with while result. - builder.AddInstruction(HloInstruction::CreateBinary( + auto add = builder.AddInstruction(HloInstruction::CreateBinary( data_shape_, HloOpcode::kAdd, data_init, one_vec)); - return BuildWhileInstructionWithCustomInit(loop_state_shape_, data_init, - &builder); + auto xla_while = BuildWhileInstructionWithCustomInit(loop_state_shape_, + data_init, &builder); + + // Add an additional binary operation operating on the while and the + // interfering add so that neither operation is dead. + auto gte = xla_while->parent()->AddInstruction( + HloInstruction::CreateGetTupleElement( + ShapeUtil::GetSubshape(xla_while->shape(), {1}), xla_while, 1)); + auto sub = xla_while->parent()->AddInstruction(HloInstruction::CreateBinary( + data_shape_, HloOpcode::kSubtract, add, gte)); + auto gte0 = xla_while->parent()->AddInstruction( + HloInstruction::CreateGetTupleElement( + ShapeUtil::GetSubshape(xla_while->shape(), {0}), xla_while, 0)); + auto tuple = xla_while->parent()->AddInstruction( + HloInstruction::CreateTuple({gte0, sub})); + + xla_while->parent()->set_root_instruction(tuple); + + return xla_while; } HloInstruction* BuildWhileInstructionWithCustomInit( @@ -672,8 +745,8 @@ class WhileCopyInsertionTest : public CopyInsertionTest { ShapeUtil::Equal(loop_state_shape, nested_loop_state_shape_); auto induction_var_init = builder->AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(0))); - auto condition = - module_->AddEmbeddedComputation(BuildConditionComputation(nested)); + auto condition = module_->AddEmbeddedComputation( + BuildConditionComputation(loop_state_shape)); auto body = module_->AddEmbeddedComputation( BuildIndependentBodyComputation(nested)); auto loop_state_init = builder->AddInstruction( @@ -706,23 +779,21 @@ class WhileCopyInsertionTest : public CopyInsertionTest { // CopyInsertion pass should not generate any copies. // TEST_F(WhileCopyInsertionTest, IndependentTupleElements) { - auto condition = module_->AddEmbeddedComputation(BuildConditionComputation()); + auto condition = module_->AddEmbeddedComputation( + BuildConditionComputation(loop_state_shape_)); auto body = module_->AddEmbeddedComputation(BuildIndependentBodyComputation()); auto while_hlo = BuildWhileInstruction(condition, body); - const HloInstruction* old_init = while_hlo->operand(0); - HloInstruction* old_root = body->root_instruction(); InsertCopies(module_.get()); - HloInstruction* new_root = body->root_instruction(); - const HloInstruction* new_init = while_hlo->operand(0); - // No copies should be inserted so root should not be updated. - EXPECT_EQ(old_root, new_root); + // Body should have no copies as the adds can be done inplace. + EXPECT_EQ(CountCopies(*body), 0); + EXPECT_EQ(CountControlEdges(*module_), 0); - // Both init indices need copies. - EXPECT_THAT(new_init, op::Tuple(op::Copy(old_init->operand(0)), - op::Copy(old_init->operand(1)))); + // Both init indices need copies as they are constants. + EXPECT_THAT(while_hlo->operand(0), + op::Tuple(op::Copy(op::Constant()), op::Copy(op::Constant()))); } // Tests while body computation with dependent tuple elements: @@ -737,20 +808,33 @@ TEST_F(WhileCopyInsertionTest, IndependentTupleElements) { // Tuple(Copy(out0), out1) // TEST_F(WhileCopyInsertionTest, DependentTupleElements) { - auto condition = module_->AddEmbeddedComputation(BuildConditionComputation()); + auto condition = module_->AddEmbeddedComputation( + BuildConditionComputation(loop_state_shape_)); auto body = module_->AddEmbeddedComputation(BuildDependentBodyComputation()); auto while_hlo = BuildWhileInstruction(condition, body); - const HloInstruction* old_init = while_hlo->operand(0); - HloInstruction* old_root = body->root_instruction(); InsertCopies(module_.get()); - HloInstruction* new_root = body->root_instruction(); - const HloInstruction* new_init = while_hlo->operand(0); - EXPECT_THAT(new_root, - op::Tuple(op::Copy(old_root->operand(0)), old_root->operand(1))); - EXPECT_THAT(new_init, op::Tuple(op::Copy(old_init->operand(0)), - op::Copy(old_init->operand(1)))); + EXPECT_EQ(CountCopies(*body), 1); + EXPECT_EQ(CountControlEdges(*body), 0); + + EXPECT_THAT( + body->root_instruction(), + op::Tuple(op::Add(), op::Add(op::GetTupleElement(), op::Broadcast()))); + + auto add = body->root_instruction()->operand(0); + auto bcast = body->root_instruction()->operand(1)->operand(1); + ASSERT_EQ(add->opcode(), HloOpcode::kAdd); + ASSERT_EQ(bcast->opcode(), HloOpcode::kBroadcast); + + EXPECT_THAT( + while_hlo->while_body()->root_instruction(), + op::Tuple(op::Add(op::Copy(), op::Constant()), + op::Add(op::GetTupleElement(), op::Broadcast(op::Copy())))); + + // Both init indices need copies as they are constants. + EXPECT_THAT(while_hlo->operand(0), + op::Tuple(op::Copy(op::Constant()), op::Copy(op::Constant()))); } // Tests while body computation with read-only tuple element 0: @@ -768,33 +852,26 @@ TEST_F(WhileCopyInsertionTest, DependentTupleElements) { // // CopyInsertion pass should not generate any copies for the while body. TEST_F(WhileCopyInsertionTest, DependentTupleElements_OneReadOnly) { - auto condition = module_->AddEmbeddedComputation(BuildConditionComputation()); + auto condition = module_->AddEmbeddedComputation( + BuildConditionComputation(loop_state_shape_)); auto body = module_->AddEmbeddedComputation( BuildDependentBodyOneReadOnlyComputation()); - auto while_hlo = BuildWhileInstruction(condition, body); + BuildWhileInstruction(condition, body); - const HloInstruction* old_init = while_hlo->operand(0); - HloInstruction* old_root = body->root_instruction(); InsertCopies(module_.get()); - HloInstruction* new_root = body->root_instruction(); - const HloInstruction* new_init = while_hlo->operand(0); - - // No copies should be inserted in the body, so root should not be updated. - EXPECT_EQ(old_root, new_root); - // Both indices need copies, even though Index 0 is read-only, since both are - // constants, which must be copied. - EXPECT_THAT(new_init, op::Tuple(op::Copy(old_init->operand(0)), - op::Copy(old_init->operand(1)))); + // No copies or control edges should be inserted. The body is legal as is. + EXPECT_EQ(CountCopies(*body), 0); + EXPECT_EQ(CountControlEdges(*body), 0); } // Same as above, but with two while loops, sharing entry parameters. TEST_F(WhileCopyInsertionTest, DependentTupleElements_OneReadOnly_TwoLoops_EntryParams) { - auto condition1 = - module_->AddEmbeddedComputation(BuildConditionComputation()); - auto condition2 = - module_->AddEmbeddedComputation(BuildConditionComputation()); + auto condition1 = module_->AddEmbeddedComputation( + BuildConditionComputation(loop_state_shape_)); + auto condition2 = module_->AddEmbeddedComputation( + BuildConditionComputation(loop_state_shape_)); auto body1 = module_->AddEmbeddedComputation( BuildDependentBodyOneReadOnlyComputation()); auto body2 = module_->AddEmbeddedComputation( @@ -812,30 +889,46 @@ TEST_F(WhileCopyInsertionTest, loop_state_shape_, condition1, body1, loop_init)); auto while_hlo2 = builder.AddInstruction(HloInstruction::CreateWhile( loop_state_shape_, condition2, body2, loop_init)); - module_->AddEntryComputation(builder.Build()); + + // Add a couple elements from each of the while so both whiles are live. + auto gte1 = builder.AddInstruction(HloInstruction::CreateGetTupleElement( + ShapeUtil::GetSubshape(while_hlo1->shape(), {0}), while_hlo1, 0)); + auto gte2 = builder.AddInstruction(HloInstruction::CreateGetTupleElement( + ShapeUtil::GetSubshape(while_hlo2->shape(), {0}), while_hlo2, 0)); + builder.AddInstruction( + HloInstruction::CreateBinary(gte1->shape(), HloOpcode::kAdd, gte1, gte2)); + + auto entry = module_->AddEntryComputation(builder.Build()); InsertCopies(module_.get()); - // Both while loops alias iter_param, since index 0 is read-only in the body. - EXPECT_EQ(while_hlo1->operand(0)->operand(0), - while_hlo2->operand(0)->operand(0)); - EXPECT_EQ(while_hlo1->operand(0)->operand(0), iter_param); + // Neither body should have any copies or control edges in them. + EXPECT_EQ(CountCopies(*body1), 0); + EXPECT_EQ(CountCopies(*body2), 0); + EXPECT_EQ(CountControlEdges(*body1), 0); + EXPECT_EQ(CountControlEdges(*body2), 0); - // Each while loop gets its own copy of data_param, since index 1 is not - // read-only in the body. + // Only two copies should be necessary. Each of the whiles should have + // a copy of tuple element 1 (init value is a parameter, and the element is + // not non-read-only) so each of the while bodies gets its own buffer to write + // element 1 into. + EXPECT_EQ(CountCopies(*entry), 2); + + EXPECT_EQ(while_hlo1->operand(0)->operand(1)->opcode(), HloOpcode::kCopy); + EXPECT_EQ(while_hlo2->operand(0)->operand(1)->opcode(), HloOpcode::kCopy); + + // The two copies of element 1 should be different. EXPECT_NE(while_hlo1->operand(0)->operand(1), while_hlo2->operand(0)->operand(1)); - EXPECT_THAT(while_hlo1->operand(0)->operand(1), op::Copy(data_param)); - EXPECT_THAT(while_hlo2->operand(0)->operand(1), op::Copy(data_param)); } // Same as above, but with two while loops, sharing non-parameters. TEST_F(WhileCopyInsertionTest, DependentTupleElements_OneReadOnly_TwoLoops_NonParams) { - auto condition1 = - module_->AddEmbeddedComputation(BuildConditionComputation()); - auto condition2 = - module_->AddEmbeddedComputation(BuildConditionComputation()); + auto condition1 = module_->AddEmbeddedComputation( + BuildConditionComputation(loop_state_shape_)); + auto condition2 = module_->AddEmbeddedComputation( + BuildConditionComputation(loop_state_shape_)); auto body1 = module_->AddEmbeddedComputation( BuildDependentBodyOneReadOnlyComputation()); auto body2 = module_->AddEmbeddedComputation( @@ -858,21 +951,28 @@ TEST_F(WhileCopyInsertionTest, loop_state_shape_, condition1, body1, loop_init)); auto while_hlo2 = builder.AddInstruction(HloInstruction::CreateWhile( loop_state_shape_, condition2, body2, loop_init)); - module_->AddEntryComputation(builder.Build()); + + // Add a couple elements from each of the while so both whiles are not dead. + auto gte1 = builder.AddInstruction(HloInstruction::CreateGetTupleElement( + ShapeUtil::GetSubshape(while_hlo1->shape(), {0}), while_hlo1, 0)); + auto gte2 = builder.AddInstruction(HloInstruction::CreateGetTupleElement( + ShapeUtil::GetSubshape(while_hlo2->shape(), {0}), while_hlo2, 0)); + builder.AddInstruction( + HloInstruction::CreateBinary(gte1->shape(), HloOpcode::kAdd, gte1, gte2)); + auto entry = module_->AddEntryComputation(builder.Build()); InsertCopies(module_.get()); - // No copies of iter_value are necessary, since index 0 is read-only in both - // while bodies. - EXPECT_EQ(while_hlo1->operand(0)->operand(0), iter_value); - EXPECT_EQ(while_hlo2->operand(0)->operand(0), iter_value); + // Ideally only one copy should be necessary. One of the whiles should + // have a copy of tuple element 1 (the non-read-only element) so each of the + // while bodies gets its own buffer to write element 1 into. However, the + // analysis isn't perfect and adds an additional copy of element 0. + EXPECT_EQ(CountCopies(*entry), 2); - // Each while loop gets its own copy of data_value, since index 1 is not - // read-only in the body. - EXPECT_NE(while_hlo1->operand(0)->operand(1), - while_hlo2->operand(0)->operand(1)); - EXPECT_THAT(while_hlo1->operand(0)->operand(1), op::Copy(data_value)); - EXPECT_THAT(while_hlo2->operand(0)->operand(1), op::Copy(data_value)); + EXPECT_THAT(while_hlo1->operand(0), + op::Tuple(op::Exp(), op::Copy(op::Exp()))); + EXPECT_THAT(while_hlo2->operand(0), + op::Tuple(op::Exp(), op::Copy(op::Exp()))); } // Tests while body computation with nested tuple elements: @@ -905,18 +1005,34 @@ TEST_F(WhileCopyInsertionTest, // Tuple // new root // TEST_F(WhileCopyInsertionTest, NestedTupleElements) { - auto condition = - module_->AddEmbeddedComputation(BuildConditionComputation(true)); + auto condition = module_->AddEmbeddedComputation( + BuildConditionComputation(nested_loop_state_shape_)); auto body = module_->AddEmbeddedComputation(BuildNestedBodyComputation()); BuildWhileInstruction(condition, body, true); - HloInstruction* old_root = body->root_instruction(); + // HloInstruction* old_root = body->root_instruction(); InsertCopies(module_.get()); - EXPECT_THAT(body->root_instruction(), - op::Tuple(old_root->operand(0), - op::Tuple(old_root->operand(1)->operand(0), - op::Copy(old_root->operand(1)->operand(1))))); + // The only copy necessary is for the kReverse as it cannot be done + // in-place (instruction can share buffer with operand). The other elements of + // the loop state are kAdd instructions which can be done in-place. + EXPECT_EQ(CountCopies(*body), 1); + + // Each element of the init needs a copy as all are constants. + EXPECT_EQ(CountCopies(*module_), 4); + + // Either the kReverse itself must be copied or the operand of the kReverse + // must be copied. + if (body->root_instruction()->operand(1)->operand(1)->opcode() == + HloOpcode::kCopy) { + EXPECT_THAT( + body->root_instruction(), + op::Tuple(op::Add(), op::Tuple(op::Add(), op::Copy(op::Reverse())))); + } else { + EXPECT_THAT( + body->root_instruction(), + op::Tuple(op::Add(), op::Tuple(op::Add(), op::Reverse(op::Copy())))); + } } // Tests while init instruction which points-to a constant. @@ -927,11 +1043,13 @@ TEST_F(WhileCopyInsertionTest, NestedTupleElements) { // TEST_F(WhileCopyInsertionTest, InitPointsToConstant) { auto while_hlo = BuildWhileInstruction_InitPointsToConstant(); - auto old_init = while_hlo->operand(0); + InsertCopies(module_.get()); + EXPECT_EQ(CountCopies(*while_hlo->while_body()), 0); + EXPECT_EQ(CountCopies(*module_), 2); - EXPECT_THAT(while_hlo->operand(0), op::Tuple(op::Copy(old_init->operand(0)), - op::Copy(old_init->operand(1)))); + EXPECT_THAT(while_hlo->operand(0), + op::Tuple(op::Copy(op::Constant()), op::Copy(op::Constant()))); } // Tests while init instruction which points-to a parameter. @@ -942,11 +1060,13 @@ TEST_F(WhileCopyInsertionTest, InitPointsToConstant) { // TEST_F(WhileCopyInsertionTest, InitPointsToParameter) { auto while_hlo = BuildWhileInstruction_InitPointsToParameter(); - auto old_init = while_hlo->operand(0); + InsertCopies(module_.get()); + EXPECT_EQ(CountCopies(*while_hlo->while_body()), 0); + EXPECT_EQ(CountCopies(*module_), 2); - EXPECT_THAT(while_hlo->operand(0), op::Tuple(op::Copy(old_init->operand(0)), - op::Copy(old_init->operand(1)))); + EXPECT_THAT(while_hlo->operand(0), + op::Tuple(op::Copy(op::Constant()), op::Copy(op::Parameter()))); } // Tests while init instruction which has an ambiguous points-to set. @@ -975,15 +1095,34 @@ TEST_F(WhileCopyInsertionTest, InitPointsToParameter) { // TEST_F(WhileCopyInsertionTest, InitPointsToAmbiguous) { auto while_hlo = BuildWhileInstruction_InitPointsToAmbiguous(); - auto old_init = while_hlo->operand(0); - InsertCopies(module_.get()); - EXPECT_THAT( - while_hlo->operand(0), - op::Tuple( - op::Copy(old_init->operand(0)), - op::Tuple(op::Copy(op::GetTupleElement(old_init->operand(1))), - op::Copy(op::GetTupleElement(old_init->operand(1)))))); + InsertCopies(module_.get()); + EXPECT_EQ(CountCopies(*module_), 4); + // The entry computation requires three copies to resolve the ambiguity of two + // init elements and the constant passed in as one of the init elements. + EXPECT_EQ(CountCopies(*module_->entry_computation()), 3); + EXPECT_THAT(while_hlo->operand(0), + op::Tuple(op::Copy(op::Constant()), + op::Tuple(op::Copy(op::GetTupleElement()), + op::Copy(op::GetTupleElement())))); + + // The body requires one copy because the buffer set is not distinct: the + // result of one of the adds is written into two elements of the output of the + // loop body. Either element might be copied. + EXPECT_EQ(CountCopies(*while_hlo->while_body()), 1); + if (while_hlo->while_body() + ->root_instruction() + ->operand(1) + ->operand(0) + ->opcode() == HloOpcode::kCopy) { + EXPECT_THAT( + while_hlo->while_body()->root_instruction(), + op::Tuple(op::Add(), op::Tuple(op::Copy(op::Add()), op::Add()))); + } else { + EXPECT_THAT( + while_hlo->while_body()->root_instruction(), + op::Tuple(op::Add(), op::Tuple(op::Add(), op::Copy(op::Add())))); + } } // Tests while init instruction which has a non-distinct points-to set. @@ -1011,13 +1150,43 @@ TEST_F(WhileCopyInsertionTest, InitPointsToAmbiguous) { // TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinct) { auto while_hlo = BuildWhileInstruction_InitPointsToNonDistinct(); - auto old_init = while_hlo->operand(0); + InsertCopies(module_.get()); - EXPECT_THAT(while_hlo->operand(0), - op::Tuple(op::Copy(old_init->operand(0)), - op::Tuple(op::Copy(old_init->operand(1)->operand(0)), - op::Copy(old_init->operand(1)->operand(0))))); + // The entry computation requires two copies to resolve the non-disinctness of + // two init elements and the constant passed in as one of the init + // elements. Either element can be copied for the distinctness issue. + EXPECT_EQ(CountCopies(*module_->entry_computation()), 2); + if (while_hlo->operand(0)->operand(1)->operand(0)->opcode() == + HloOpcode::kCopy) { + EXPECT_THAT( + while_hlo->operand(0), + op::Tuple(op::Copy(op::Constant()), + op::Tuple(op::Copy(op::Broadcast()), op::Broadcast()))); + } else { + EXPECT_THAT( + while_hlo->operand(0), + op::Tuple(op::Copy(op::Constant()), + op::Tuple(op::Broadcast(), op::Copy(op::Broadcast())))); + } + + // The body requires one copy because the buffer set is not distinct: the + // result of one of the adds is written into two elements of the output of the + // loop body. Either element might be copied. + EXPECT_EQ(CountCopies(*while_hlo->while_body()), 1); + if (while_hlo->while_body() + ->root_instruction() + ->operand(1) + ->operand(0) + ->opcode() == HloOpcode::kCopy) { + EXPECT_THAT( + while_hlo->while_body()->root_instruction(), + op::Tuple(op::Add(), op::Tuple(op::Copy(op::Add()), op::Add()))); + } else { + EXPECT_THAT( + while_hlo->while_body()->root_instruction(), + op::Tuple(op::Add(), op::Tuple(op::Add(), op::Copy(op::Add())))); + } } // Tests while init instruction buffer which interferes with while result @@ -1031,11 +1200,13 @@ TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinct) { // TEST_F(WhileCopyInsertionTest, InitPointsToInterfering) { auto while_hlo = BuildWhileInstruction_InitPointsToInterfering(); - auto old_init = while_hlo->operand(0); + InsertCopies(module_.get()); + EXPECT_EQ(CountCopies(*module_), 2); + EXPECT_EQ(CountCopies(*while_hlo->while_body()), 0); - EXPECT_THAT(while_hlo->operand(0), op::Tuple(op::Copy(old_init->operand(0)), - op::Copy(old_init->operand(1)))); + EXPECT_THAT(while_hlo->operand(0), + op::Tuple(op::Copy(op::Constant()), op::Copy(op::Broadcast()))); } // Tests while init instruction buffer which has a non-distinct points-to set: @@ -1044,18 +1215,21 @@ TEST_F(WhileCopyInsertionTest, InitPointsToInterfering) { // Parameter(F32, {8}))) // // where the second and third parameters are identical *and* the tuple shared -// by another while instruction.. +// by another while instruction. // // Verifies that the resulting point-to set is distinct in the resulting Tuple // (non-identical Copys). In other words, verifies that copy sharing does not // insert identical copies to the resulting tuple. TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinctUsedByTwoWhileLoops) { - auto condition1 = - module_->AddEmbeddedComputation(BuildConditionComputation()); - auto condition2 = - module_->AddEmbeddedComputation(BuildConditionComputation()); // Loop body that outputs tuple comprises two elements dependent on the init // tuple. + const Shape& loop_state_shape = ShapeUtil::MakeTupleShape( + {induction_variable_shape_, data_shape_, data_shape_}); + + auto condition1 = module_->AddEmbeddedComputation( + BuildConditionComputation(loop_state_shape)); + auto condition2 = module_->AddEmbeddedComputation( + BuildConditionComputation(loop_state_shape)); auto body1 = module_->AddEmbeddedComputation(BuildDependentBodyComputation2()); auto body2 = @@ -1072,8 +1246,6 @@ TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinctUsedByTwoWhileLoops) { auto loop_init = builder.AddInstruction( HloInstruction::CreateTuple({iter_param, data_param, data_param})); - const Shape& loop_state_shape = ShapeUtil::MakeTupleShape( - {induction_variable_shape_, data_shape_, data_shape_}); // Two while loops shares the same loop init tuple. auto while_hlo1 = builder.AddInstruction(HloInstruction::CreateWhile( @@ -1081,43 +1253,478 @@ TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinctUsedByTwoWhileLoops) { auto while_hlo2 = builder.AddInstruction(HloInstruction::CreateWhile( loop_state_shape, condition2, body2, loop_init)); - module_->AddEntryComputation(builder.Build()); + // Add add instruction so neither while is dead. + auto gte1 = builder.AddInstruction(HloInstruction::CreateGetTupleElement( + ShapeUtil::GetSubshape(while_hlo1->shape(), {0}), while_hlo1, 0)); + auto gte2 = builder.AddInstruction(HloInstruction::CreateGetTupleElement( + ShapeUtil::GetSubshape(while_hlo1->shape(), {0}), while_hlo2, 0)); + builder.AddInstruction( + HloInstruction::CreateBinary(gte1->shape(), HloOpcode::kAdd, gte1, gte2)); - auto points_to_analysis = - TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie(); + module_->AddEntryComputation(builder.Build()); - // Asserts that the init tuples before copy insertion is non-distinct. - ASSERT_FALSE( - points_to_analysis->GetPointsToSet(while_hlo1->operand(0)).IsDistinct()); - ASSERT_FALSE( - points_to_analysis->GetPointsToSet(while_hlo2->operand(0)).IsDistinct()); + InsertCopies(module_.get()); - auto old_init1 = while_hlo1->operand(0); - auto old_init2 = while_hlo2->operand(0); + // None of the bodies should have copies or control flow edges. + EXPECT_EQ(CountCopies(*body1), 0); + EXPECT_EQ(CountCopies(*body2), 0); - InsertCopies(module_.get()); + // The loop bodies pass through elements 1 and 2 in the init tuple, so ideally + // these should not need to be copied before either while. However, copy + // insertion is not able to reason about the transparency of elements through + // while bodies in all circumstances so extra copies are added (b/xxx). + EXPECT_EQ(CountCopies(*module_->entry_computation()), 2); EXPECT_THAT(while_hlo1->operand(0), - op::Tuple(op::Copy(old_init1->operand(0)), - op::Copy(old_init1->operand(1)), - op::Copy(old_init1->operand(2)))); - + op::Tuple(op::Copy(), op::Parameter(), op::Parameter())); EXPECT_THAT(while_hlo2->operand(0), - op::Tuple(op::Copy(old_init2->operand(0)), - op::Copy(old_init2->operand(1)), - op::Copy(old_init2->operand(2)))); - - // Verifies the init tuples after copy insertion is distinct. - points_to_analysis = - TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie(); - const auto& points_to1 = - points_to_analysis->GetPointsToSet(while_hlo1->operand(0)); - EXPECT_TRUE(points_to1.IsDistinct()); - - const auto& points_to2 = - points_to_analysis->GetPointsToSet(while_hlo2->operand(0)); - EXPECT_TRUE(points_to2.IsDistinct()); + op::Tuple(op::Copy(), op::Parameter(), op::Parameter())); } +TEST_F(CopyInsertionTest, SwizzlingWhile) { + // Test a while instruction with a body which permutes its tuple parameter + // elements. + auto module = CreateNewModule(); + const Shape loop_state_shape = + ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); + + // Body simply interchanges the two tuple elements in the loop state. + auto body_builder = HloComputation::Builder("body"); + auto body_param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, loop_state_shape, "param")); + auto body_element_0 = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0)); + auto body_element_1 = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1)); + body_builder.AddInstruction( + HloInstruction::CreateTuple({body_element_1, body_element_0})); + HloComputation* body = module->AddEmbeddedComputation(body_builder.Build()); + + auto cond_builder = HloComputation::Builder("condition"); + cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, loop_state_shape, "param")); + auto cond_constant = cond_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(false))); + cond_builder.AddInstruction(HloInstruction::CreateUnary( + cond_constant->shape(), HloOpcode::kNot, cond_constant)); + HloComputation* condition = + module->AddEmbeddedComputation(cond_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + auto tuple = builder.AddInstruction( + HloInstruction::CreateTuple({constant1, constant2})); + auto xla_while = builder.AddInstruction( + HloInstruction::CreateWhile(loop_state_shape, condition, body, tuple)); + module->AddEntryComputation(builder.Build()); + + InsertCopies(module.get()); + + EXPECT_EQ(CountCopies(*module), 6); + + // The loop state elements should be copied at the parameter and at the root + // with a control edge in between (see DeepCopyAndAddControlEdges). This is + // technically one more copy than is strictly necessary, but in order to have + // only three copies the copies of different loop state elements must be + // ordered with a control edge. + EXPECT_EQ(CountCopies(*body), 4); + EXPECT_EQ(CountControlEdges(*body), 2); + + EXPECT_THAT(body->root_instruction(), + op::Tuple(op::Copy(op::Copy()), op::Copy(op::Copy()))); + + EXPECT_EQ(CountCopies(*module->entry_computation()), 2); + EXPECT_THAT(xla_while->operand(0), op::Tuple(op::Copy(), op::Copy())); +} + +TEST_F(CopyInsertionTest, SwizzlingWhileWithOneOp) { + // Test a while instruction with a body which permutes its tuple parameter + // elements and applies one operation to one of the elements. The addition of + // the operation (instruction) on the element makes the live range of the + // respective input and output elements different than if the instruction were + // not there (as in the SwizzlingWhile test above). + auto module = CreateNewModule(); + const Shape loop_state_shape = + ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); + + // Body interchanges the two tuple elements in the loop state and negates one + // of them. + auto body_builder = HloComputation::Builder("body"); + auto body_param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, loop_state_shape, "param")); + auto body_element_0 = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0)); + auto body_element_1 = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1)); + auto negate = body_builder.AddInstruction(HloInstruction::CreateUnary( + scalar_shape_, HloOpcode::kNegate, body_element_1)); + body_builder.AddInstruction( + HloInstruction::CreateTuple({negate, body_element_0})); + HloComputation* body = module->AddEmbeddedComputation(body_builder.Build()); + + auto cond_builder = HloComputation::Builder("condition"); + cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, loop_state_shape, "param")); + auto cond_constant = cond_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(false))); + cond_builder.AddInstruction(HloInstruction::CreateUnary( + cond_constant->shape(), HloOpcode::kNot, cond_constant)); + HloComputation* condition = + module->AddEmbeddedComputation(cond_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + auto tuple = builder.AddInstruction( + HloInstruction::CreateTuple({constant1, constant2})); + auto xla_while = builder.AddInstruction( + HloInstruction::CreateWhile(loop_state_shape, condition, body, tuple)); + module->AddEntryComputation(builder.Build()); + + InsertCopies(module.get()); + + EXPECT_EQ(CountCopies(*module), 6); + + // The loop state elements should be copied at the parameter and at the root + // with a control edge in between (see DeepCopyAndAddControlEdges). + EXPECT_EQ(CountCopies(*body), 4); + EXPECT_EQ(CountControlEdges(*body), 2); + + EXPECT_THAT( + body->root_instruction(), + op::Tuple(op::Copy(op::Negate(op::Copy())), op::Copy(op::Copy()))); + + EXPECT_EQ(CountCopies(*module->entry_computation()), 2); + EXPECT_THAT(xla_while->operand(0), op::Tuple(op::Copy(), op::Copy())); +} + +TEST_F(CopyInsertionTest, SwizzlingWhileSharedInput) { + // Test a while instruction with a body which permutes it's tuple parameter + // elements similar to SwizzlinWhile above. However, in this test the input to + // the while body is a single constant (both loop state elements are the same + // constant). This means no copies are necessary because both loop state + // elements are the same so interchanging them is a no-op. + auto module = CreateNewModule(); + const Shape loop_state_shape = + ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); + + // Body simply interchanges the two tuple elements in the loop state. + auto body_builder = HloComputation::Builder("body"); + auto body_param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, loop_state_shape, "param")); + auto body_element_0 = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0)); + auto body_element_1 = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1)); + body_builder.AddInstruction( + HloInstruction::CreateTuple({body_element_1, body_element_0})); + HloComputation* body = module->AddEmbeddedComputation(body_builder.Build()); + + auto cond_builder = HloComputation::Builder("condition"); + cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, loop_state_shape, "param")); + auto cond_constant = cond_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(false))); + cond_builder.AddInstruction(HloInstruction::CreateUnary( + cond_constant->shape(), HloOpcode::kNot, cond_constant)); + HloComputation* condition = + module->AddEmbeddedComputation(cond_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + auto constant = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + auto tuple = + builder.AddInstruction(HloInstruction::CreateTuple({constant, constant})); + builder.AddInstruction( + HloInstruction::CreateWhile(loop_state_shape, condition, body, tuple)); + module->AddEntryComputation(builder.Build()); + + InsertCopies(module.get()); + + EXPECT_EQ(CountCopies(*module), 2); + EXPECT_EQ(CountCopies(*body), 0); + + EXPECT_EQ(CountCopies(*module->entry_computation()), 2); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Tuple(op::Copy(), op::Copy())); +} + +TEST_F(CopyInsertionTest, SequentialWhiles) { + // Construct a computation with a series of sequential while instructions + // containing four loop state elements: + // + // element 0 is passed to each while directly from an entry parameter. + // + // element 1 is passed transparently in series through all the while bodies. + // + // element 2 is negated in each while body. (in-place possible) + // + // element 3 is reversed in each while body. (in-place not possible) + // + const Shape element_shape = ShapeUtil::MakeShape(F32, {42}); + const Shape loop_state_shape = ShapeUtil::MakeTupleShape( + {element_shape, element_shape, element_shape, element_shape}); + + auto module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + auto param_0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, element_shape, "param_0")); + auto param_1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, element_shape, "param_1")); + auto param_2 = builder.AddInstruction( + HloInstruction::CreateParameter(2, element_shape, "param_2")); + auto param_3 = builder.AddInstruction( + HloInstruction::CreateParameter(3, element_shape, "param_3")); + + // The number of sequential kWhile instructions. + const int kNumWhiles = 3; + + HloInstruction* prev_element_1 = param_1; + HloInstruction* prev_element_2 = param_2; + HloInstruction* prev_element_3 = param_3; + + // Vector containing all of the while instructions. + std::vector whiles; + for (int i = 0; i < kNumWhiles; ++i) { + auto body_builder = HloComputation::Builder("body"); + auto body_param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, loop_state_shape, "param")); + auto body_element_0 = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(element_shape, body_param, 0)); + auto body_element_1 = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(element_shape, body_param, 1)); + auto body_element_2 = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(element_shape, body_param, 2)); + auto body_element_3 = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(element_shape, body_param, 3)); + auto negate = body_builder.AddInstruction(HloInstruction::CreateUnary( + element_shape, HloOpcode::kNegate, body_element_2)); + auto reverse = body_builder.AddInstruction( + HloInstruction::CreateReverse(element_shape, body_element_3, {0})); + body_builder.AddInstruction(HloInstruction::CreateTuple( + {body_element_0, body_element_1, negate, reverse})); + HloComputation* body = module->AddEmbeddedComputation(body_builder.Build()); + + auto cond_builder = HloComputation::Builder("condition"); + cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, loop_state_shape, "param")); + auto cond_constant = cond_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(false))); + cond_builder.AddInstruction(HloInstruction::CreateUnary( + cond_constant->shape(), HloOpcode::kNot, cond_constant)); + HloComputation* condition = + module->AddEmbeddedComputation(cond_builder.Build()); + + auto while_init = builder.AddInstruction(HloInstruction::CreateTuple( + {param_0, prev_element_1, prev_element_2, prev_element_3})); + + auto xla_while = builder.AddInstruction(HloInstruction::CreateWhile( + loop_state_shape, condition, body, while_init)); + whiles.push_back(xla_while); + if (i != kNumWhiles - 1) { + prev_element_1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(element_shape, xla_while, 1)); + prev_element_2 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(element_shape, xla_while, 2)); + prev_element_3 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(element_shape, xla_while, 3)); + } + } + + module->AddEntryComputation(builder.Build()); + + InsertCopies(module.get()); + + // Each while body has one copy. And each loop state element is copied once in + // the entry computation. + EXPECT_EQ(CountCopies(*module), 4 + kNumWhiles); + + // Each while body should have exactly one copy for element three which is an + // op (kReverse) which cannot be done in place. + for (const HloInstruction* xla_while : whiles) { + EXPECT_EQ(CountCopies(*xla_while->while_body()), 1); + } + + EXPECT_THAT(whiles[0]->operand(0), op::Tuple(op::Parameter(), op::Parameter(), + op::Copy(), op::Copy())); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Tuple(op::Copy(), op::Copy(), op::GetTupleElement(), + op::GetTupleElement())); +} + +TEST_F(CopyInsertionTest, WhileBodyWithConstantRoot) { + // Test a while body and condition which are each simply a constant (root of + // computation is a constant). The body constant should be copied. + auto module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + auto param_0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "param_0")); + + auto body_builder = HloComputation::Builder("body"); + body_builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "param")); + body_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(123.0))); + HloComputation* body = module->AddEmbeddedComputation(body_builder.Build()); + + auto cond_builder = HloComputation::Builder("condition"); + cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "param")); + cond_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloComputation* condition = + module->AddEmbeddedComputation(cond_builder.Build()); + + auto xla_while = builder.AddInstruction( + HloInstruction::CreateWhile(scalar_shape_, condition, body, param_0)); + + module->AddEntryComputation(builder.Build()); + + InsertCopies(module.get()); + + EXPECT_EQ(CountCopies(*module), 2); + + EXPECT_THAT(xla_while->operand(0), op::Copy(op::Parameter())); + EXPECT_THAT(body->root_instruction(), op::Copy(op::Constant())); + EXPECT_THAT(condition->root_instruction(), op::Constant()); +} + +std::unique_ptr MakeTrivialCondition(const Shape& shape) { + auto builder = HloComputation::Builder("trivial_condition"); + builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "loop_state")); + auto constant = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(false))); + builder.AddInstruction(HloInstruction::CreateUnary( + constant->shape(), HloOpcode::kNot, constant)); + return builder.Build(); +} + +std::unique_ptr MakeBenchmarkWhileBody() { + auto builder = HloComputation::Builder("benchmark_loop_body"); + const Shape element_shape = ShapeUtil::MakeShape(F32, {42}); + const Shape loop_state_shape = + ShapeUtil::MakeTupleShape({element_shape, element_shape, element_shape}); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, loop_state_shape, "loop_state")); + HloInstruction* element_0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(element_shape, param, 0)); + HloInstruction* element_1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(element_shape, param, 1)); + HloInstruction* element_2 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(element_shape, param, 2)); + + HloInstruction* rev_1 = builder.AddInstruction( + HloInstruction::CreateReverse(element_shape, element_1, {0})); + HloInstruction* add_1_2 = builder.AddInstruction(HloInstruction::CreateBinary( + element_shape, HloOpcode::kAdd, element_1, element_2)); + + builder.AddInstruction( + HloInstruction::CreateTuple({element_0, rev_1, add_1_2})); + return builder.Build(); +} + +void BM_SequentialWhiles(int num_iters, int num_whiles) { + // This benchmark constructs a chain of sequential while instructions. + tensorflow::testing::StopTiming(); + for (int i = 0; i < num_iters; ++i) { + HloModuleConfig config; + config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); + HloModule module("BM_SequentialWhiles", VersionedComputationHandle(), + config); + + auto builder = HloComputation::Builder("BM_SequentialWhiles"); + HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {42}), "x")); + HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {42}), "y")); + HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter( + 2, ShapeUtil::MakeShape(F32, {42}), "z")); + HloInstruction* init = + builder.AddInstruction(HloInstruction::CreateTuple({x, y, z})); + + HloInstruction* prev_loop_state = init; + for (int w = 0; w < num_whiles; ++w) { + HloComputation* condition = + module.AddEmbeddedComputation(MakeTrivialCondition(init->shape())); + HloComputation* body = + module.AddEmbeddedComputation(MakeBenchmarkWhileBody()); + prev_loop_state = builder.AddInstruction(HloInstruction::CreateWhile( + init->shape(), condition, body, prev_loop_state)); + } + module.AddEntryComputation(builder.Build()); + + CopyInsertion copy_insertion; + + tensorflow::testing::StartTiming(); + ASSERT_IS_OK(copy_insertion.Run(&module).status()); + tensorflow::testing::StopTiming(); + + // The entry computation should have three copies, and each body has one. + ASSERT_EQ(CountCopies(module), 3 + num_whiles); + } +} + +void BM_ParallelWhiles(int num_iters, int num_whiles) { + // This benchmark constructs a fan-out of parallel while instructions. + tensorflow::testing::StopTiming(); + for (int i = 0; i < num_iters; ++i) { + HloModuleConfig config; + config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); + HloModule module("BM_SequentialWhiles", VersionedComputationHandle(), + config); + + auto builder = HloComputation::Builder("BM_ParallelWhiles"); + HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {42}), "x")); + HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {42}), "y")); + HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter( + 2, ShapeUtil::MakeShape(F32, {42}), "z")); + HloInstruction* init = + builder.AddInstruction(HloInstruction::CreateTuple({x, y, z})); + + HloInstruction* sum = nullptr; + for (int w = 0; w < num_whiles; ++w) { + HloComputation* condition = + module.AddEmbeddedComputation(MakeTrivialCondition(init->shape())); + HloComputation* body = + module.AddEmbeddedComputation(MakeBenchmarkWhileBody()); + + HloInstruction* xla_while = builder.AddInstruction( + HloInstruction::CreateWhile(init->shape(), condition, body, init)); + + if (sum == nullptr) { + sum = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(x->shape(), xla_while, 0)); + } else { + HloInstruction* element_0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(x->shape(), xla_while, 0)); + sum = builder.AddInstruction(HloInstruction::CreateBinary( + x->shape(), HloOpcode::kAdd, sum, element_0)); + } + } + module.AddEntryComputation(builder.Build()); + + CopyInsertion copy_insertion; + + tensorflow::testing::StartTiming(); + ASSERT_IS_OK(copy_insertion.Run(&module).status()); + tensorflow::testing::StopTiming(); + + // Each body receives of copy of two of the parameters (the corresponding + // elements in the body are modifed), and there is one copy in each body. + ASSERT_EQ(CountCopies(module), 3 * num_whiles); + } +} + +BENCHMARK(BM_SequentialWhiles)->Arg(512)->Arg(1024)->Arg(2048)->Arg(4096); +BENCHMARK(BM_ParallelWhiles)->Arg(512)->Arg(1024)->Arg(2048)->Arg(4096); + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index ef8eed3f88c3d557fcb4ec5b9e1988ce82b777e8..b43597dca983151d59ec7aaba9887313191fc9bd 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -17,6 +17,7 @@ package_group( load(":build_defs.bzl", "runtime_copts") load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("//tensorflow:tensorflow.bzl", "tf_cc_binary") +load("//tensorflow/compiler/xla:xla.bzl", "ORC_JIT_MEMORY_MAPPER_TARGETS") # Filegroup used to collect source files for dependency checking. filegroup( @@ -78,14 +79,16 @@ cc_library( deps = [ ":compiler_functor", ":conv_canonicalization", + ":cpu_copy_insertion", ":cpu_executable", ":cpu_instruction_fusion", + ":cpu_layout_assignment", ":cpu_options", ":cpu_parallelization_preparation", ":disassembler", + ":dot_op_emitter", ":ir_emission_utils", ":ir_emitter", - ":layout_assignment", ":parallel_cpu_executable", ":parallel_task_assignment", ":simple_orc_jit", @@ -101,13 +104,14 @@ cc_library( "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:buffer_liveness", "//tensorflow/compiler/xla/service:call_inliner", - "//tensorflow/compiler/xla/service:copy_insertion", + "//tensorflow/compiler/xla/service:dot_decomposer", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:flatten_call_graph", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_constant_folding", "//tensorflow/compiler/xla/service:hlo_cse", "//tensorflow/compiler/xla/service:hlo_dce", + "//tensorflow/compiler/xla/service:hlo_element_type_converter", "//tensorflow/compiler/xla/service:hlo_ordering", "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/compiler/xla/service:hlo_pass_pipeline", @@ -122,6 +126,7 @@ cc_library( "//tensorflow/compiler/xla/service:reshape_mover", "//tensorflow/compiler/xla/service:transpose_folding", "//tensorflow/compiler/xla/service:tuple_simplifier", + "//tensorflow/compiler/xla/service:while_loop_simplifier", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", # fixdeps: keep "//tensorflow/core:lib", # fixdeps: keep "//tensorflow/core:stream_executor_no_cuda", @@ -155,21 +160,23 @@ cc_library( ":custom_call_target_registry", ":disassembler", ":external_constant_pool", + ":orc_jit_memory_mapper", ":runtime_conv2d", ":runtime_fork_join", ":runtime_matmul", ":runtime_single_threaded_conv2d", ":runtime_single_threaded_matmul", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", - "//tensorflow/core:lib", - "@llvm//:core", "@llvm//:execution_engine", + "@llvm//:core", "@llvm//:mc", # fixdeps: keep "@llvm//:orc_jit", "@llvm//:support", "@llvm//:target", # fixdeps: keep - ], + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + ] + ORC_JIT_MEMORY_MAPPER_TARGETS, ) cc_library( @@ -245,6 +252,8 @@ cc_library( ":dot_op_emitter", ":external_constant_pool", ":ir_emission_utils", + ":ir_function", + ":parallel_loop_emitter", ":shape_partition", ":simple_orc_jit", "//tensorflow/compiler/xla:shape_util", @@ -268,19 +277,54 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:ops", "//tensorflow/compiler/xla/service/llvm_ir:tuple_ops", "//tensorflow/core:lib", + "@llvm//:code_gen", "@llvm//:core", "@llvm//:support", "@llvm//:target", ], ) +cc_library( + name = "ir_function", + srcs = ["ir_function.cc"], + hdrs = ["ir_function.h"], + deps = [ + ":ir_emission_utils", + ":shape_partition", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla/service/cpu:cpu_runtime", + "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", + "//tensorflow/core:lib", + "@llvm//:core", + ], +) + +cc_library( + name = "parallel_loop_emitter", + srcs = ["parallel_loop_emitter.cc"], + hdrs = ["parallel_loop_emitter.h"], + deps = [ + ":ir_emission_utils", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service/llvm_ir:ir_array", + "//tensorflow/compiler/xla/service/llvm_ir:llvm_loop", + "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", + "//tensorflow/compiler/xla/service/llvm_ir:loop_emitter", + "//tensorflow/core:lib", + "@llvm//:core", + ], +) + cc_library( name = "dot_op_emitter", srcs = ["dot_op_emitter.cc"], hdrs = ["dot_op_emitter.h"], deps = [ + ":cpu_options", ":cpu_runtime", - ":ir_emission_utils", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:types", @@ -289,8 +333,10 @@ cc_library( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_module_config", "//tensorflow/compiler/xla/service/llvm_ir:ir_array", + "//tensorflow/compiler/xla/service/llvm_ir:kernel_support_library", "//tensorflow/compiler/xla/service/llvm_ir:llvm_loop", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", + "//tensorflow/compiler/xla/service/llvm_ir:vector_support_library", "//tensorflow/core:lib", "@llvm//:core", ], @@ -607,14 +653,16 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla/service:hlo", + "@llvm//:core", ], ) cc_library( - name = "layout_assignment", - srcs = ["layout_assignment.cc"], - hdrs = ["layout_assignment.h"], + name = "cpu_layout_assignment", + srcs = ["cpu_layout_assignment.cc"], + hdrs = ["cpu_layout_assignment.h"], deps = [ + ":dot_op_emitter", ":ir_emission_utils", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:computation_layout", @@ -624,11 +672,11 @@ cc_library( ) tf_cc_test( - name = "layout_assignment_test", + name = "cpu_layout_assignment_test", size = "small", - srcs = ["layout_assignment_test.cc"], + srcs = ["cpu_layout_assignment_test.cc"], deps = [ - ":layout_assignment", + ":cpu_layout_assignment", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_layout", "//tensorflow/compiler/xla:shape_util", @@ -702,6 +750,7 @@ cc_library( srcs = ["parallel_task_assignment.cc"], hdrs = ["parallel_task_assignment.h"], deps = [ + ":dot_op_emitter", ":ir_emission_utils", ":shape_partition", "//tensorflow/compiler/xla/service:hlo", @@ -716,6 +765,7 @@ cc_library( hdrs = ["cpu_options.h"], deps = [ "//tensorflow/compiler/xla/service:hlo_module_config", + "//tensorflow/core:lib", ], ) @@ -730,6 +780,48 @@ cc_library( visibility = ["//visibility:public"], ) +cc_library( + name = "orc_jit_memory_mapper", + srcs = ["orc_jit_memory_mapper.cc"], + hdrs = ["orc_jit_memory_mapper.h"], + deps = [ + "//tensorflow/core:lib", + "@llvm//:execution_engine", + ], +) + +cc_library( + name = "cpu_copy_insertion", + srcs = ["cpu_copy_insertion.cc"], + hdrs = ["cpu_copy_insertion.h"], + deps = [ + "//tensorflow/compiler/xla/service:copy_insertion", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "cpu_copy_insertion_test", + srcs = ["cpu_copy_insertion_test.cc"], + deps = [ + ":cpu_copy_insertion", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_graph_dumper", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + ], +) + # ----------------------------------------------------------------------------- filegroup( diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc index 44cd2171afdc6eecc22f3f920276a4d95f930573..2136aeb3877685373efaf5bf702a42b39a63f082 100644 --- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc @@ -41,19 +41,17 @@ StatusOr ConvCanonicalization::Run(HloModule* module) { auto kernel_input_feature_dim = dnums.kernel_input_feature_dimension(); auto kernel_output_feature_dim = dnums.kernel_output_feature_dimension(); - int num_spatial_dims = dnums.spatial_dimensions_size(); - int num_dims = num_spatial_dims + 2; + const int64 num_spatial_dims = dnums.output_spatial_dimensions_size(); + const int64 num_dims = num_spatial_dims + 2; // A canonical convolution's dimension numbers need to satisfy the // following conditions (see cs/PotentiallyImplementedAsEigenConvolution). // - // - the input is in NHWC or NWHC order. - // - the kernel is in HWIO or WHIO order. - // - the spatial dimensions are in the same relative order in the input, - // kernel and output. + // - the input is in NHWC order. + // - the kernel is in HWIO order. // // For simplicity, as a first step, we reshape the input and filter to - // NHWC and HWIO order, respectively. This may lose precision but not + // NHWC and HWIO order, respectively. This may lose precision but won't // break the soundness. HloInstruction* input = hlo->mutable_operand(0); @@ -61,10 +59,10 @@ StatusOr ConvCanonicalization::Run(HloModule* module) { std::vector new_input_dims(num_dims); new_input_dim_order[0] = input_batch_dim; new_input_dims[0] = input->shape().dimensions(input_batch_dim); - for (int i = 0; i < num_spatial_dims; ++i) { - new_input_dim_order[i + 1] = dnums.spatial_dimensions(i); + for (int64 i = 0; i < num_spatial_dims; ++i) { + new_input_dim_order[i + 1] = dnums.input_spatial_dimensions(i); new_input_dims[i + 1] = - input->shape().dimensions(dnums.spatial_dimensions(i)); + input->shape().dimensions(dnums.input_spatial_dimensions(i)); } new_input_dim_order[num_dims - 1] = input_feature_dim; new_input_dims[num_dims - 1] = @@ -80,7 +78,7 @@ StatusOr ConvCanonicalization::Run(HloModule* module) { std::vector new_kernel_dim_order(num_dims); std::vector new_kernel_dims(num_dims); - for (int i = 0; i < num_spatial_dims; ++i) { + for (int64 i = 0; i < num_spatial_dims; ++i) { new_kernel_dim_order[i] = dnums.kernel_spatial_dimensions(i); new_kernel_dims[i] = kernel->shape().dimensions(dnums.kernel_spatial_dimensions(i)); @@ -98,14 +96,18 @@ StatusOr ConvCanonicalization::Run(HloModule* module) { HloInstruction::CreateTranspose(new_kernel_shape, kernel, new_kernel_dim_order)); + std::vector new_output_dim_order(num_dims); std::vector new_conv_dims(num_dims); auto output_batch_dim = dnums.output_batch_dimension(); auto output_feature_dim = dnums.output_feature_dimension(); + new_output_dim_order[0] = output_batch_dim; new_conv_dims[0] = hlo->shape().dimensions(output_batch_dim); - for (int i = 0; i < num_spatial_dims; ++i) { + for (int64 i = 0; i < num_spatial_dims; ++i) { + new_output_dim_order[i + 1] = dnums.output_spatial_dimensions(i); new_conv_dims[i + 1] = - hlo->shape().dimensions(dnums.spatial_dimensions(i)); + hlo->shape().dimensions(dnums.output_spatial_dimensions(i)); } + new_output_dim_order[num_dims - 1] = output_feature_dim; new_conv_dims[num_dims - 1] = hlo->shape().dimensions(output_feature_dim); Shape new_conv_shape = ShapeUtil::MakeShape(hlo->shape().element_type(), new_conv_dims); @@ -113,9 +115,10 @@ StatusOr ConvCanonicalization::Run(HloModule* module) { ConvolutionDimensionNumbers new_dnums; new_dnums.set_input_batch_dimension(0); new_dnums.set_output_batch_dimension(0); - for (int i = 0; i < num_spatial_dims; ++i) { - new_dnums.add_spatial_dimensions(i + 1); + for (int64 i = 0; i < num_spatial_dims; ++i) { + new_dnums.add_input_spatial_dimensions(i + 1); new_dnums.add_kernel_spatial_dimensions(i); + new_dnums.add_output_spatial_dimensions(i + 1); } new_dnums.set_input_feature_dimension(num_dims - 1); new_dnums.set_output_feature_dimension(num_dims - 1); @@ -129,14 +132,11 @@ StatusOr ConvCanonicalization::Run(HloModule* module) { HloInstruction::CreateConvolve(new_conv_shape, new_input, new_kernel, hlo->window(), new_dnums)); - // kConvolution inherits the dimension mapping of its input, so we need to - // reshape the output back to the shape of the original convolution. This - // is done by apply the inverse permutation of the collapsing order of the - // input reshape. + // Reshape the output back to the shape of the original convolution. TF_RETURN_IF_ERROR(module->entry_computation()->ReplaceWithNewInstruction( hlo, HloInstruction::CreateTranspose( hlo->shape(), new_conv, - InversePermutation(new_input_dim_order)))); + InversePermutation(new_output_dim_order)))); changed = true; } } diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc index d593ba26b655d00a0f0f0b9a94c9e62fa1835080..968f53d5c706651d2a470a853e0e9b601c0ed2df 100644 --- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc @@ -69,8 +69,10 @@ TEST_F(ConvCanonicalizationTest, NonCanonicalToCanonical) { ConvolutionDimensionNumbers dnums; dnums.set_input_batch_dimension(1); dnums.set_output_batch_dimension(1); - dnums.add_spatial_dimensions(2); - dnums.add_spatial_dimensions(3); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.add_input_spatial_dimensions(3); + dnums.add_output_spatial_dimensions(3); dnums.set_input_feature_dimension(0); dnums.set_output_feature_dimension(0); dnums.add_kernel_spatial_dimensions(2); @@ -125,8 +127,10 @@ TEST_F(ConvCanonicalizationTest, CanonicalStaysTheSame) { ConvolutionDimensionNumbers dnums; dnums.set_input_batch_dimension(0); dnums.set_output_batch_dimension(0); - dnums.add_spatial_dimensions(1); - dnums.add_spatial_dimensions(2); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); dnums.set_input_feature_dimension(3); dnums.set_output_feature_dimension(3); dnums.add_kernel_spatial_dimensions(0); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index e141066b8fb48896e9f88e0a98f74aad08b63799..55e7c7bc2ca05991ac6dd53bf48bc9fd30f52601 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -46,27 +46,30 @@ limitations under the License. #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/buffer_liveness.h" #include "tensorflow/compiler/xla/service/call_inliner.h" -#include "tensorflow/compiler/xla/service/copy_insertion.h" #include "tensorflow/compiler/xla/service/cpu/compiler_functor.h" #include "tensorflow/compiler/xla/service/cpu/conv_canonicalization.h" +#include "tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h" #include "tensorflow/compiler/xla/service/cpu/cpu_executable.h" #include "tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h" +#include "tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h" #include "tensorflow/compiler/xla/service/cpu/cpu_options.h" #include "tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.h" #include "tensorflow/compiler/xla/service/cpu/disassembler.h" +#include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h" #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/cpu/ir_emitter.h" -#include "tensorflow/compiler/xla/service/cpu/layout_assignment.h" #include "tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h" #include "tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h" #include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/dot_decomposer.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_constant_folding.h" #include "tensorflow/compiler/xla/service/hlo_cse.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" +#include "tensorflow/compiler/xla/service/hlo_element_type_converter.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" @@ -82,6 +85,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/reshape_mover.h" #include "tensorflow/compiler/xla/service/transpose_folding.h" #include "tensorflow/compiler/xla/service/tuple_simplifier.h" +#include "tensorflow/compiler/xla/service/while_loop_simplifier.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -195,28 +199,35 @@ void InitializeLLVMCommandLineOptions(const HloModuleConfig& config) { class CollectProfileCandidates : public DfsHloVisitorWithDefault { public: static StatusOr> - GetCandidatesForComputation(HloComputation* computation) { + GetCandidatesForComputation( + HloComputation* computation, + const std::unordered_map& + assigned_indices) { std::unordered_map hlo_to_profile_idx; CollectProfileCandidates profile_candidates_for_computation( - &hlo_to_profile_idx); + &hlo_to_profile_idx, assigned_indices); TF_RETURN_IF_ERROR( computation->Accept(&profile_candidates_for_computation)); return hlo_to_profile_idx; } private: - explicit CollectProfileCandidates( - std::unordered_map* hlo_to_profile_idx) - : hlo_to_profile_idx_(hlo_to_profile_idx) {} + CollectProfileCandidates( + std::unordered_map* hlo_to_profile_idx, + const std::unordered_map& assigned_indices) + : hlo_to_profile_idx_(hlo_to_profile_idx), + assigned_indices_(assigned_indices) {} Status DefaultAction(HloInstruction* hlo_instruction) override { - hlo_to_profile_idx_->insert({hlo_instruction, hlo_to_profile_idx_->size()}); + hlo_to_profile_idx_->insert( + {hlo_instruction, FindOrDie(assigned_indices_, hlo_instruction)}); return Status::OK(); } Status HandleCall(HloInstruction* call) override { TF_RETURN_IF_ERROR(DefaultAction(call)); - CollectProfileCandidates candidates_for_call(hlo_to_profile_idx_); + CollectProfileCandidates candidates_for_call(hlo_to_profile_idx_, + assigned_indices_); TF_RETURN_IF_ERROR(call->to_apply()->Accept(&candidates_for_call)); return Status::OK(); } @@ -230,17 +241,20 @@ class CollectProfileCandidates : public DfsHloVisitorWithDefault { Status HandleWhile(HloInstruction* xla_while) override { TF_RETURN_IF_ERROR(DefaultAction(xla_while)); - CollectProfileCandidates candidates_for_condition(hlo_to_profile_idx_); + CollectProfileCandidates candidates_for_condition(hlo_to_profile_idx_, + assigned_indices_); TF_RETURN_IF_ERROR( xla_while->while_condition()->Accept(&candidates_for_condition)); - CollectProfileCandidates candidates_for_body(hlo_to_profile_idx_); + CollectProfileCandidates candidates_for_body(hlo_to_profile_idx_, + assigned_indices_); TF_RETURN_IF_ERROR(xla_while->while_body()->Accept(&candidates_for_body)); return Status::OK(); } std::unordered_map* hlo_to_profile_idx_; + const std::unordered_map& assigned_indices_; }; } // namespace @@ -260,7 +274,7 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { // TODO(b/65775800): Fix wrong output bug in Call and remove the CallInliner // pass. pipeline.AddPass(); - + pipeline.AddPass(); pipeline.AddPass(); { auto& pass = @@ -275,8 +289,9 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { pass.AddPass( /*is_layout_sensitive=*/false, [](const Shape&, const Shape&) { return false; }, - /*enable_dot_simplification=*/false); + /*enable_dot_strength_reduction=*/false); pass.AddPass(); + pass.AddPass(); pass.AddPass(); pass.AddPass(); pass.AddPass(); @@ -303,8 +318,9 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { pipeline.AddPass>( /*is_layout_sensitive=*/true, [](const Shape&, const Shape&) { return true; }, - /*enable_dot_simplification=*/false); + /*enable_dot_strength_reduction=*/false); pipeline.AddPass(/*is_layout_sensitive=*/true); + pipeline.AddPass(BF16, F32); // Outline ops in the entry computation into calls to subcomputations. const int max_parallelism = module->config().intra_op_parallelism_threads() > 0 @@ -320,7 +336,7 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { // binary size (and most AOT applications are single-threaded). // TODO(29630486) Support multi-threaded AOT. pipeline.AddPass(max_parallelism, - ShapeSizeBytesFunction(), module); + ShapeSizeBytesFunction()); } // Copy insertion should be performed immediately before IR emission to avoid // inserting unnecessary copies (later pass adds an instruction which @@ -329,15 +345,16 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { // (and sometime after) copy insertion, to avoid dead code from interfering // with the rewrites. pipeline.AddPass(); - pipeline.AddPass(); + pipeline.AddPass(); + pipeline.AddPass(); if (options::CpuParallelBackendRequested(module->config())) { // Re-run the outlining, in case any copies were inserted into the entry // computation. pipeline.AddPass(max_parallelism, ShapeSizeBytesFunction()); + pipeline.AddPass(); } pipeline.AddPass(); - pipeline.AddPass(); return pipeline.Run(module).status(); } @@ -423,11 +440,25 @@ Status InitializeModuleHooks( } // namespace -StatusOr> CpuCompiler::Compile( - std::unique_ptr module, se::StreamExecutor* stream_exec) { +StatusOr> CpuCompiler::RunHloPasses( + std::unique_ptr module, + perftools::gputools::StreamExecutor* /*stream_exec*/) { + VLOG(2) << "Before optimization:"; + XLA_VLOG_LINES(2, module->ToString()); + + TF_RETURN_IF_ERROR(RunHloPasses(module.get(), /*is_aot_compile=*/false)); + + VLOG(2) << "After optimization:"; + XLA_VLOG_LINES(2, module->ToString()); + return std::move(module); +} + +StatusOr> CpuCompiler::RunBackend( + std::unique_ptr module, + perftools::gputools::StreamExecutor* stream_exec) { const string timer_message = "Compiling [" + module->name() + "] for CPU using JIT"; - ScopedLoggingTimer compiling_timer(timer_message, 1); + XLA_SCOPED_LOGGING_TIMER(timer_message); VLOG(1) << "Compiling: " << module->name(); TF_RET_CHECK(stream_exec != nullptr); @@ -441,11 +472,11 @@ StatusOr> CpuCompiler::Compile( &pre_optimization_ir_hook, &post_optimization_ir_hook)); // Compile must be thread-safe so create a new LLVM context for the module. - auto llvm_context = MakeUnique(); + auto llvm_context = xla::MakeUnique(); auto llvm_module = - MakeUnique("__compute_module", *llvm_context); + xla::MakeUnique("__compute_module", *llvm_context); - auto jit = MakeUnique( + auto jit = xla::MakeUnique( CompilerTargetOptions(module->config()), CodeGenOptLevel(module->config()), options::OptimizeForSizeRequested(module->config()), @@ -455,14 +486,29 @@ StatusOr> CpuCompiler::Compile( llvm_module->setDataLayout(jit->data_layout()); llvm_module->setTargetTriple(jit->target_triple().getTriple()); - TF_RETURN_IF_ERROR(RunHloPasses(module.get(), /*is_aot_compile=*/false)); - HloComputation* computation = module->entry_computation(); std::unordered_map hlo_to_profile_idx; + std::unique_ptr hlo_profile_index_map; + std::unique_ptr hlo_profile_printer; if (module->config().hlo_profiling_enabled()) { + hlo_profile_index_map = MakeUnique(*module); + TF_ASSIGN_OR_RETURN( hlo_to_profile_idx, - CollectProfileCandidates::GetCandidatesForComputation(computation)); + CollectProfileCandidates::GetCandidatesForComputation( + computation, hlo_profile_index_map->instruction_to_profile_idx())); + + auto shape_size_bytes = [](const Shape& shape) { + // On the cpu, opaques are pointers. + if (ShapeUtil::IsOpaque(shape)) { + return static_cast(sizeof(void*)); + } + return ShapeUtil::ByteSizeOf(shape, sizeof(void*)); + }; + + HloCostAnalysis cost_analysis(shape_size_bytes); + hlo_profile_printer = + CreateHloProfilePrinter(*hlo_profile_index_map, cost_analysis); } std::unique_ptr cpu_executable; @@ -485,9 +531,9 @@ StatusOr> CpuCompiler::Compile( // uses data dependencies for determining order. TF_ASSIGN_OR_RETURN( std::unique_ptr assignment, - BufferAssigner::Run(module.get(), - MakeUnique(module.get()), - BufferSizeBytesFunction(), memory_alignment)); + BufferAssigner::Run( + module.get(), xla::MakeUnique(module.get()), + BufferSizeBytesFunction(), memory_alignment)); // BufferAssignment::ToString() includes a header, so no need for us to // print one ourselves. XLA_VLOG_LINES(2, assignment->ToString()); @@ -514,7 +560,7 @@ StatusOr> CpuCompiler::Compile( const void* data = instruction->literal().InternalData(); int64 size = CpuExecutable::ShapeSizeBytes(instruction->shape()); auto iter = aligned_constants.emplace( - instruction, MakeUnique(size)); + instruction, xla::MakeUnique(size)); CHECK_EQ(iter.second, true); unsigned char* aligned_data = iter.first->second.get(); memcpy(aligned_data, data, size); @@ -528,12 +574,20 @@ StatusOr> CpuCompiler::Compile( parallel_computations.emplace(to_apply, instruction); } + // We always profile the entire computation as a whole, even if hlo + // profiling is disabled. When hlo profiling is diabled, we pass in a + // profile counter array of just one element, which corresponds to the whole + // computation. + size_t entry_computation_profile_idx = + hlo_profile_index_map ? hlo_profile_index_map->GetProfileIndexFor( + *module->entry_computation()) + : 0; IrEmitter ir_emitter(*module, *assignment, llvm_module.get(), - &hlo_to_profile_idx, jit->target_machine(), - jit->external_constant_pool()); + hlo_to_profile_idx, entry_computation_profile_idx, + jit->target_machine(), jit->external_constant_pool()); - std::unique_ptr> function_names( - new std::map()); + std::unique_ptr> function_names( + new HloInstructionMap()); for (auto embedded_computation : computation->MakeEmbeddedComputationsList()) { if (embedded_computation->IsFusionComputation()) { @@ -549,7 +603,7 @@ StatusOr> CpuCompiler::Compile( llvm::Function * ir_function, ir_emitter.EmitComputation( embedded_computation, embedded_computation->name(), - /*is_entry_computation=*/computation_is_parallel, + /*is_top_level_computation=*/computation_is_parallel, /*instruction_order=*/nullptr)); // If this computation is parallel, remember it in the function name map. // This way we know what function to execute when we try to run code for @@ -570,8 +624,8 @@ StatusOr> CpuCompiler::Compile( jit->AddModule(std::move(llvm_module)); cpu_executable.reset(new ParallelCpuExecutable( std::move(jit), std::move(assignment), std::move(module), - std::move(function_names), std::move(hlo_to_profile_idx), - std::move(aligned_constants))); + std::move(function_names), std::move(aligned_constants), + std::move(hlo_profile_printer), std::move(hlo_profile_index_map))); if (embed_ir_in_executable) { static_cast(*cpu_executable) @@ -591,10 +645,10 @@ StatusOr> CpuCompiler::Compile( // temporary buffers are required to run the computation. TF_ASSIGN_OR_RETURN( std::unique_ptr assignment, - BufferAssigner::Run( - module.get(), - MakeUnique(module.get(), module_sequence), - BufferSizeBytesFunction(), memory_alignment)); + BufferAssigner::Run(module.get(), + xla::MakeUnique( + module.get(), module_sequence), + BufferSizeBytesFunction(), memory_alignment)); // BufferAssignment::ToString() includes a header, so no need for us to // print one ourselves. XLA_VLOG_LINES(2, assignment->ToString()); @@ -604,13 +658,23 @@ StatusOr> CpuCompiler::Compile( TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory( proto, xla_dump_hlo_proto_to, module->name())); } + // We always profile the entire computation as a whole, even if hlo + // profiling is disabled. When hlo profiling is diabled, we pass in a + // profile counter array of just one element, which corresponds to the whole + // computation. + size_t entry_computation_profile_idx = + hlo_profile_index_map ? hlo_profile_index_map->GetProfileIndexFor( + *module->entry_computation()) + : 0; + // Each computation is a single function. Emit all embedded computations // before the entry computation. The order of computations returned from // GetEmbeddedComputations guarantees that a called computation occurs // before a caller computation. + IrEmitter ir_emitter(*module, *assignment, llvm_module.get(), - &hlo_to_profile_idx, jit->target_machine(), - jit->external_constant_pool()); + hlo_to_profile_idx, entry_computation_profile_idx, + jit->target_machine(), jit->external_constant_pool()); for (auto embedded_computation : computation->MakeEmbeddedComputationsList()) { @@ -621,7 +685,7 @@ StatusOr> CpuCompiler::Compile( ir_emitter .EmitComputation(embedded_computation, embedded_computation->name(), - /*is_entry_computation=*/false, + /*is_top_level_computation=*/false, &module_sequence.at(embedded_computation)) .status()); } @@ -630,7 +694,7 @@ StatusOr> CpuCompiler::Compile( TF_ASSIGN_OR_RETURN( llvm::Function * entry_function, ir_emitter.EmitComputation(computation, function_name_prefix, - /*is_entry_computation=*/true, + /*is_top_level_computation=*/true, &module_sequence.at(computation))); string function_name = llvm_ir::AsString(entry_function->getName()); @@ -643,7 +707,7 @@ StatusOr> CpuCompiler::Compile( jit->AddModule(std::move(llvm_module)); cpu_executable.reset(new CpuExecutable( std::move(jit), std::move(assignment), std::move(module), function_name, - std::move(hlo_to_profile_idx))); + std::move(hlo_profile_printer), std::move(hlo_profile_index_map))); if (embed_ir_in_executable) { static_cast(*cpu_executable) @@ -655,13 +719,6 @@ StatusOr> CpuCompiler::Compile( return std::move(cpu_executable); } -StatusOr>> CpuCompiler::Compile( - std::vector> modules, - std::vector> stream_execs) { - return Unimplemented( - "Compilation of multiple HLO modules is not yet supported on CPU."); -} - StatusOr>> CpuCompiler::CompileAheadOfTime(std::vector> modules, const AotCompilationOptions& aot_options) { @@ -770,7 +827,8 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, TF_ASSIGN_OR_RETURN( std::unique_ptr assignment, BufferAssigner::Run( - module, MakeUnique(module, module_sequence), + module, + xla::MakeUnique(module, module_sequence), BufferSizeBytesFunction(), memory_alignment)); // BufferAssignment::ToString() includes a header, so no need for us to // print one ourselves. @@ -784,9 +842,13 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, proto, xla_dump_hlo_proto_to, module->name())); } - IrEmitter ir_emitter(*module, *assignment, &llvm_module, - /*hlo_to_profile_idx=*/nullptr, target_machine.get(), - /*external_constant_pool=*/nullptr); + IrEmitter ir_emitter( + *module, *assignment, &llvm_module, + /*hlo_to_profile_idx=*/ + std::unordered_map{}, + /*entry_computation_profile_idx=*/tensorflow::gtl::nullopt, + target_machine.get(), + /*external_constant_pool=*/nullptr); HloComputation* computation = module->entry_computation(); for (auto embedded_computation : computation->MakeEmbeddedComputationsList()) { @@ -797,7 +859,7 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, ir_emitter .EmitComputation(embedded_computation, embedded_computation->name(), - /*is_entry_computation=*/false, + /*is_top_level_computation=*/false, &module_sequence.at(embedded_computation)) .status()); } @@ -805,7 +867,7 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, TF_ASSIGN_OR_RETURN( llvm::Function * entry_function, ir_emitter.EmitComputation(computation, entry_point_name, - /*is_entry_computation=*/true, + /*is_top_level_computation=*/true, &module_sequence.at(computation))); CHECK(entry_function->getName() == llvm_ir::AsStringRef(entry_point_name)); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h index d09130247421b11d6d4879466f39b89167eb9564..ebed7058d8f7968c6e03ef90d0da6b2325037eb0 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h @@ -109,14 +109,20 @@ class CpuCompiler : public LLVMCompiler { CpuCompiler(); ~CpuCompiler() override {} - StatusOr> Compile( + // Bring in + // StatusOr>> Compile( + // std::vector> modules, + // std::vector> + // stream_execs) + using LLVMCompiler::Compile; + + StatusOr> RunHloPasses( std::unique_ptr module, perftools::gputools::StreamExecutor* stream_exec) override; - StatusOr>> Compile( - std::vector> modules, - std::vector> - stream_execs) override; + StatusOr> RunBackend( + std::unique_ptr module, + perftools::gputools::StreamExecutor* stream_exec) override; StatusOr>> CompileAheadOfTime(std::vector> modules, diff --git a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.cc b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.cc new file mode 100644 index 0000000000000000000000000000000000000000..baaacd2ecc9611946678f71ac36ef787ecb57b4e --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.cc @@ -0,0 +1,43 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h" + +#include +#include +#include + +#include "tensorflow/compiler/xla/service/copy_insertion.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +StatusOr CpuCopyInsertion::Run(HloModule* module) { + CopyInsertion generic_copy_insertion; + + TF_ASSIGN_OR_RETURN(bool generic_changed, generic_copy_insertion.Run(module)); + + // The CPU backend needs additional copies added due to deficiencies in + // buffer assignment. + TF_ASSIGN_OR_RETURN(bool buffer_assignment_changed, + CopyInsertion::AddCopiesForBufferAssignment(module)); + + return generic_changed || buffer_assignment_changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h new file mode 100644 index 0000000000000000000000000000000000000000..3313d1e6eb71bff39f509c3d24858568df786422 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h @@ -0,0 +1,42 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_COPY_INSERTION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_COPY_INSERTION_H_ + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// Besides the modifications made by the generic xla::CopyInsertion, this +// CPU-specific copy insertion pass also adds copies to values live out of +// computations satisfying certain conditions (defined by constant or parameter, +// etc). This is necessary because of deficiencies of buffer +// assignment. Specifically, buffer assignment is computation-scoped and does +// not recognized aliasing between arguments and outputs of computations. +// +// TODO(b/62548313): Remove this when buffer assignment is smarter +// (module-scoped). +class CpuCopyInsertion : public HloPassInterface { + public: + tensorflow::StringPiece name() const override { return "copy-insertion"; } + + StatusOr Run(HloModule* module) override; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_COPY_INSERTION_H_ diff --git a/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..a05a26941786cbf404c4685abb098c9ac8caaa09 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/cpu_copy_insertion_test.cc @@ -0,0 +1,139 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h" + +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/test_benchmark.h" + +namespace xla { +namespace { + +namespace op = xla::testing::opcode_matchers; + +int64 CountCopies(const HloComputation& computation) { + int64 count = 0; + for (const auto& instruction : computation.instructions()) { + if (instruction->opcode() == HloOpcode::kCopy) { + count++; + } + } + return count; +} + +int64 CountCopies(const HloModule& module) { + int64 count = 0; + for (const auto& computation : module.computations()) { + count += CountCopies(*computation); + } + return count; +} + +class CpuCopyInsertionTest : public HloTestBase { + protected: + void InsertCopies(HloModule* module) { + CpuCopyInsertion copy_insertion; + ASSERT_IS_OK(copy_insertion.Run(module).status()); + } + + const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {}); +}; + +TEST_F(CpuCopyInsertionTest, WhileBodyWithConstantRoot) { + // Test a while body and condition which are each simply a constant (root of + // computation is a constant). Each constant should be copied. + auto module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + auto param_0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "param_0")); + + auto body_builder = HloComputation::Builder("body"); + body_builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "param")); + body_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(123.0))); + HloComputation* body = module->AddEmbeddedComputation(body_builder.Build()); + + auto cond_builder = HloComputation::Builder("condition"); + cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "param")); + cond_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloComputation* condition = + module->AddEmbeddedComputation(cond_builder.Build()); + + auto xla_while = builder.AddInstruction( + HloInstruction::CreateWhile(scalar_shape_, condition, body, param_0)); + + module->AddEntryComputation(builder.Build()); + + InsertCopies(module.get()); + + EXPECT_EQ(CountCopies(*module), 3); + + EXPECT_THAT(xla_while->operand(0), op::Copy(op::Parameter())); + EXPECT_THAT(body->root_instruction(), op::Copy(op::Constant())); + EXPECT_THAT(condition->root_instruction(), op::Copy(op::Constant())); +} + +TEST_F(CpuCopyInsertionTest, TupleCall) { + // Test a kCall instruction which calls a computation which produces a three + // element tuple: one is a constant, one is a parameter, and one is produced + // in the computation. The constant and parameter should be copied. + auto module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "param_0")); + const Shape tuple_shape = + ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_, scalar_shape_}); + + auto sub_builder = HloComputation::Builder("subcomputation"); + auto sub_param = sub_builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "param")); + auto constant = sub_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(123.0))); + auto add = sub_builder.AddInstruction(HloInstruction::CreateBinary( + scalar_shape_, HloOpcode::kAdd, sub_param, constant)); + sub_builder.AddInstruction( + HloInstruction::CreateTuple({sub_param, constant, add})); + HloComputation* subcomputation = + module->AddEmbeddedComputation(sub_builder.Build()); + + builder.AddInstruction( + HloInstruction::CreateCall(tuple_shape, {param}, subcomputation)); + + module->AddEntryComputation(builder.Build()); + + InsertCopies(module.get()); + + EXPECT_EQ(CountCopies(*subcomputation), 2); + EXPECT_THAT(subcomputation->root_instruction(), + op::Tuple(op::Copy(op::Parameter()), op::Copy(op::Constant()), + op::Add())); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc index 4dba87f49906739284daea68c70ef1860127f8d0..e956f478b86d9816615e2902f5bbeae6d6384162 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc @@ -43,6 +43,7 @@ limitations under the License. #include "tensorflow/core/platform/mem.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/stream_executor/host/host_stream.h" namespace se = ::perftools::gputools; @@ -54,11 +55,12 @@ CpuExecutable::CpuExecutable( std::unique_ptr assignment, std::unique_ptr hlo_module, const string& entry_function_name, - std::unordered_map hlo_to_profile_idx) - : Executable(std::move(hlo_module)), + std::unique_ptr hlo_profile_printer, + std::unique_ptr hlo_profile_index_map) + : Executable(std::move(hlo_module), std::move(hlo_profile_printer), + std::move(hlo_profile_index_map)), jit_(std::move(jit)), - assignment_(std::move(assignment)), - hlo_to_profile_idx_(std::move(hlo_to_profile_idx)) { + assignment_(std::move(assignment)) { // Resolve symbols in the constructor rather than at execution time to avoid // races because FindSymbol is not thread safe. llvm::JITSymbol sym = jit_->FindSymbol(entry_function_name); @@ -147,8 +149,9 @@ Status CpuExecutable::ExecuteComputeFunction( tensorflow::gtl::ArraySlice buffers, HloExecutionProfile* hlo_execution_profile) { std::vector argument_buffers; - for (int i = 0; i < arguments.size(); ++i) { - argument_buffers.push_back(arguments[i]->buffer(/*index=*/{})); + argument_buffers.reserve(arguments.size()); + for (const auto* argument : arguments) { + argument_buffers.push_back(argument->buffer(/*index=*/{})); } return ExecuteComputeFunction(run_options, argument_buffers, buffers, hlo_execution_profile); @@ -181,9 +184,16 @@ Status CpuExecutable::ExecuteComputeFunction( uint64 start_micros = tensorflow::Env::Default()->NowMicros(); // Allocate profiling counters for each hlo instruction that we would like to - // profile. Allocate an additional profile counter for the entire - // computation. - std::vector profile_counters(hlo_to_profile_idx_.size() + 1); + // profile. Even when not Hlo profiling, we allocate a counter for the entire + // computation, which we use to update ExecutionProfile below. + std::vector* profile_counters = nullptr; + std::vector profile_counter_for_entry_computation; + if (hlo_execution_profile) { + profile_counters = hlo_execution_profile->mutable_profile_counters(); + } else { + profile_counters = &profile_counter_for_entry_computation; + profile_counter_for_entry_computation.push_back(0); + } // Call the computation function following the calling convention. std::vector buffer_pointers; @@ -198,7 +208,7 @@ Status CpuExecutable::ExecuteComputeFunction( VLOG(3) << tensorflow::strings::Printf( " func(void* result, void* params[%zu], void* temps[%zu], " "uint64 profile_counters[%zu])", - args_array.size(), buffer_pointers.size(), profile_counters.size()); + args_array.size(), buffer_pointers.size(), profile_counters->size()); VLOG(3) << tensorflow::strings::Printf(" result = %p", result_buffer); auto ptr_printer = [](string* out, const void* p) { tensorflow::strings::StrAppend(out, tensorflow::strings::Printf("%p", p)); @@ -210,11 +220,11 @@ Status CpuExecutable::ExecuteComputeFunction( " temps = [%s]", tensorflow::str_util::Join(buffer_pointers, ", ", ptr_printer).c_str()); VLOG(3) << tensorflow::strings::Printf(" profile_counters = %p", - profile_counters.data()); + profile_counters->data()); } compute_function_(result_buffer, run_options, args_array.data(), - buffer_pointers.data(), profile_counters.data()); + buffer_pointers.data(), profile_counters->data()); uint64 end_micros = tensorflow::Env::Default()->NowMicros(); @@ -223,20 +233,46 @@ Status CpuExecutable::ExecuteComputeFunction( const double nanoseconds = (end_micros - start_micros) * 1000.0; execution_profile_.set_compute_time_ns(std::max(nanoseconds, 1.0)); - // The last profile counter is used for the computation as a whole. - execution_profile_.set_compute_cycle_count(profile_counters.back()); + if (hlo_execution_profile) { + execution_profile_.set_compute_cycle_count( + hlo_execution_profile->total_cycles_executed( + *module().entry_computation())); + } else { + execution_profile_.set_compute_cycle_count(profile_counters->back()); + } } - if (hlo_execution_profile != nullptr) { - hlo_execution_profile->set_total_cycles_executed( - *module().entry_computation(), profile_counters.back()); + return Status::OK(); +} + +static void LogLiveAddresses( + const std::unordered_set& marked_addresses) { + VLOG(3) << "Live addresses in output marking found " + << marked_addresses.size() << " addresses:\n" + << tensorflow::str_util::Join( + marked_addresses, ", ", [](string* out, const void* address) { + tensorflow::strings::StrAppend( + out, tensorflow::strings::Printf("%p", address)); + }); +} - for (auto hlo_prof_idx : hlo_to_profile_idx_) { - const HloInstruction* hlo = hlo_prof_idx.first; - uint64 cycles_taken = profile_counters[hlo_prof_idx.second]; - hlo_execution_profile->AddProfileResult(hlo, cycles_taken); +static Status DeallocateTempBuffers( + DeviceMemoryAllocator* allocator, se::Stream* stream, + tensorflow::gtl::ArraySlice buffers, + const std::unordered_set& marked_addresses) { + // Keep those marked live because they are referenced by the output of the + // computation and are needed by the service. They will be deallocated by the + // service. + for (size_t i = 0; i < buffers.size(); ++i) { + se::DeviceMemoryBase alloc = buffers[i]; + if (marked_addresses.count(alloc.opaque()) == 0 && !alloc.is_null()) { + VLOG(3) << "CpuExecutable deallocating buffer #" << i << " [" + << alloc.opaque() << "]"; + TF_RETURN_IF_ERROR( + allocator->Deallocate(stream->parent()->device_ordinal(), &alloc)); } } + return Status::OK(); } @@ -262,26 +298,9 @@ StatusOr CpuExecutable::ExecuteOnStream( MarkLiveAddressesInOutput(top_level_output.opaque(), result_shape(), &marked_addresses); - VLOG(3) << "Live addresses in output marking found " - << marked_addresses.size() << " addresses:\n" - << tensorflow::str_util::Join( - marked_addresses, ", ", [](string* out, const void* address) { - tensorflow::strings::StrAppend( - out, tensorflow::strings::Printf("%p", address)); - }); - - // Computation is done - deallocate temp buffers. Keep those marked live - // because they are referenced by the output of the computation and are needed - // by the service. They will be deallocated by the service. - for (size_t i = 0; i < buffers.size(); ++i) { - se::DeviceMemoryBase alloc = buffers[i]; - if (marked_addresses.count(alloc.opaque()) == 0 && !alloc.is_null()) { - VLOG(3) << "CpuExecutable deallocating buffer #" << i << " [" - << alloc.opaque() << "]"; - TF_RETURN_IF_ERROR(memory_allocator->Deallocate( - stream->parent()->device_ordinal(), &alloc)); - } - } + LogLiveAddresses(marked_addresses); + TF_RETURN_IF_ERROR(DeallocateTempBuffers(memory_allocator, stream, buffers, + marked_addresses)); return top_level_output; } @@ -359,9 +378,44 @@ StatusOr CpuExecutable::ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments) { - // TODO(b/30671675): Implement asynchronous execution mode. - return Unimplemented( - "Asynchronous execution on stream is not yet supported on CPU."); + if (hlo_profiling_enabled()) { + return Unimplemented( + "Asynchronous execution on stream with hlo profiling is not yet " + "supported on CPU."); + } + + auto* host_stream = dynamic_cast( + run_options->stream()->implementation()); + se::Stream* stream = run_options->stream(); + DeviceMemoryAllocator* memory_allocator = run_options->allocator(); + std::vector buffers(assignment_->Allocations().size()); + + TF_RETURN_IF_ERROR(AllocateBuffers( + memory_allocator, stream->parent()->device_ordinal(), &buffers)); + + // Mark the buffers that are actually live (used in the output) when the + // computation finishes executing. + std::unordered_set marked_addresses; + TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice, + assignment_->GetUniqueTopLevelOutputSlice()); + se::DeviceMemoryBase top_level_output = buffers[result_slice.index()]; + MarkLiveAddressesInOutput(top_level_output.opaque(), result_shape(), + &marked_addresses); + + LogLiveAddresses(marked_addresses); + + host_stream->EnqueueTask([this, run_options, arguments, buffers, + marked_addresses, memory_allocator, stream]() { + // Failing a CHECK here is not great, but I don't see an obvious way to + // return a failed Status asynchronously. + TF_CHECK_OK(ExecuteComputeFunction(&run_options->run_options(), arguments, + buffers, + /*hlo_execution_profile=*/nullptr)); + TF_CHECK_OK(DeallocateTempBuffers(memory_allocator, stream, buffers, + marked_addresses)); + }); + + return top_level_output; } /*static*/ int64 CpuExecutable::ShapeSizeBytes(const Shape& shape) { @@ -377,9 +431,5 @@ const PointsToSet& CpuExecutable::GetRootPointsToSet() const { module().entry_computation()->root_instruction()); } -std::unique_ptr CpuExecutable::CreateCostAnalysis() const { - return MakeUnique(ShapeSizeBytes); -} - } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.h b/tensorflow/compiler/xla/service/cpu/cpu_executable.h index 238bc9b46ae2bf1b519eaf137d9ae063e769bd2e..17ee2d673ee7cde1847bf29e2399e6033cb7e30e 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.h @@ -47,12 +47,12 @@ namespace cpu { // architecture, so JIT-ed code and host code share the same ABI. class CpuExecutable : public Executable { public: - CpuExecutable( - std::unique_ptr jit, - std::unique_ptr assignment, - std::unique_ptr hlo_module, - const string& entry_function_name, - std::unordered_map hlo_to_profile_idx); + CpuExecutable(std::unique_ptr jit, + std::unique_ptr assignment, + std::unique_ptr hlo_module, + const string& entry_function_name, + std::unique_ptr hlo_profile_printer, + std::unique_ptr hlo_profile_index_map); ~CpuExecutable() override {} StatusOr ExecuteOnStream( @@ -85,12 +85,10 @@ class CpuExecutable : public Executable { static int64 ShapeSizeBytes(const Shape& shape); - std::unique_ptr CreateCostAnalysis() const override; - // Type of the computation function we expect in the JIT. using ComputeFunctionType = void (*)( void* /*result*/, const ExecutableRunOptions* /*run_options*/, - const void** /*args*/, void** /*temps*/, uint64* /*profile_counters*/); + const void** /*args*/, void** /*temps*/, int64* /*profile_counters*/); const ComputeFunctionType& compute_function() const { return compute_function_; @@ -145,9 +143,6 @@ class CpuExecutable : public Executable { // Entry function name for the computation. const string entry_function_name_; - // Maps HLOs to their index into the profile counter array. - const std::unordered_map hlo_to_profile_idx_; - TF_DISALLOW_COPY_AND_ASSIGN(CpuExecutable); }; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc index b9e4d006d77ae76e33ac51440349400ea4eff118..1c04c9835e3e1ecf0f78a74aa74b0b052054004a 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc @@ -31,6 +31,14 @@ namespace { using InstructionFusionTest = HloTestBase; +std::unique_ptr MakeDot(const Shape& shape, HloInstruction* lhs, + HloInstruction* rhs) { + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + return HloInstruction::CreateDot(shape, lhs, rhs, dot_dnums); +} + TEST_F(InstructionFusionTest, DotOperationFusion_Basic_0) { HloComputation::Builder builder(TestName()); HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter( @@ -40,8 +48,8 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Basic_0) { HloInstruction* exp0 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(S32, {1024, 256}), HloOpcode::kExp, arg0)); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {1024, 1}), HloOpcode::kDot, exp0, arg1)); + HloInstruction* dot = builder.AddInstruction( + MakeDot(ShapeUtil::MakeShape(F32, {1024, 1}), exp0, arg1)); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); @@ -59,8 +67,8 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Basic_1) { HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(S32, {256, 1024}), HloOpcode::kExp, arg1)); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {1, 1024}), HloOpcode::kDot, arg0, exp1)); + HloInstruction* dot = builder.AddInstruction( + MakeDot(ShapeUtil::MakeShape(F32, {1, 1024}), arg0, exp1)); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); @@ -80,8 +88,8 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Bitcast) { ShapeUtil::MakeShape(S32, {2, 512, 2, 128}), HloOpcode::kExp, arg0)); HloInstruction* bitcast0 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(S32, {1024, 256}), HloOpcode::kBitcast, exp0)); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {1024, 1}), HloOpcode::kDot, bitcast0, arg1)); + HloInstruction* dot = builder.AddInstruction( + MakeDot(ShapeUtil::MakeShape(F32, {1024, 1}), bitcast0, arg1)); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); @@ -102,8 +110,8 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Reshape) { HloInstruction* reshape0 = builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(S32, {1024, 256}), exp0)); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {1024, 1}), HloOpcode::kDot, reshape0, arg1)); + HloInstruction* dot = builder.AddInstruction( + MakeDot(ShapeUtil::MakeShape(F32, {1024, 1}), reshape0, arg1)); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); @@ -121,8 +129,8 @@ TEST_F(InstructionFusionTest, DotOperationFusion_TooLarge) { HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(S32, {256, 32 * 1024}), HloOpcode::kExp, arg1)); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {1, 32 * 1024}), HloOpcode::kDot, arg0, exp1)); + HloInstruction* dot = builder.AddInstruction( + MakeDot(ShapeUtil::MakeShape(F32, {1, 32 * 1024}), arg0, exp1)); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); @@ -140,8 +148,8 @@ TEST_F(InstructionFusionTest, DotOperationFusion_ElementReuse) { HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(S32, {256, 1024}), HloOpcode::kExp, arg1)); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {2, 1024}), HloOpcode::kDot, arg0, exp1)); + HloInstruction* dot = builder.AddInstruction( + MakeDot(ShapeUtil::MakeShape(F32, {2, 1024}), arg0, exp1)); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); @@ -162,8 +170,8 @@ TEST_F(InstructionFusionTest, DotOperationFusion_TransposeFusion) { HloInstruction* transpose1 = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(S32, {256, 1024}), exp1, {1, 0})); - builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {1, 1024}), HloOpcode::kDot, arg0, transpose1)); + builder.AddInstruction( + MakeDot(ShapeUtil::MakeShape(F32, {1, 1024}), arg0, transpose1)); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); diff --git a/tensorflow/compiler/xla/service/cpu/layout_assignment.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc similarity index 59% rename from tensorflow/compiler/xla/service/cpu/layout_assignment.cc rename to tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc index c446b6b792a042da2500ea6a175fdca4c70bcab6..0df10f4af318de3f80e4df599797709c5c43b5cd 100644 --- a/tensorflow/compiler/xla/service/cpu/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc @@ -13,69 +13,76 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/cpu/layout_assignment.h" +#include "tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h" #include #include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h" #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" #include "tensorflow/core/lib/core/errors.h" namespace xla { namespace cpu { -Status CpuLayoutAssignment::AddBackendConstraints( - LayoutConstraints* constraints) { - auto row_major_shape = [](const Shape& old_shape) { - Shape new_shape(old_shape); - std::vector dimension_order(new_shape.dimensions_size()); - std::iota(dimension_order.rbegin(), dimension_order.rend(), 0); - *new_shape.mutable_layout() = LayoutUtil::MakeLayout(dimension_order); - return new_shape; - }; - auto col_major_shape = [](const Shape& old_shape) { - Shape new_shape(old_shape); - std::vector dimension_order(new_shape.dimensions_size()); - std::iota(dimension_order.begin(), dimension_order.end(), 0); - *new_shape.mutable_layout() = LayoutUtil::MakeLayout(dimension_order); - return new_shape; - }; - - // We want to change the layout of constant arrays to be column major when all - // of their users are dot operations that can be made faster with the flipped - // layout. To avoid going quadriatic over the # of instructions, we cache - // this property in should_make_rhs_col_major -- it maps a constant to true if - // all of the users of said constant are dot operations that can be sped up. - // This cache is populated lazily as we encounter dot operations traversing - // the instruction stream. - tensorflow::gtl::FlatMap - should_make_rhs_col_major_cache; - auto should_make_rhs_col_major = [&](const HloInstruction& instruction) { - if (ProfitableToImplementDotInLlvmIr(instruction) != - DotInLlvmIrProfitable::kWithColumnMajorRhs) { - return false; - } +// We want to change the layout of constant arrays to be column major when all +// of their users are dot operations that can be made faster with the flipped +// layout. To avoid going quadriatic over the # of instructions, we cache this +// property in should_make_rhs_col_major -- it maps a constant to true if all of +// the users of said constant are dot operations that can be sped up. This +// cache is populated lazily as we encounter dot operations traversing the +// instruction stream. + +namespace { +using ShouldMakeRhsColMajorCache = + tensorflow::gtl::FlatMap; +} - const auto* rhs = instruction.operand(1); - if (rhs->opcode() != HloOpcode::kConstant) { - return false; - } +static bool ShouldMakeRhsColMajor(ShouldMakeRhsColMajorCache* cache, + const HloInstruction& instruction) { + if (!ProfitableToMakeDotRhsColumnMajor(instruction)) { + return false; + } - auto it = should_make_rhs_col_major_cache.find(rhs); - if (it != should_make_rhs_col_major_cache.end()) { - return it->second; - } + const auto* rhs = instruction.operand(1); + if (rhs->opcode() != HloOpcode::kConstant) { + return false; + } + + auto it = cache->find(rhs); + if (it != cache->end()) { + return it->second; + } - bool result = std::all_of( - rhs->users().begin(), rhs->users().end(), [&](HloInstruction* user) { - return ProfitableToImplementDotInLlvmIr(*user) == - DotInLlvmIrProfitable::kWithColumnMajorRhs && - user->operand(0) != rhs; - }); + bool result = std::all_of(rhs->users().begin(), rhs->users().end(), + [&](HloInstruction* user) { + return ProfitableToMakeDotRhsColumnMajor(*user) && + user->operand(0) != rhs; + }); - InsertOrDie(&should_make_rhs_col_major_cache, rhs, result); - return result; - }; + InsertOrDie(cache, rhs, result); + return result; +} + +static Shape RowMajorShape(const Shape& old_shape) { + Shape new_shape(old_shape); + std::vector dimension_order(new_shape.dimensions_size()); + std::iota(dimension_order.rbegin(), dimension_order.rend(), 0); + *new_shape.mutable_layout() = LayoutUtil::MakeLayout(dimension_order); + return new_shape; +} + +static Shape ColMajorShape(const Shape& old_shape) { + Shape new_shape(old_shape); + std::vector dimension_order(new_shape.dimensions_size()); + std::iota(dimension_order.begin(), dimension_order.end(), 0); + *new_shape.mutable_layout() = LayoutUtil::MakeLayout(dimension_order); + return new_shape; +} + +Status CpuLayoutAssignment::AddBackendConstraints( + LayoutConstraints* constraints) { + ShouldMakeRhsColMajorCache cache; const HloComputation* computation = constraints->computation(); for (auto* instruction : computation->instructions()) { @@ -90,9 +97,9 @@ Status CpuLayoutAssignment::AddBackendConstraints( // // These constraints are not hard constraints. Ideally, we should decide // which layouts to choose according to some cost model. - Shape output_shape(row_major_shape(convolution->shape())); - Shape input_shape(row_major_shape(lhs_instruction->shape())); - Shape filter_shape(row_major_shape(rhs_instruction->shape())); + Shape output_shape(RowMajorShape(convolution->shape())); + Shape input_shape(RowMajorShape(lhs_instruction->shape())); + Shape filter_shape(RowMajorShape(rhs_instruction->shape())); // Set layouts of the instructions' shapes. TF_RETURN_IF_ERROR( @@ -101,11 +108,11 @@ Status CpuLayoutAssignment::AddBackendConstraints( constraints->SetOperandLayout(filter_shape, convolution, 1)); TF_RETURN_IF_ERROR( constraints->SetInstructionLayout(output_shape, convolution)); - } else if (should_make_rhs_col_major(*instruction)) { + } else if (ShouldMakeRhsColMajor(&cache, *instruction)) { auto* dot = instruction; const auto& rhs_shape = dot->operand(1)->shape(); TF_RETURN_IF_ERROR( - constraints->SetOperandLayout(col_major_shape(rhs_shape), dot, 1)); + constraints->SetOperandLayout(ColMajorShape(rhs_shape), dot, 1)); } else if (PotentiallyImplementedAsEigenDot(*instruction)) { const HloInstruction* dot = instruction; // In order to implement `dot` with Eigen dot, the layouts of the lhs, @@ -113,17 +120,17 @@ Status CpuLayoutAssignment::AddBackendConstraints( // // These constraints are not hard constraints. Ideally, we should decide // which layouts to choose according to some cost model. - Shape output_shape(row_major_shape(dot->shape())); + Shape output_shape(RowMajorShape(dot->shape())); const HloInstruction* lhs_instruction = dot->operand(0); - Shape lhs_shape(row_major_shape(lhs_instruction->shape())); + Shape lhs_shape(RowMajorShape(lhs_instruction->shape())); TF_RETURN_IF_ERROR(constraints->SetOperandLayout(lhs_shape, dot, 0)); // dot is a kDot or a kTransposeDot fusion node. In the latter case, if // it represents X @ X, it may have just one operand. if (dot->operand_count() > 1) { const HloInstruction* rhs_instruction = dot->operand(1); - Shape rhs_shape(row_major_shape(rhs_instruction->shape())); + Shape rhs_shape(RowMajorShape(rhs_instruction->shape())); TF_RETURN_IF_ERROR(constraints->SetOperandLayout(rhs_shape, dot, 1)); } @@ -140,8 +147,12 @@ Status CpuLayoutAssignment::AddBackendConstraints( if (constraints->OperandBufferForwarded(instruction, operand_no)) { continue; } + // Skip operands with non-array shapes. + if (!ShapeUtil::IsArray(instruction->operand(operand_no)->shape())) { + continue; + } Shape operand_shape( - row_major_shape(instruction->operand(operand_no)->shape())); + RowMajorShape(instruction->operand(operand_no)->shape())); TF_RETURN_IF_ERROR(constraints->SetOperandLayout( operand_shape, instruction, operand_no)); } diff --git a/tensorflow/compiler/xla/service/cpu/layout_assignment.h b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h similarity index 86% rename from tensorflow/compiler/xla/service/cpu/layout_assignment.h rename to tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h index 4fd8d68dd6b4f2a8b16f6c048743a996ea76a560..c8edbb9e15a5b6f9c574f5fe9d130d149499ebd2 100644 --- a/tensorflow/compiler/xla/service/cpu/layout_assignment.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_LAYOUT_ASSIGNMENT_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_LAYOUT_ASSIGNMENT_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_LAYOUT_ASSIGNMENT_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_LAYOUT_ASSIGNMENT_H_ #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/layout_assignment.h" @@ -38,4 +38,4 @@ class CpuLayoutAssignment : public LayoutAssignment { } // namespace cpu } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_LAYOUT_ASSIGNMENT_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_LAYOUT_ASSIGNMENT_H_ diff --git a/tensorflow/compiler/xla/service/cpu/layout_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc similarity index 99% rename from tensorflow/compiler/xla/service/cpu/layout_assignment_test.cc rename to tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc index 1ea5e8c7fc4896512e62396d0a756cda44785f11..401cf50717959da95f48963c3c83b3036a80eb1b 100644 --- a/tensorflow/compiler/xla/service/cpu/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/cpu/layout_assignment.h" +#include "tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h" #include #include diff --git a/tensorflow/compiler/xla/service/cpu/cpu_options.cc b/tensorflow/compiler/xla/service/cpu/cpu_options.cc index dba140d1120bc5502d2039e1663b9bf035d8d66a..09f028463af68bbc2841fecdb2ca6c6a42498798 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_options.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_options.cc @@ -15,11 +15,14 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_options.h" +#include "tensorflow/core/lib/strings/numbers.h" + namespace { const char* const kXlaParallelCpuOption = "xla_cpu_parallel"; const char* const kXlaOptimizeForSizeCpuOption = "xla_cpu_optimize_for_size"; const char* const kXlaDisableVectorizedReduce = "xla_disable_vectorized_reduce"; +const char* const kLlvmIrDotTilingFactor = "xla_llvm_dot_tiling_factor"; } // namespace @@ -45,6 +48,19 @@ bool VectorizedReduceDisabled(const HloModuleConfig& config) { return extra_options_map.count(kXlaOptimizeForSizeCpuOption) > 0; } +tensorflow::gtl::optional LlvmIrGemvTilingFactor( + const HloModuleConfig& config) { + const auto& extra_options_map = + config.debug_options().xla_backend_extra_options(); + auto it = extra_options_map.find(kLlvmIrDotTilingFactor); + int64 tiling_factor; + if (it != extra_options_map.end() && + tensorflow::strings::safe_strto64(it->second, &tiling_factor)) { + return tiling_factor; + } + return tensorflow::gtl::nullopt; +} + } // namespace options } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_options.h b/tensorflow/compiler/xla/service/cpu/cpu_options.h index 5dc24ebc7b8661092e3bc27c4f30fda1e497e41b..6ba0fd24538b63a3da81083482e6bee3b552dfea 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_options.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_options.h @@ -27,6 +27,8 @@ namespace options { bool CpuParallelBackendRequested(const HloModuleConfig& config); bool OptimizeForSizeRequested(const HloModuleConfig& config); bool VectorizedReduceDisabled(const HloModuleConfig& config); +tensorflow::gtl::optional LlvmIrGemvTilingFactor( + const HloModuleConfig& config); } // namespace options } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc index f8e260dd90149405fff7beefba3f7fe83b75d4b6..f385829cdf5cafbd35e083f47106734cdd5dde88 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc @@ -12,15 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ - +#define EIGEN_USE_THREADS #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" #include #include #include -#define EIGEN_USE_THREADS - #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/local_client.h" diff --git a/tensorflow/compiler/xla/service/cpu/disassembler.h b/tensorflow/compiler/xla/service/cpu/disassembler.h index b6feaa7e45cee26eb7f850081bd1fad2cb63b15c..5e302f88990ee4a3c37758881ecec4d6f71dd8e6 100644 --- a/tensorflow/compiler/xla/service/cpu/disassembler.h +++ b/tensorflow/compiler/xla/service/cpu/disassembler.h @@ -37,7 +37,7 @@ struct DisassemblerResult { DisassemblerResult(const string& text, size_t code_size_bytes) : text(text), code_size_bytes(code_size_bytes) {} - // The dissassembled text sections of the object file. + // The disassembled text sections of the object file. string text; // The total number of bytes of executable code in the object file. uint64_t code_size_bytes; @@ -53,7 +53,7 @@ class Disassembler { // Returns a DisassemblerResult for the given object file, containing the // disassembled code. // - // If we couldnt' retrieve a disassembler for this platform, an error status + // If we couldn't retrieve a disassembler for this platform, an error status // is returned. StatusOr DisassembleObjectFile( const llvm::object::ObjectFile& object_file) const; diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc index e57d49172b18beb75cfbb482c5d732ef679ebe41..7f0bf2c8e4e26511e2e69121042540120c281c62 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc @@ -23,9 +23,10 @@ limitations under the License. #include "llvm/IR/Module.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" -#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/service/llvm_ir/vector_support_library.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" @@ -38,6 +39,457 @@ using llvm_ir::SetToFirstInsertPoint; namespace cpu { +namespace { +// Loads a tile of values from a 2D tensor. +class TileLoader { + public: + // Constructs a TileLoader that will load a tile consisting of + // `tile_size_along_major_dim` vectors from the matrix `matrix`, starting at + // `major_dim_offset` in the major dimension. The tile size along the minor + // dimension is the vector size, and that is implicitly determined by `vsl`. + TileLoader(VectorSupportLibrary* vsl, llvm::IRBuilder<>* ir_builder, + llvm::Value* matrix, int64 matrix_size_along_minor_dim, + llvm::Value* major_dim_offset, int64 tile_size_along_major_dim) + : vsl_(vsl) { + pointers_.reserve(tile_size_along_major_dim); + for (int64 i = 0; i < tile_size_along_major_dim; i++) { + llvm::Value* total_offset = ir_builder->CreateMul( + ir_builder->getInt64(matrix_size_along_minor_dim), + ir_builder->CreateAdd(ir_builder->getInt64(i), major_dim_offset)); + pointers_.push_back(vsl_->ComputeOffsetPointer(matrix, total_offset)); + } + } + + // Load a tile consisting of `tile_size_along_major_dim_` vectors starting at + // `major_dim_offset_` in the major dimension and `minor_dim_offset` in the + // minor dimension. + std::vector LoadTile(llvm::Value* minor_dim_offset) const { + std::vector result; + result.reserve(pointers_.size()); + for (const auto& pointer : pointers_) { + result.push_back(vsl_->LoadVector(pointer, minor_dim_offset)); + } + return result; + } + + private: + VectorSupportLibrary* vsl_; + std::vector pointers_; +}; + +// Computes a dot product between "[M,K]{0,1} lhs" with a [K,1] vector (the +// layout of the vector does not matter). This implementation uses a tiling +// scheme to improve performance. +// +// We logically separate the LHS matrix into four segments: +// +// +----------------------+---+ +// | | | +// | | | +// | A | B | +// | | | +// | | | +// | | | +// +----------------------+---+ +// | C | D | +// +----------------------+---+ +// +// where A is the largest submatrix of the LHS that can be evenly dividied into +// tiles. For each tile in A, assuming tile_rows_ == tile_cols_ == 4, we have: +// +// +---+---+---+---+ +--+--+--+--+ +// |M00|M10|M20|M30| |V0|V1|V2|V3| +// +---+---+---+---+ +--+--+--+--+ +// |M01|M11|M21|M31| and |V0|V1|V2|V3| +// +---+---+---+---+ +--+--+--+--+ +// |M02|M12|M22|M32| |V0|V1|V2|V3| +// +---+---+---+---+ +--+--+--+--+ +// |M03|M13|M23|M33| |V0|V1|V2|V3| +// +---+---+---+---+ +--+--+--+--+ +// +// (Legend: rows are horizontal and columns are vertical; and each column is one +// llvm::Value of a vector type) +// +// where: +// +// a. The left tile is from the column major left matrix. +// b. The right tile is an elementwise broadcast of a [V0, V1, V2, V3] +// vector loaded from the RHS vector. +// +// As we iterate through the column dimension, we compute the change to the +// result vector by an elementwise multiplication between the two tiles above +// followed by a reduction along the major dimension: +// +// +-----------------------------------+ +// | M00*V0 + M10*V1 + M20*V2 + M30*V3 | +// +-----------------------------------+ +// | M01*V0 + M11*V1 + M21*V2 + M31*V3 | +// Result[R:R+4] += +-----------------------------------+ +// | M02*V0 + M12*V1 + M22*V2 + M32*V3 | +// +-----------------------------------+ +// | M03*V0 + M13*V1 + M23*V2 + M33*V3 | +// +-----------------------------------+ +// +// Where R is the starting row for the tile. +// +// We have an inner epilogue loop to deal with the "C" submatrix and an outer +// epilogue loop to deal with the B,D submarix. +// +// TODO(sanjoy): We should investigate if using gather loads and scatter stores +// can be used here have the same inner loop for both column-major and row-major +// matrix-vector products. +class ColumnMajorMatrixVectorProductEmitter { + public: + ColumnMajorMatrixVectorProductEmitter(PrimitiveType scalar_type, + int64 tile_rows, int64 tile_cols, + int64 m, int64 k, llvm::Value* lhs, + llvm::Value* rhs, llvm::Value* result, + llvm::IRBuilder<>* ir_builder) + : scalar_type_(scalar_type), + tile_rows_(tile_rows), + tile_cols_(tile_cols), + m_(m), + k_(k), + lhs_(lhs), + rhs_(rhs), + result_(result), + ir_builder_(ir_builder), + ksl_(ir_builder_), + vsl_(scalar_type_, /*vector_size=*/tile_rows_, ir_builder_, "") { + CHECK(tile_rows_ > 0 && IsPowerOfTwo(static_cast(tile_rows_))); + } + + void Emit(); + + private: + void EmitOuterLoopBody(llvm::Value* column, int64 column_count, + bool is_first_column); + + TileLoader GetLhsTileLoader(llvm::Value* column_start, int64 column_count) { + return TileLoader(&vsl_, ir_builder_, /*matrix=*/lhs_, + /*matrix_size_along_minor_dim=*/m_, + /*major_dim_offset=*/column_start, + /*tile_size_along_major_dim=*/column_count); + } + + // Load a tile of values from the RHS. For the RHS a "tile" is a contiguous + // sequnce of `count` values, each one broadcasted to the vector width. + std::vector LoadRhsTile(llvm::Value* offset, int64 count) { + llvm::Value* base_pointer = vsl_.ComputeOffsetPointer(rhs_, offset); + std::vector result; + result.reserve(count); + for (int64 i = 0; i < count; i++) { + result.push_back(vsl_.LoadBroadcast(base_pointer, i)); + } + return result; + } + + void EmitInnerLoopTiled(TileLoader* lhs_tile_loader, + const std::vector& rhs_tile, + int64 columns, bool is_first_column); + + void EmitInnerLoopEpilogue(llvm::Value* current_tile_col, int64 columns, + bool is_first_tiled_column); + + PrimitiveType scalar_type_; + int64 tile_rows_; + int64 tile_cols_; + int64 m_; + int64 k_; + llvm::Value* lhs_; + llvm::Value* rhs_; + llvm::Value* result_; + llvm::IRBuilder<>* ir_builder_; + KernelSupportLibrary ksl_; + VectorSupportLibrary vsl_; +}; + +void ColumnMajorMatrixVectorProductEmitter::EmitOuterLoopBody( + llvm::Value* column, int64 column_count, bool is_first_column) { + TileLoader lhs_tile_loader = GetLhsTileLoader(/*column_start=*/column, + /*column_count=*/column_count); + + std::vector rhs_tile = + LoadRhsTile(column, /*count=*/column_count); + EmitInnerLoopTiled(&lhs_tile_loader, rhs_tile, + /*columns=*/column_count, is_first_column); + EmitInnerLoopEpilogue(column, /*columns=*/column_count, is_first_column); +} + +void ColumnMajorMatrixVectorProductEmitter::Emit() { + // See the comment on the class declaration for the algorithm used here. + int64 column_remainder = k_ % tile_cols_; + int64 column_limit = k_ - column_remainder; + + ksl_.For("dot.outer.tiled", + /*start=*/0, /*end=*/column_limit, /*step=*/tile_cols_, + [&](llvm::Value* column, bool is_first_column) { + EmitOuterLoopBody(column, tile_cols_, is_first_column); + }); + + if (column_remainder != 0) { + EmitOuterLoopBody(ir_builder_->getInt64(column_limit), column_remainder, + column_limit == 0); + } +} + +void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopTiled( + TileLoader* lhs_tile_loader, const std::vector& rhs_tile, + int64 columns, bool is_first_column) { + int64 row_limit = m_ - (m_ % tile_rows_); + + ksl_.For("dot.inner.tiled", /*start=*/0, /*end=*/row_limit, + /*step=*/tile_rows_, [&](llvm::Value* row) { + std::vector lhs_tile = + lhs_tile_loader->LoadTile(/*minor_dim_offset=*/row); + llvm::Value* accumulator = is_first_column + ? vsl_.GetZeroVector() + : vsl_.LoadVector(result_, row); + for (int i = 0; i < columns; i++) { + accumulator = vsl_.MulAdd(lhs_tile[i], rhs_tile[i], accumulator); + } + vsl_.StoreVector(accumulator, result_, row); + }); +} + +void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( + llvm::Value* current_tile_col, int64 columns, bool is_first_tiled_column) { + int64 row_start = m_ - (m_ % tile_rows_); + if (row_start == m_) { + return; + } + + llvm::Value* columns_llvm = ir_builder_->getInt64(columns); + + // for (col = current_tile_col; col < (columns + current_tile_col); col++) + // for (row = row_start, row < m_; row++) { + // result[row] += lhs[row, col] * rhs[col] + // // Also take into account that if col is 0 then result[row] is not + // // initialized. + // } + + ksl_.For( + "dot.inner.epilg.outer", /*start=*/current_tile_col, + /*end=*/ir_builder_->CreateAdd(columns_llvm, current_tile_col), + /*step=*/1, /*peel_first_iteration=*/false, + [&](llvm::Value* col, llvm::Value* is_first_scalar_col) { + llvm::Value* rhs_element = vsl_.LoadScalar(rhs_, col); + llvm::Value* total_offset = + ir_builder_->CreateMul(col, ir_builder_->getInt64(m_)); + llvm::Value* lhs_base_pointer = + vsl_.ComputeOffsetPointer(lhs_, total_offset); + ksl_.For( + "dot.inner.epilg.inner", /*start=*/row_start, /*end=*/m_, + /*step=*/1, [&](llvm::Value* scalar_row) { + llvm::Value* product = vsl_.Mul( + vsl_.LoadScalar(lhs_base_pointer, scalar_row), rhs_element); + llvm::Value* setting_result_first_time = ir_builder_->CreateAnd( + is_first_scalar_col, + ir_builder_->getInt1(is_first_tiled_column)); + ksl_.If( + setting_result_first_time, + [&]() { vsl_.StoreScalar(product, result_, scalar_row); }, + [&]() { + vsl_.StoreScalar( + vsl_.Add(vsl_.LoadScalar(result_, scalar_row), product), + result_, scalar_row); + }); + }); + }); +} + +// Computes a dot product between "[M,K]{1,0} lhs" with a [K,1] vector (the +// layout of the vector does not matter). This implementation uses a tiling +// scheme to improve performance. +// +// We logically separate the LHS matrix into four segments: +// +// +----------------------+---+ +// | | | +// | | | +// | A | B | +// | | | +// | | | +// | | | +// +----------------------+---+ +// | C | D | +// +----------------------+---+ +// +// where A is the largest submatrix of the LHS that can be evenly dividied into +// tiles. For each tile in A, assuming tile_rows_ == tile_cols_ == 4, we have: +// +// +---+---+---+---+ +// |M00|M10|M20|M30| +// +---+---+---+---+ +--+--+--+--+ +// |M01|M11|M21|M31| and |V0|V1|V2|V3| +// +---+---+---+---+ +--+--+--+--+ +// |M02|M12|M22|M32| +// +---+---+---+---+ +// |M03|M13|M23|M33| +// +---+---+---+---+ +// +// (Legend: rows are horizontal and columns are vertical; and each row is one +// llvm::Value of a vector type) +// +// where: +// +// a. The left tile is loaded from the row major left matrix. +// b. The right vector is loaded from the RHS vector. +// +// We keep 4 vector accumulators accumulating the following four vector +// expressions as we iterate over the row dimension: +// +// +------+------+------+------+ +// |M0I*V0|M1I*V1|M2I*V2|M3I*V3| for I in [0,4) +// +------+------+------+------+ +// +// In the end we do a horizontal reduction over these 4 vector accumulators to +// get 4 values in the result vector. +// +// We have an inner epilogue loop to deal with the "B" sub-matrix and an outer +// epilogue loop to deal with the C,D submatrix. +class RowMajorMatrixVectorProductEmitter { + public: + RowMajorMatrixVectorProductEmitter(PrimitiveType scalar_type, int64 tile_rows, + int64 tile_cols, int64 m, int64 k, + llvm::Value* lhs, llvm::Value* rhs, + llvm::Value* result, + llvm::IRBuilder<>* ir_builder) + : scalar_type_(scalar_type), + tile_rows_(tile_rows), + tile_cols_(tile_cols), + m_(m), + k_(k), + lhs_(lhs), + rhs_(rhs), + result_(result), + ir_builder_(ir_builder), + ksl_(ir_builder_), + vsl_(scalar_type_, /*vector_size=*/tile_cols_, ir_builder_, "") { + CHECK(tile_cols_ > 0 && IsPowerOfTwo(static_cast(tile_cols_))); + } + + void Emit(); + + private: + TileLoader GetLhsTileLoader(llvm::Value* row_start, int64 row_count) { + return TileLoader(&vsl_, ir_builder_, /*matrix=*/lhs_, + /*matrix_size_along_minor_dim=*/k_, + /*major_dim_offset=*/row_start, + /*tile_size_along_major_dim=*/row_count); + } + + void EmitOuterLoopBody(llvm::Value* row, int64 row_count); + + void EmitInnerLoopTiled(TileLoader* lhs_tile_loader, int64 rows, + std::vector* vector_accumulators); + + void EmitInnerLoopEpilogue(llvm::Value* current_tile_row, int64 rows, + std::vector* scalar_accumulators); + + PrimitiveType scalar_type_; + int64 tile_rows_; + int64 tile_cols_; + int64 m_; + int64 k_; + llvm::Value* lhs_; + llvm::Value* rhs_; + llvm::Value* result_; + llvm::IRBuilder<>* ir_builder_; + KernelSupportLibrary ksl_; + VectorSupportLibrary vsl_; +}; + +void RowMajorMatrixVectorProductEmitter::EmitOuterLoopBody(llvm::Value* row, + int64 row_count) { + TileLoader lhs_tile_loader = GetLhsTileLoader(/*row_start=*/row, + /*row_count=*/row_count); + std::vector vector_accumulators; + std::vector scalar_accumulators; + for (int i = 0; i < row_count; i++) { + vector_accumulators.emplace_back(&vsl_, vsl_.GetZeroVector()); + scalar_accumulators.emplace_back(&vsl_, vsl_.GetZeroScalar()); + } + EmitInnerLoopTiled(&lhs_tile_loader, /*rows=*/row_count, + &vector_accumulators); + EmitInnerLoopEpilogue(/*current_tile_row=*/row, /*rows=*/row_count, + &scalar_accumulators); + + std::vector accumulator_values; + std::transform( + vector_accumulators.begin(), vector_accumulators.end(), + std::back_inserter(accumulator_values), + [](const VectorVariable& vector_var) { return vector_var.Get(); }); + std::vector horizontal_sums = + vsl_.ComputeHorizontalSums(std::move(accumulator_values)); + + for (int i = 0; i < row_count; i++) { + llvm::Value* result_value = + vsl_.Add(horizontal_sums[i], scalar_accumulators[i].Get()); + llvm::Value* offset = ir_builder_->CreateAdd(ir_builder_->getInt64(i), row); + vsl_.StoreScalar(result_value, result_, offset); + } +} + +void RowMajorMatrixVectorProductEmitter::Emit() { + // See the comment on the class declaration for the algorithm used here. + int64 row_remainder = m_ % tile_rows_; + int64 row_limit = m_ - row_remainder; + + ksl_.For("dot.outer.tiled", + /*start=*/0, /*end=*/row_limit, /*step=*/tile_rows_, + [&](llvm::Value* row) { EmitOuterLoopBody(row, tile_rows_); }); + + if (row_remainder != 0) { + EmitOuterLoopBody(ir_builder_->getInt64(row_limit), row_remainder); + } +} + +void RowMajorMatrixVectorProductEmitter::EmitInnerLoopTiled( + TileLoader* lhs_tile_loader, int64 rows, + std::vector* vector_accumulators) { + int64 column_limit = k_ - (k_ % tile_cols_); + + ksl_.For("dot.inner.tiled", /*start=*/0, /*end=*/column_limit, + /*step=*/tile_cols_, [&](llvm::Value* col) { + std::vector lhs_tile = + lhs_tile_loader->LoadTile(/*minor_dim_offset=*/col); + llvm::Value* rhs_value = vsl_.LoadVector(rhs_, col); + for (int i = 0; i < rows; i++) { + llvm::Value* old_sum = (*vector_accumulators)[i].Get(); + (*vector_accumulators)[i].Set( + vsl_.Add(old_sum, vsl_.Mul(rhs_value, lhs_tile[i]))); + } + }); +} + +void RowMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( + llvm::Value* current_tile_row, int64 rows, + std::vector* scalar_accumulators) { + int64 column_start = k_ - (k_ % tile_cols_); + if (column_start == k_) { + return; + } + + for (int r = 0; r < rows; r++) { + llvm::Value* total_offset = ir_builder_->CreateMul( + ir_builder_->CreateAdd(ir_builder_->getInt64(r), current_tile_row), + ir_builder_->getInt64(k_)); + llvm::Value* lhs_base_pointer = + vsl_.ComputeOffsetPointer(lhs_, total_offset); + ksl_.For("dot.inner.epilg.inner", /*start=*/column_start, /*end=*/k_, + /*step=*/1, [&](llvm::Value* scalar_col) { + llvm::Value* product = + vsl_.Mul(vsl_.LoadScalar(lhs_base_pointer, scalar_col), + vsl_.LoadScalar(rhs_, scalar_col)); + llvm::Value* old_value = (*scalar_accumulators)[r].Get(); + (*scalar_accumulators)[r].Set(vsl_.Add(old_value, product)); + }); + } +} + +} // namespace + DotOpEmitter::DotOpEmitter(const HloInstruction& dot, bool transpose_lhs, bool transpose_rhs, const llvm_ir::IrArray& target_array, @@ -72,6 +524,122 @@ DotOpEmitter::DotOpEmitter(const HloInstruction& dot, bool transpose_lhs, bool DotOpEmitter::ShapesAreLegalForRuntimeDot() const { return true; } +bool DotOpEmitter::EmitLlvmIrDotIfProfitable() { + if (dot_.shape().dimensions_size() != 2) { + return false; + } + + PrimitiveType primitive_type = dot_.shape().element_type(); + + if (!primitive_util::IsFloatingPointType(primitive_type) && + !primitive_util::IsIntegralType(primitive_type)) { + return false; + } + + MatMultDims mat_mult_dims = GetMatMultDims(); + bool is_column_major_matrix_vector = false; + bool is_row_major_matrix_vector = false; + + int64 m, k; + bool swap_operands; + + if (mat_mult_dims.m == 1) { + bool rhs_effectively_row_major = + transpose_rhs_ ^ !mat_mult_dims.rhs_column_major; + if (rhs_effectively_row_major) { + k = mat_mult_dims.k; + m = mat_mult_dims.n; + is_column_major_matrix_vector = true; + swap_operands = true; + } else { + k = mat_mult_dims.k; + m = mat_mult_dims.n; + is_row_major_matrix_vector = true; + swap_operands = true; + } + } + + if (mat_mult_dims.n == 1) { + bool lhs_effectively_column_major = + transpose_lhs_ ^ mat_mult_dims.lhs_column_major; + if (lhs_effectively_column_major) { + m = mat_mult_dims.m; + k = mat_mult_dims.k; + is_column_major_matrix_vector = true; + swap_operands = false; + } else { + m = mat_mult_dims.m; + k = mat_mult_dims.k; + is_row_major_matrix_vector = true; + swap_operands = false; + } + } + + if (!is_column_major_matrix_vector && !is_row_major_matrix_vector) { + return false; + } + + int64 tiling_factor = GetGemvTilingFactor(); + CHECK_GT(tiling_factor, 0); + + llvm::Value* result_op = target_array_.GetBasePointer(); + llvm::Value* lhs_op = + swap_operands ? rhs_array_.GetBasePointer() : lhs_array_.GetBasePointer(); + llvm::Value* rhs_op = + swap_operands ? lhs_array_.GetBasePointer() : rhs_array_.GetBasePointer(); + + const bool enable_fast_math = + hlo_module_config_.debug_options().xla_enable_fast_math(); + const bool optimize_for_size = + options::OptimizeForSizeRequested(hlo_module_config_); + + if (is_column_major_matrix_vector) { + VLOG(2) << "Emitting column major matrix-vector multiply with m = " << m + << " and k = " << k; + int64 tile_rows = 8; + int64 tile_cols = tiling_factor; + + string kernel_name = tensorflow::strings::StrCat( + "col_major_gemv_", PrimitiveType_Name(primitive_type), "_", tile_rows, + "_", tile_cols, "_", m, "_", k); + + KernelSupportLibrary::EmitAndCallOutlinedKernel( + /*enable_fast_math=*/enable_fast_math, + /*optimize_for_size=*/optimize_for_size, ir_builder_, kernel_name, + lhs_op, rhs_op, result_op, + [this, tile_rows, tile_cols, m, k, primitive_type]( + llvm::Value* lhs_op, llvm::Value* rhs_op, llvm::Value* result_op) { + ColumnMajorMatrixVectorProductEmitter emitter( + primitive_type, tile_rows, tile_cols, m, k, lhs_op, rhs_op, + result_op, ir_builder_); + emitter.Emit(); + }); + } else { + VLOG(2) << "Emitting row major matrix-vector multiply with m = " << m + << " and k = " << k; + int64 tile_rows = tiling_factor; + int64 tile_cols = 8; + + string kernel_name = tensorflow::strings::StrCat( + "row_major_gemv_", PrimitiveType_Name(primitive_type), "_", tile_rows, + "_", tile_cols, "_", m, "_", k); + + KernelSupportLibrary::EmitAndCallOutlinedKernel( + /*enable_fast_math=*/enable_fast_math, + /*optimize_for_size=*/optimize_for_size, ir_builder_, kernel_name, + lhs_op, rhs_op, result_op, + [this, tile_rows, tile_cols, m, k, primitive_type]( + llvm::Value* lhs_op, llvm::Value* rhs_op, llvm::Value* result_op) { + RowMajorMatrixVectorProductEmitter emitter( + primitive_type, tile_rows, tile_cols, m, k, lhs_op, rhs_op, + result_op, ir_builder_); + emitter.Emit(); + }); + } + + return true; +} + tensorflow::Status DotOpEmitter::Emit() { // The dot operation performs a sum of products over dimension 0 of the left // hand side operand and dimension 1 of the right hand side operand. @@ -105,6 +673,10 @@ tensorflow::Status DotOpEmitter::Emit() { return EmitScalarDot(); } + if (EmitLlvmIrDotIfProfitable()) { + return Status::OK(); + } + if (PotentiallyImplementedAsEigenDot(dot_)) { return EmitCallToRuntime(); } @@ -340,22 +912,17 @@ tensorflow::Status DotOpEmitter::EmitCallToRuntime() { // // Effectively this involves swapping the 'lhs' with 'rhs' and 'm' with 'n'. - const Shape& lhs_shape = lhs_array_.GetShape(); - const Shape& rhs_shape = rhs_array_.GetShape(); + MatMultDims mat_mult_dims = GetMatMultDims(); - CHECK(LayoutUtil::Equal(lhs_shape.layout(), rhs_shape.layout())); + CHECK_EQ(mat_mult_dims.lhs_column_major, mat_mult_dims.rhs_column_major); - int64 m = lhs_shape.dimensions(transpose_lhs_ ? 1 : 0); - int64 k = lhs_shape.dimensions(transpose_lhs_ ? 0 : 1); - int64 n = rhs_shape.dimensions(transpose_rhs_ ? 0 : 1); const llvm_ir::IrArray* lhs = &lhs_array_; const llvm_ir::IrArray* rhs = &rhs_array_; bool transpose_lhs = transpose_lhs_; bool transpose_rhs = transpose_rhs_; - bool is_column_major = lhs_shape.layout().minor_to_major(0) == 0; - if (!is_column_major) { - std::swap(m, n); + if (!mat_mult_dims.lhs_column_major) { + std::swap(mat_mult_dims.m, mat_mult_dims.n); std::swap(lhs, rhs); std::swap(transpose_lhs, transpose_rhs); } @@ -367,12 +934,27 @@ tensorflow::Status DotOpEmitter::EmitCallToRuntime() { float_ptr_type), ir_builder_->CreateBitCast(lhs->GetBasePointer(), float_ptr_type), ir_builder_->CreateBitCast(rhs->GetBasePointer(), float_ptr_type), - ir_builder_->getInt64(m), ir_builder_->getInt64(n), - ir_builder_->getInt64(k), ir_builder_->getInt32(transpose_lhs), + ir_builder_->getInt64(mat_mult_dims.m), + ir_builder_->getInt64(mat_mult_dims.n), + ir_builder_->getInt64(mat_mult_dims.k), + ir_builder_->getInt32(transpose_lhs), ir_builder_->getInt32(transpose_rhs)}); return tensorflow::Status::OK(); } +DotOpEmitter::MatMultDims DotOpEmitter::GetMatMultDims() const { + CHECK_EQ(dot_.shape().dimensions_size(), 2); + + const Shape& lhs_shape = lhs_array_.GetShape(); + const Shape& rhs_shape = rhs_array_.GetShape(); + + return {lhs_shape.dimensions(transpose_lhs_ ? 1 : 0), + lhs_shape.dimensions(transpose_lhs_ ? 0 : 1), + rhs_shape.dimensions(transpose_rhs_ ? 0 : 1), + lhs_shape.layout().minor_to_major(0) == 0, + rhs_shape.layout().minor_to_major(0) == 0}; +} + llvm_ir::IrArray::Index DotOpEmitter::EmitOperandArrayLoopNest( llvm_ir::ForLoopNest* loop_nest, const llvm_ir::IrArray& operand_array, int64 reduction_dimension, tensorflow::StringPiece name_suffix) { @@ -403,5 +985,82 @@ llvm_ir::IrArray::Index DotOpEmitter::EmitOperandArrayLoopNest( return index; } +// Return whether the given shape is a matrix with no padding. +static bool IsRank2WithNoPadding(const Shape& shape) { + return ShapeUtil::Rank(shape) == 2 && !LayoutUtil::IsPadded(shape); +} + +// In a gemm operation where output = lhs * rhs, check whether the given shapes +// are valid for the operation. +static bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape, + const Shape& output_shape) { + // The inputs and the output must + // 1) be matrices with no padding, and + // 2) have an allowed element type. + return output_shape.element_type() == F32 && + IsRank2WithNoPadding(lhs_shape) && IsRank2WithNoPadding(rhs_shape) && + IsRank2WithNoPadding(output_shape); +} + +bool PotentiallyImplementedAsEigenDot(const HloInstruction& hlo) { + // For certain types of Dot, we can call Eigen + if (hlo.opcode() == HloOpcode::kDot) { + const Shape& lhs_shape = hlo.operand(0)->shape(); + const Shape& rhs_shape = hlo.operand(1)->shape(); + + if (ShapeUtil::HasZeroElements(lhs_shape) || + ShapeUtil::HasZeroElements(rhs_shape)) { + return false; + } + + if (ProfitableToImplementDotInTiledLlvmIr(hlo)) { + return false; + } + + // If gemm can accept the operand shapes, use it rather than a custom + // kernel. + if (AreValidGemmShapes(lhs_shape, rhs_shape, hlo.shape())) { + // The size of the reduction dimension should match. The shape inference + // guarantees this invariant, so the check here is for programming + // errors. + CHECK_EQ(lhs_shape.dimensions(1), rhs_shape.dimensions(0)); + return true; + } + } + + if (hlo.opcode() == HloOpcode::kFusion && + hlo.fusion_kind() == HloInstruction::FusionKind::kTransposeDot && + hlo.fused_expression_root()->opcode() == HloOpcode::kDot) { + auto* dot = hlo.fused_expression_root(); + const Shape& lhs_shape = dot->operand(0)->shape(); + const Shape& rhs_shape = dot->operand(1)->shape(); + if (ShapeUtil::HasZeroElements(lhs_shape) || + ShapeUtil::HasZeroElements(rhs_shape)) { + return false; + } + return true; + } + + return false; +} + +// For vector-matrix dot products, it is always profitable to make the Rhs +// column major. +bool ProfitableToMakeDotRhsColumnMajor(const HloInstruction& hlo) { + return hlo.opcode() == HloOpcode::kDot && + hlo.shape().dimensions_size() == 2 && hlo.shape().dimensions(0) == 1; +} + +bool ProfitableToImplementDotInTiledLlvmIr(const HloInstruction& dot) { + // Any Matrix-Vector product of floating point or integral type, or + // a transpose-dot fusion of the same can be lowered to a tiled LLVM + // IR implementation. + const Shape& shape = dot.shape(); + return shape.dimensions_size() == 2 && + (shape.dimensions(0) == 1 || shape.dimensions(1) == 1) && + (primitive_util::IsFloatingPointType(shape.element_type()) || + primitive_util::IsIntegralType(shape.element_type())); +} + } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h index cfc10660453c822635d68270c053977fca779ee1..2badb26f905d6f1fe6de00401f7800b774f44c07 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_DOT_OP_EMITTER_H_ #include "llvm/IR/IRBuilder.h" +#include "tensorflow/compiler/xla/service/cpu/cpu_options.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" @@ -29,6 +30,16 @@ limitations under the License. namespace xla { namespace cpu { +bool PotentiallyImplementedAsEigenDot(const HloInstruction& hlo); + +// Returns true to indicate that |hlo| is a dot, and that it is profitable to +// switch the layout of the |hlo|'s RHS operand to column major. +bool ProfitableToMakeDotRhsColumnMajor(const HloInstruction& hlo); + +// Returns true to indicate that we can generate a tiled LLVM IR implementation +// for |dot|. +bool ProfitableToImplementDotInTiledLlvmIr(const HloInstruction& dot); + // Helper class for emitting LLVM IR to perform the dot operation. class DotOpEmitter { public: @@ -59,6 +70,10 @@ class DotOpEmitter { // LHS and RHS) and store the results in the target. tensorflow::Status EmitScalarDot(); + // Emit an LLVM IR implementation of the dot operation if we can. Returns + // true if an LLVM IR implementation was emitted. + bool EmitLlvmIrDotIfProfitable(); + // Emits a call to the CPU runtime to perform the matrix multiply. tensorflow::Status EmitCallToRuntime(); @@ -77,6 +92,38 @@ class DotOpEmitter { // no padding, and a rank of two. bool ShapesAreLegalForRuntimeDot() const; + // Represents the dimensions of a matrix-matrix multiply operation. + struct MatMultDims { + // The number of rows in the LHS. + int64 m; + + // The number of columns in the LHS, which is also must be equal to the + // number of rows in the RHS. + int64 k; + + // The number of columns on the RHS. + int64 n; + + // True if the LHS matrix column major. + bool lhs_column_major; + + // True if the RHS matrix column major. + bool rhs_column_major; + }; + + // Get the MatMultDims instance for the dot product this DotOpEmitter + // represents. Precondition: the dot is of rank 2 (and thus its operands are + // of rank 2 as well). + MatMultDims GetMatMultDims() const; + + // When doing a tiled GEMV in LLVM IR, a "tile" consists of this many vector + // registers. + int64 GetGemvTilingFactor() const { + const int64 kDefaultTilingFactor = 8; + return options::LlvmIrGemvTilingFactor(hlo_module_config_) + .value_or(kDefaultTilingFactor); + } + const HloInstruction& dot_; const bool transpose_lhs_; const bool transpose_rhs_; diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc index b99b36a55eee40bc66dcb1b7b1a464bf764ef0ea..3993779da636e519f8d8fded468c3271d27ee093 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc @@ -29,10 +29,8 @@ bool PotentiallyImplementedAsEigenConvolution( // The following conditions are necessary (but not sufficient) for // implementing `convolution` with Eigen convolution: // - the input and kernel have a non-zero number of elements. - // - the input is in NHWC or NWHC order. - // - the kernel is in HWIO or WHIO order. - // - the spatial dimensions are in the same relative order in the input, - // kernel and output. + // - the input is in NHWC order. + // - the kernel is in HWIO order. // // To be sufficient, certain layout constraints need to be satisfied as well. const Shape& input_shape = convolution.operand(0)->shape(); @@ -51,15 +49,22 @@ bool PotentiallyImplementedAsEigenConvolution( convolution.convolution_dimension_numbers(); // Only 1D and 2D convolutions are supported at the moment. // TODO(b/32897908): add an optimized implementation for 3D convolution. - if (dnums.spatial_dimensions_size() > 2) { + const int64 num_spatial_dims = dnums.output_spatial_dimensions_size(); + if (num_spatial_dims > 2) { return false; } - bool input_spatial_dims_ascending = std::is_sorted( - dnums.spatial_dimensions().begin(), dnums.spatial_dimensions().end()); - bool kernel_spatial_dims_ascending = - std::is_sorted(dnums.kernel_spatial_dimensions().begin(), - dnums.kernel_spatial_dimensions().end()); + for (int64 i = 0; i < num_spatial_dims; ++i) { + if (dnums.input_spatial_dimensions(i) != i + 1) { + return false; + } + if (dnums.kernel_spatial_dimensions(i) != i) { + return false; + } + if (dnums.output_spatial_dimensions(i) != i + 1) { + return false; + } + } const Shape& output_shape = convolution.shape(); return dnums.input_batch_dimension() == 0 && @@ -67,116 +72,11 @@ bool PotentiallyImplementedAsEigenConvolution( dnums.output_batch_dimension() == 0 && dnums.output_feature_dimension() == output_shape.dimensions_size() - 1 && - input_spatial_dims_ascending == kernel_spatial_dims_ascending && dnums.kernel_input_feature_dimension() == kernel_shape.dimensions_size() - 2 && dnums.kernel_output_feature_dimension() == kernel_shape.dimensions_size() - 1; } -namespace { - -// Return whether the given shape is a matrix with no padding. -bool IsRank2WithNoPadding(const Shape& shape) { - return ShapeUtil::Rank(shape) == 2 && !LayoutUtil::IsPadded(shape); -} - -// In a gemm operation where output = lhs * rhs, check whether the given shapes -// are valid for the operation. -bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape, - const Shape& output_shape) { - // The inputs and the output must - // 1) be matrices with no padding, and - // 2) have an allowed element type. - return output_shape.element_type() == F32 && - IsRank2WithNoPadding(lhs_shape) && IsRank2WithNoPadding(rhs_shape) && - IsRank2WithNoPadding(output_shape); -} -} // namespace - -bool PotentiallyImplementedAsEigenDot(const HloInstruction& hlo) { - // For certain types of Dot, we can call Eigen - if (hlo.opcode() == HloOpcode::kDot) { - const Shape& lhs_shape = hlo.operand(0)->shape(); - const Shape& rhs_shape = hlo.operand(1)->shape(); - - if (ShapeUtil::HasZeroElements(lhs_shape) || - ShapeUtil::HasZeroElements(rhs_shape)) { - return false; - } - - if (ProfitableToImplementDotInLlvmIr(hlo) == DotInLlvmIrProfitable::kYes) { - return false; - } - - // If gemm can accept the operand shapes, use it rather than a custom - // kernel. - if (AreValidGemmShapes(lhs_shape, rhs_shape, hlo.shape())) { - // The size of the reduction dimension should match. The shape inference - // guarantees this invariant, so the check here is for programming - // errors. - CHECK_EQ(lhs_shape.dimensions(1), rhs_shape.dimensions(0)); - return true; - } - } - - if (hlo.opcode() == HloOpcode::kFusion && - hlo.fusion_kind() == HloInstruction::FusionKind::kTransposeDot && - hlo.fused_expression_root()->opcode() == HloOpcode::kDot) { - auto* dot = hlo.fused_expression_root(); - const Shape& lhs_shape = dot->operand(0)->shape(); - const Shape& rhs_shape = dot->operand(1)->shape(); - if (ShapeUtil::HasZeroElements(lhs_shape) || - ShapeUtil::HasZeroElements(rhs_shape)) { - return false; - } - return true; - } - - return false; -} - -DotInLlvmIrProfitable ProfitableToImplementDotInLlvmIr( - const HloInstruction& dot) { - if (dot.opcode() == HloOpcode::kDot && dot.shape().dimensions_size() == 2) { - const Shape& result_shape = dot.shape(); - // kReductionDimensionThresholdBytes was chosen to be 1/4 of a typical L1 - // cache line size, so that we can have the reduction dimension of both the - // LHS and RHS matrices and still have some space "left over". This needs - // to be tuned further. - const int64 kReductionDimensionThresholdBytes = 8 * 1024; - const bool single_threaded_eigen = - !dot.GetModule()->config().debug_options().xla_cpu_multi_thread_eigen(); - - // This is the point at which it is better to call into Eigen and shard the - // dot across multiple worker threads. This is a rough estimate by running - // a matmult benchmark on my local machine, and it can be tuned further. - const int64 kMaxSingleThreadedFlops = 16 * 1024; - - const int64 M = result_shape.dimensions(0); - const int64 N = result_shape.dimensions(1); - const int64 K = dot.operand(1)->shape().dimensions(0); - const int64 primitive_type_size = - ShapeUtil::ByteSizeOfPrimitiveType(result_shape.element_type()); - if (M == 1 && - K * primitive_type_size <= kReductionDimensionThresholdBytes && - (single_threaded_eigen || M * K * N <= kMaxSingleThreadedFlops)) { - // Heuristics: - // - // - Look for a configuration where we will likely be able to keep LHS in - // L1 and do a cache-optimal traversal of RHS. - // - // - Bail out on matrices that are large enough that Eigen can profitably - // shard the computation across multiple cores. This only applies when - // multi-threading is enabled. - return LayoutUtil::IsMonotonicWithDim0Major( - dot.operand(1)->shape().layout()) - ? DotInLlvmIrProfitable::kWithColumnMajorRhs - : DotInLlvmIrProfitable::kYes; - } - } - return DotInLlvmIrProfitable::kNo; -} - } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.h index 66656ed99765806ec4463f3781644853886cf303..34b2003916933f5ec0a15d9e219063c0a912fa40 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_EMISSION_UTILS_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_EMISSION_UTILS_H_ +#include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" namespace xla { @@ -24,20 +25,17 @@ namespace cpu { bool PotentiallyImplementedAsEigenConvolution( const HloInstruction& convolution); -bool PotentiallyImplementedAsEigenDot(const HloInstruction& dot); - -enum class DotInLlvmIrProfitable { kYes, kNo, kWithColumnMajorRhs }; - -// Returns a value to indicate if (and under what conditions) will lowering -// |dot| as a pure LLVM IR dot operation be profitable over calling into Eigen. -// Possible return values are: +// Dynamic loop bounds are specified as an array of dimension index +// [start, limit) pairs of ir values (one for each partitioned outer dimension). +// +// EX: Let 'shape' = [8, 16, 32], with the loop bounds of the two-most major +// dimensions dynamic. Then 'dynamic_loop_bounds' will contain the +// following ir values for the two most-major dimensions: +// [dim0_index_start_ir_value, dim0_index_limit_ir_value] +// [dim1_index_start_ir_value, dim1_index_limit_ir_value] // -// * DotInLlvmIrProfitable::kYes - always profitable. -// * DotInLlvmIrProfitable::kNo - never profitable. -// * DotInLlvmIrProfitable::kWithColumnMajorRhs - only if we can manage to make -// the Rhs layout column major. -DotInLlvmIrProfitable ProfitableToImplementDotInLlvmIr( - const HloInstruction& dot); +// See IrFunction and ParallelLoopEmitter for details. +using DynamicLoopBounds = std::vector>; } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index a20ce6826ca0a86f8c0d441c1e89f091cfb434f1..c82a0c7ef4a797d9e1cf853badc84130a3e062b1 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -24,16 +24,17 @@ limitations under the License. #include #include +#include "tensorflow/core/lib/math/math_util.h" #include "tensorflow/core/platform/logging.h" // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc" +#include "llvm/CodeGen/TargetRegisterInfo.h" +#include "llvm/CodeGen/TargetSubtargetInfo.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constants.h" #include "llvm/IR/GlobalVariable.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/LLVMContext.h" -#include "llvm/Target/TargetRegisterInfo.h" -#include "llvm/Target/TargetSubtargetInfo.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" @@ -42,6 +43,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h" #include "tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h" #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/cpu/ir_function.h" +#include "tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h" #include "tensorflow/compiler/xla/service/cpu/shape_partition.h" #include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h" #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h" @@ -76,14 +79,16 @@ namespace cpu { IrEmitter::IrEmitter( const HloModule& hlo_module, const BufferAssignment& assignment, llvm::Module* llvm_module, - const std::unordered_map* hlo_to_profile_idx, + std::unordered_map hlo_to_profile_idx, + tensorflow::gtl::optional entry_computation_profile_idx, llvm::TargetMachine* target_machine, ExternalConstantPool* external_constant_pool) : assignment_(assignment), module_(llvm_module), arch_type_(llvm::Triple(llvm_module->getTargetTriple()).getArch()), ir_builder_(llvm_module->getContext()), - hlo_to_profile_idx_(hlo_to_profile_idx), + hlo_to_profile_idx_(std::move(hlo_to_profile_idx)), + entry_computation_profile_idx_(std::move(entry_computation_profile_idx)), alias_analysis_(hlo_module, assignment, &llvm_module->getContext()), hlo_module_config_(hlo_module.config()), parallel_cpu_backend_( @@ -122,133 +127,27 @@ StatusOr IrEmitter::EmitComputation( } else { TF_RETURN_IF_ERROR(computation->AcceptOrdered(this, *instruction_order)); } - InsertOrDie(&emitted_functions_, computation, compute_function_); - - return compute_function_; -} - -static llvm::Argument* GetArg(llvm::Function* f, int idx) { - llvm::Function::arg_iterator arg_iter = f->arg_begin(); - std::advance(arg_iter, idx); - return &*arg_iter; + llvm::Function* ir_function = compute_function_->function(); + InsertOrDie(&emitted_functions_, computation, ir_function); + // Delete 'compute_function', finalizing 'ir_function' and restoring caller + // IR insert point. + compute_function_.reset(); + return ir_function; } void IrEmitter::InitializeIrFunction(const string& function_name) { - // The function signature is: - // void function(i8* retval, i8* run_options, i8** params, i8** temps, - // i64* dynamic_loop_bounds, i64* prof_counters) - // - // retval: points to the returned value. - // params: address of an array with pointers to parameters. - // temps: address of an array with pointers to temporary buffers. - // - // Therefore, the generated function's signature (FunctionType) is statically - // determined - parameter unpacking is done in code generated into the - // function, rather than by a prologue dictated by the platform ABI. - // - // /--------------\ - // retval ----------> | return value | - // \--------------/ - // - // /-------------------------------\ - // run_options -----> | xla::ExecutableRunOptions | - // \-------------------------------/ - // - // /---------------------------------------------\ - // params --------> | param 0 | param 1 | ..... | param N-1 | - // | addr | addr | | addr | - // \---------------------------------------------/ - // | | | - // | | | - // V V V - // /---------\ /---------\ /-----------\ - // | param 0 | | param 1 | | param N-1 | - // \---------/ \---------/ \-----------/ - // - // /---------------------------------------------\ - // temps ---------> | temp 0 | temp 1 | ..... | temp N-1 | - // | addr | addr | | addr | - // \---------------------------------------------/ - // | | | - // | | | - // V V V - // /---------\ /---------\ /-----------\ - // | temp 0 | | temp 1 | | temp N-1 | - // \---------/ \---------/ \-----------/ - // - // /--------------------------------------------\ - // dynamic loop bounds -> | outer_dim0_start | outer_dim0_limit | .....| - // (elided for aot) \--------------------------------------------/ - // - // /---------------------------------------------\ - // prof counters -> | counter 0 | counter 1 | ..... | counter N-1 | - // (elided for aot) \---------------------------------------------/ - - // Even though the type of params and temps is void** in the host's view, in - // LLVM IR this is represented by i8*, similarly to void*. It's up to the code - // to use GEPs to unravel the indirection layers. - llvm::FunctionType* compute_function_type = llvm::FunctionType::get( - /*Result=*/llvm::Type::getVoidTy(module_->getContext()), - /*Params=*/GetComputeFunctionParams(), - /*isVarArg=*/false); - // Functions with local linkage get an inlining bonus. Because we know // a-priori that embedded functions (non-entry functions) will not have its // name resolved, give it local linkage. llvm::Function::LinkageTypes linkage = is_top_level_computation_ ? llvm::GlobalValue::ExternalLinkage : llvm::GlobalValue::InternalLinkage; - compute_function_ = - llvm::Function::Create(/*Ty=*/compute_function_type, - /*Linkage=*/linkage, - /*Name=*/AsStringRef(function_name), - /*Module=*/module_); - compute_function_->setCallingConv(llvm::CallingConv::C); - - // Set meaningful names for the function's arguments: useful for debugging. - llvm::Function::arg_iterator arg_iter = compute_function_->arg_begin(); - arg_iter->setName("retval"); - (++arg_iter)->setName("run_options"); - (++arg_iter)->setName("params"); - (++arg_iter)->setName("temps"); - if (num_dynamic_loop_bounds_ > 0) { - (++arg_iter)->setName("dynamic_loop_bounds"); - } - if (hlo_to_profile_idx_) { - (++arg_iter)->setName("prof_counters"); - } - - // We know a-priori that the function arguments are guaranteed to point to - // disjoint objects. - llvm::Argument* retval = GetResultArgument(); - for (llvm::Argument& argument : compute_function_->args()) { - // However, the return buffer aliases the temporaries and thus cannot be - // marked noalias. - if (&argument == retval) { - continue; - } - compute_function_->addAttribute(argument.getArgNo() + 1, - llvm::Attribute::NoAlias); - } - - // Add the optize attribute to the function if optimizing for size. This - // controls internal behavior of some optimization passes (e.g. loop - // unrolling). - if (options::OptimizeForSizeRequested(hlo_module_config_)) { - compute_function_->addFnAttr(llvm::Attribute::OptimizeForSize); - } - - if (hlo_module_config_.debug_options().xla_enable_fast_math()) { - compute_function_->addFnAttr("unsafe-fp-math", "true"); - compute_function_->addFnAttr("no-infs-fp-math", "true"); - compute_function_->addFnAttr("no-nans-fp-math", "true"); - compute_function_->addFnAttr("no-signed-zeros-fp-math", "true"); - } - - ir_builder_.SetInsertPoint(llvm::BasicBlock::Create( - /*Context=*/module_->getContext(), - /*Name=*/"entry", - /*Parent=*/compute_function_)); + // Create and initialize new IrFunction. + compute_function_.reset( + new IrFunction(function_name, linkage, + options::OptimizeForSizeRequested(hlo_module_config_), + hlo_module_config_.debug_options().xla_enable_fast_math(), + module_, &ir_builder_, num_dynamic_loop_bounds_)); } IrEmitter::~IrEmitter() {} @@ -344,11 +243,12 @@ int IrEmitter::MinimumAlignmentForBufferSize(int64 buffer_size) { // Calculate the alignment of a buffer allocated for a given primitive type. int IrEmitter::MinimumAlignmentForPrimitiveType(PrimitiveType primitive_type) { - int64 buffer_size = ShapeUtil::ByteSizeOfPrimitiveType(primitive_type); - DCHECK_GE(buffer_size, 0); - DCHECK_LE(buffer_size, SIZE_MAX); - - return MinimumAlignmentForBufferSize(buffer_size); + int64 byte_size = ShapeUtil::ByteSizeOfPrimitiveType(primitive_type); + DCHECK_GE(byte_size, 0); + // Largest scalar is a complex64 so we don't need to worry about the + // int64->int truncation here. + DCHECK_LE(byte_size, 8); + return byte_size; } int64 IrEmitter::ByteSizeOf(const Shape& shape) const { @@ -357,6 +257,10 @@ int64 IrEmitter::ByteSizeOf(const Shape& shape) const { // Calculate the alignment of a buffer allocated for a given shape. int IrEmitter::MinimumAlignmentForShape(const Shape& shape) { + if (ShapeUtil::IsScalar(shape)) { + return MinimumAlignmentForPrimitiveType(shape.element_type()); + } + int64 buffer_size = ByteSizeOf(shape); DCHECK_GE(buffer_size, 0); DCHECK_LE(buffer_size, SIZE_MAX); @@ -612,7 +516,7 @@ Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) { HloComputation* function = reduce_window->to_apply(); TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( /*instruction=*/*reduce_window, /*operands=*/{operand}, - /*supported_types=*/{F32})); + /*supported_types=*/{F32, BF16})); // TODO(b/31410564): Implement dilation for reduce-window. if (window_util::HasDilation(window)) { @@ -795,7 +699,7 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { // operand index is within the bounds. The unsigned comparison includes // checking whether the operand index >= 0. llvm_ir::IrArray::Index operand_index(source_index.size()); - llvm::Value* in_bounds_condition = ir_builder_.getInt1(true); + llvm::Value* in_bounds_condition = ir_builder_.getTrue(); for (int64 i = 0; i < rank; ++i) { llvm::Value* strided_index = ir_builder_.CreateNSWMul( source_index[i], ir_builder_.getInt64(window.dimensions(i).stride())); @@ -822,14 +726,16 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { // If the initialized_flag is false, initialize the selected value and index // with the currently visiting operand. SetToFirstInsertPoint(if_initialized.false_block, &ir_builder_); - const auto save_operand_index = [&]( - const llvm_ir::IrArray::Index& operand_index) { - for (int64 i = 0; i < rank; ++i) { - llvm::Value* selected_index_address_slot = ir_builder_.CreateInBoundsGEP( - selected_index_address, {ir_builder_.getInt32(i)}); - ir_builder_.CreateStore(operand_index[i], selected_index_address_slot); - } - }; + const auto save_operand_index = + [&](const llvm_ir::IrArray::Index& operand_index) { + for (int64 i = 0; i < rank; ++i) { + llvm::Value* selected_index_address_slot = + ir_builder_.CreateInBoundsGEP(selected_index_address, + {ir_builder_.getInt32(i)}); + ir_builder_.CreateStore(operand_index[i], + selected_index_address_slot); + } + }; llvm_ir::IrArray operand_array(GetIrArrayFor(operand)); llvm::Value* operand_data = operand_array.EmitReadArrayElement(operand_index, &ir_builder_); @@ -896,6 +802,24 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( /*instruction=*/*dot, /*operands=*/{lhs, rhs}, /*supported_types=*/{F32, F64, C64})); + const DotDimensionNumbers& dnums = dot->dot_dimension_numbers(); + if (dnums.lhs_batch_dimensions_size() > 0 || + dnums.rhs_batch_dimensions_size() > 0) { + return Unimplemented("Dot with batch dimensions not implemented."); + } + + if (dnums.lhs_contracting_dimensions_size() != 1) { + // This is disallowed by ShapeInference today. + return Unimplemented( + "Dot with multiple contracting dimensions not implemented."); + } + + if (dnums.lhs_contracting_dimensions(0) != + std::min(lhs->shape().dimensions_size() - 1, 1) || + dnums.rhs_contracting_dimensions(0) != 0) { + return Unimplemented( + "Dot with non-standard contracting dimensions not implemented."); + } llvm_ir::IrArray lhs_array(GetIrArrayFor(lhs)); llvm_ir::IrArray rhs_array(GetIrArrayFor(rhs)); @@ -952,11 +876,12 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { // Input tensor. const Shape& input_shape = convolution->operand(0)->shape(); int64 input_batch = input_shape.dimensions(dnums.input_batch_dimension()); - int64 input_rows = input_shape.dimensions(dnums.spatial_dimensions(0)); + int64 input_rows = + input_shape.dimensions(dnums.input_spatial_dimensions(0)); int64 input_cols = one_dim_convolution ? 1 - : input_shape.dimensions(dnums.spatial_dimensions(1)); + : input_shape.dimensions(dnums.input_spatial_dimensions(1)); int64 input_channels = input_shape.dimensions(dnums.input_feature_dimension()); @@ -976,11 +901,11 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { // Output tensor. const Shape& convolution_shape = convolution->shape(); int64 output_rows = - convolution_shape.dimensions(dnums.spatial_dimensions(0)); - int64 output_cols = - one_dim_convolution - ? 1 - : convolution_shape.dimensions(dnums.spatial_dimensions(1)); + convolution_shape.dimensions(dnums.output_spatial_dimensions(0)); + int64 output_cols = one_dim_convolution + ? 1 + : convolution_shape.dimensions( + dnums.output_spatial_dimensions(1)); // Extract the window stride for the convolution. const Window& window = convolution->window(); @@ -1068,10 +993,10 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { return EmitTargetElementLoop( convolution, [this, convolution, lhs, rhs, window, dnums](const llvm_ir::IrArray::Index& index) { - int num_spatial_dims = dnums.spatial_dimensions_size(); + int num_spatial_dims = dnums.output_spatial_dimensions_size(); std::vector output_spatial(num_spatial_dims); for (int i = 0; i < num_spatial_dims; ++i) { - output_spatial[i] = index[dnums.spatial_dimensions(i)]; + output_spatial[i] = index[dnums.output_spatial_dimensions(i)]; } llvm::Value* output_feature = index[dnums.output_feature_dimension()]; llvm::Value* batch = index[dnums.output_batch_dimension()]; @@ -1091,8 +1016,9 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { for (int i = 0; i < num_spatial_dims; ++i) { kernel_spatial[i] = loops - .AddLoop(0, rhs->shape().dimensions( - dnums.kernel_spatial_dimensions(i)), + .AddLoop(0, + rhs->shape().dimensions( + dnums.kernel_spatial_dimensions(i)), tensorflow::strings::StrCat("k", i)) ->GetIndVarValue(); } @@ -1108,17 +1034,18 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { // Calculate the spatial index in the input array, taking striding, // dilation and padding into account. An index in the padding will be // out of the bounds of the array. - const auto calculate_input_index = [this]( - llvm::Value* output_index, llvm::Value* kernel_index, - const WindowDimension& window_dim) { - llvm::Value* strided_index = ir_builder_.CreateNSWMul( - output_index, ir_builder_.getInt64(window_dim.stride())); - llvm::Value* dilated_kernel_index = ir_builder_.CreateNSWMul( - kernel_index, ir_builder_.getInt64(window_dim.window_dilation())); - return ir_builder_.CreateNSWSub( - ir_builder_.CreateNSWAdd(strided_index, dilated_kernel_index), - ir_builder_.getInt64(window_dim.padding_low())); - }; + const auto calculate_input_index = + [this](llvm::Value* output_index, llvm::Value* kernel_index, + const WindowDimension& window_dim) { + llvm::Value* strided_index = ir_builder_.CreateNSWMul( + output_index, ir_builder_.getInt64(window_dim.stride())); + llvm::Value* dilated_kernel_index = ir_builder_.CreateNSWMul( + kernel_index, + ir_builder_.getInt64(window_dim.window_dilation())); + return ir_builder_.CreateNSWSub( + ir_builder_.CreateNSWAdd(strided_index, dilated_kernel_index), + ir_builder_.getInt64(window_dim.padding_low())); + }; std::vector input_spatial(num_spatial_dims); for (int i = 0; i < num_spatial_dims; ++i) { input_spatial[i] = calculate_input_index( @@ -1140,11 +1067,11 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { return ir_builder_.CreateICmpEQ(remainder, ir_builder_.getInt64(0)); }; - llvm::Value* in_bounds_condition = nullptr; + llvm::Value* in_bounds_condition = ir_builder_.getInt1(true); for (int i = 0; i < num_spatial_dims; ++i) { llvm::ConstantInt* input_bound = ir_builder_.getInt64(window_util::DilatedBound( - lhs->shape().dimensions(dnums.spatial_dimensions(i)), + lhs->shape().dimensions(dnums.input_spatial_dimensions(i)), window.dimensions(i).base_dilation())); llvm::Value* dim_in_bound = ir_builder_.CreateICmpULT(input_spatial[i], input_bound); @@ -1153,9 +1080,7 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { llvm::Value* dim_ok = ir_builder_.CreateAnd(dim_in_bound, dim_not_in_hole); in_bounds_condition = - in_bounds_condition - ? ir_builder_.CreateAnd(in_bounds_condition, dim_ok) - : dim_ok; + ir_builder_.CreateAnd(in_bounds_condition, dim_ok); } // Now we need to map the dilated base coordinates back to the actual @@ -1178,7 +1103,7 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { int num_dims = num_spatial_dims + 2; llvm_ir::IrArray::Index input_index(num_dims); for (int i = 0; i < num_spatial_dims; ++i) { - input_index[dnums.spatial_dimensions(i)] = input_spatial[i]; + input_index[dnums.input_spatial_dimensions(i)] = input_spatial[i]; } input_index[dnums.input_feature_dimension()] = input_feature; input_index[dnums.input_batch_dimension()] = batch; @@ -1449,7 +1374,7 @@ Status IrEmitter::HandleParameter(HloInstruction* parameter) { // // Where Param is the actual element type of the underlying buffer (for // example, float for an XLA F32 element type). - llvm::Argument* params = GetArg(compute_function_, 2); + llvm::Value* params = compute_function_->parameters_arg(); llvm::Value* param_address_offset = llvm_ir::EmitBufferIndexingGEP(params, param_number, &ir_builder_); llvm::LoadInst* param_address_untyped = @@ -1587,7 +1512,7 @@ IrEmitter::ShardedVectorType IrEmitter::CreateShardedVectorType( // Here we assume that the largest register is a vector register. int max_vector_register_size_in_bytes = target_machine_features_.largest_register_size_in_bytes( - compute_function_); + compute_function_->function()); int vector_register_size_in_elements = max_vector_register_size_in_bytes / @@ -1745,19 +1670,6 @@ void IrEmitter::EmitShardedVectorStore( } } -namespace { -// TODO(sanjoy): This is duplicated in tensorflow/core/lib/core/arena.cc. -// Extract out a common implementation to tensorflow/core/lib/math/math_util.h -uint32 GCD(uint32 x, uint32 y) { - while (y != 0) { - uint32 r = x % y; - x = y; - y = r; - } - return x; -} -} // namespace - StatusOr IrEmitter::EmitVectorizedReduce( HloInstruction* reduce, HloInstruction* arg, HloInstruction* init_value, tensorflow::gtl::ArraySlice dimensions, HloComputation* function, @@ -1780,9 +1692,9 @@ StatusOr IrEmitter::EmitVectorizedReduce( std::find(dimensions.begin(), dimensions.end(), arg->shape().layout().minor_to_major(0)) != dimensions.end(); - unsigned element_alignment = - GCD(ShapeUtil::ByteSizeOfPrimitiveType(reduce->shape().element_type()), - MinimumAlignmentForPrimitiveType(reduce->shape().element_type())); + unsigned element_alignment = tensorflow::MathUtil::GCD( + ShapeUtil::ByteSizeOfPrimitiveType(reduce->shape().element_type()), + MinimumAlignmentForPrimitiveType(reduce->shape().element_type())); if (is_reduction_over_minor_dimension) { // TODO(sanjoy): Implement vectorized reduction over the minor dimension. @@ -1983,11 +1895,16 @@ Status IrEmitter::HandleSend(HloInstruction* send) { return Unimplemented("Send is not implemented on CPU. See b/33942983."); } +Status IrEmitter::HandleSendDone(HloInstruction* send_done) { + // TODO(b/33942983): Support Send/Recv on CPU. + return Unimplemented("Send-done is not implemented on CPU. See b/33942983."); +} + Status IrEmitter::HandleSlice(HloInstruction* slice) { VLOG(2) << "HandleSlice: " << slice->ToString(); auto operand = slice->operand(0); // The code below emits a sequential loop nest. For the parallel backend, use - // EmitParallelTargetElementLoop() which respects dynamic loop bounds. + // ParallelLoopEmitter which respects dynamic loop bounds. if (ShouldEmitParallelLoopFor(*slice)) { return DefaultAction(slice); } @@ -2148,6 +2065,11 @@ Status IrEmitter::HandleRecv(HloInstruction* recv) { return Unimplemented("Recv is not implemented on CPU. See b/33942983."); } +Status IrEmitter::HandleRecvDone(HloInstruction* recv_done) { + // TODO(b/33942983): Support Send/Recv on CPU. + return Unimplemented("Recv-done is not implemented on CPU. See b/33942983."); +} + Status IrEmitter::HandlePad(HloInstruction* pad) { // CPU backend does not properly handle negative padding but this is ok // because negative padding should be removed by the algebraic simplifier. @@ -2292,9 +2214,17 @@ Status IrEmitter::HandleCall(HloInstruction* call) { !parallel_cpu_backend_) { // ParallelTaskAssignment assigned partitions, emit call to // ParallelForkJoin. - TF_RETURN_IF_ERROR(EmitParallelForkJoin(parameter_addresses, - emitted_value_[call], computation, - call_ir_function)); + std::vector call_args = GetArrayFunctionCallArguments( + parameter_addresses, &ir_builder_, computation->name(), + /*return_value_buffer=*/emitted_value_[call], + /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(), + /*temp_buffers_arg=*/GetTempBuffersArgument(), + /*profile_counters_arg=*/GetProfileCountersArgument()); + + HloInstruction* root = computation->root_instruction(); + TF_RETURN_IF_ERROR(EmitCallToParallelForkJoin( + call_args, root->shape(), root->outer_dimension_partitions(), + &ir_builder_, call_ir_function, computation->name())); } else { EmitArrayFunctionCallInto(call_ir_function, parameter_addresses, emitted_value_[call], computation->name()); @@ -2397,7 +2327,7 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) { // Terminates the current block with a branch to a while header. llvm::BasicBlock* header_bb = llvm::BasicBlock::Create( module_->getContext(), AsStringRef(IrName(xla_while, "header")), - compute_function_); + compute_function_->function()); ir_builder_.CreateBr(header_bb); ir_builder_.SetInsertPoint(header_bb); @@ -2414,7 +2344,7 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) { // Branches to the body or to the while exit depending on the condition. llvm::BasicBlock* body_bb = llvm::BasicBlock::Create( module_->getContext(), AsStringRef(IrName(xla_while, "body")), - compute_function_); + compute_function_->function()); llvm::BasicBlock* exit_bb = llvm::BasicBlock::Create( module_->getContext(), AsStringRef(IrName(xla_while, "exit"))); ir_builder_.CreateCondBr(while_predicate, body_bb, exit_bb); @@ -2429,7 +2359,7 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) { ir_builder_.CreateBr(header_bb); // Adds the exit block to the function and sets the insert point there. - compute_function_->getBasicBlockList().push_back(exit_bb); + compute_function_->function()->getBasicBlockList().push_back(exit_bb); ir_builder_.SetInsertPoint(exit_bb); return Status::OK(); @@ -2547,7 +2477,7 @@ void IrEmitter::EmitTransferElements(llvm::Value* target, llvm::Value* source, const llvm_ir::IrArray& source_array) { unsigned primitive_type_size = ShapeUtil::ByteSizeOfPrimitiveType(primitive_type); - unsigned element_alignment = GCD( + unsigned element_alignment = tensorflow::MathUtil::GCD( primitive_type_size, MinimumAlignmentForPrimitiveType(primitive_type)); llvm::Type* primitive_ptr_type = llvm::PointerType::getUnqual( llvm_ir::PrimitiveTypeToIrType(primitive_type, module_)); @@ -2594,6 +2524,65 @@ Status IrEmitter::HandleConcatenate(HloInstruction* concatenate) { return DefaultAction(concatenate); } +Status IrEmitter::HandleConditional(HloInstruction* conditional) { + auto pred = conditional->operand(0); + auto true_arg = conditional->operand(1); + auto false_arg = conditional->operand(2); + TF_RET_CHECK(ShapeUtil::IsScalar(pred->shape()) && + pred->shape().element_type() == PRED) + << "Predicate on a Conditional must be bool; got: " + << ShapeUtil::HumanString(pred->shape()); + + HloComputation* true_computation = conditional->true_computation(); + HloComputation* false_computation = conditional->false_computation(); + TF_RET_CHECK(ShapeUtil::Equal(conditional->shape(), + true_computation->root_instruction()->shape())) + << "Shape of conditional should be same as the shape of the true " + << "computation; got: " << ShapeUtil::HumanString(conditional->shape()) + << " and " + << ShapeUtil::HumanString(true_computation->root_instruction()->shape()); + + TF_RET_CHECK(ShapeUtil::Equal(conditional->shape(), + false_computation->root_instruction()->shape())) + << "Shape of conditional should be same as the shape of the false " + << "computation; got: " << ShapeUtil::HumanString(conditional->shape()) + << " and " + << ShapeUtil::HumanString(false_computation->root_instruction()->shape()); + + llvm::Function* true_function = + FindOrDie(emitted_functions_, true_computation); + llvm::Function* false_function = + FindOrDie(emitted_functions_, false_computation); + + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(conditional)); + llvm::Value* conditional_result = GetEmittedValueFor(conditional); + + // Generating: + // if (pred) + // cond_result = true_computation(true_operand) + // else + // cond_result = false_computation(false_operand) + llvm::LoadInst* pred_value = ir_builder_.CreateLoad( + GetIrArrayFor(pred).GetBasePointer(), "load_predicate_value"); + llvm::Value* pred_cond = ir_builder_.CreateICmpNE( + pred_value, + llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0), + "boolean_predicate"); + llvm_ir::LlvmIfData if_data = + llvm_ir::EmitIfThenElse(pred_cond, "conditional", &ir_builder_); + + SetToFirstInsertPoint(if_data.true_block, &ir_builder_); + EmitArrayFunctionCallInto(true_function, {GetEmittedValueFor(true_arg)}, + conditional_result, IrName(conditional, "_true")); + + SetToFirstInsertPoint(if_data.false_block, &ir_builder_); + EmitArrayFunctionCallInto(false_function, {GetEmittedValueFor(false_arg)}, + conditional_result, IrName(conditional, "_false")); + + SetToFirstInsertPoint(if_data.after_block, &ir_builder_); + return Status::OK(); +} + Status IrEmitter::FinishVisit(HloInstruction* root) { // When this method is called, we should have already emitted an IR value for // the root (return) op. The IR value holds the address of the buffer holding @@ -2605,53 +2594,56 @@ Status IrEmitter::FinishVisit(HloInstruction* root) { llvm::Value* root_value = GetEmittedValueFor(root); VLOG(2) << " value: " << llvm_ir::DumpToString(*root_value); - // For the parallel cpu backend, we record the total for each embedded - // computation callee with its caller kCall HLO. - HloInstruction* hlo_to_lookup = nullptr; - if (parallel_cpu_backend_ && is_top_level_computation_) { - auto* computation = root->parent(); - auto* entry_computation = computation->parent()->entry_computation(); - if (computation != entry_computation) { - for (HloInstruction* instruction : entry_computation->instructions()) { - if (instruction->opcode() == HloOpcode::kCall && - instruction->to_apply()->root_instruction() == root) { - hlo_to_lookup = instruction; - break; + llvm::Value* prof_counter = [&]() { + // For the parallel cpu backend, we record the total for each embedded + // computation callee with its caller kCall HLO. + if (parallel_cpu_backend_ && is_top_level_computation_) { + auto* computation = root->parent(); + auto* entry_computation = computation->parent()->entry_computation(); + if (computation != entry_computation) { + for (HloInstruction* instruction : entry_computation->instructions()) { + if (instruction->opcode() == HloOpcode::kCall && + instruction->to_apply()->root_instruction() == root) { + return GetProfileCounterFor(*instruction); + } } } } - } - if (auto* prof_counter = GetProfileCounterFor(hlo_to_lookup)) { + + // Otherwise we record the total computation cycles in a dedicated slot for + // the entry computation. + return GetProfileCounterForEntryComputation(); + }(); + + if (prof_counter) { profiling_state_.RecordCompleteComputation(&ir_builder_, prof_counter); } - - ir_builder_.CreateRetVoid(); return Status::OK(); } -llvm::Value* IrEmitter::GetProfileCounterFor(const HloInstruction* hlo) { - string counter_name; - size_t prof_counter_idx; - if (!hlo_to_profile_idx_) { +llvm::Value* IrEmitter::GetProfileCounterFor(const HloInstruction& hlo) { + auto it = hlo_to_profile_idx_.find(&hlo); + if (it == hlo_to_profile_idx_.end()) { return nullptr; } - if (hlo) { - auto it = hlo_to_profile_idx_->find(hlo); - if (it == hlo_to_profile_idx_->end()) { - return nullptr; - } - prof_counter_idx = it->second; - counter_name = IrName("prof_counter", hlo->name()); - } else { - prof_counter_idx = hlo_to_profile_idx_->size(); - counter_name = "prof_counter.computation"; - } + size_t prof_counter_idx = it->second; + string counter_name = IrName("prof_counter", hlo.name()); return ir_builder_.CreateGEP(GetProfileCountersArgument(), ir_builder_.getInt64(prof_counter_idx), AsStringRef(counter_name)); } +llvm::Value* IrEmitter::GetProfileCounterForEntryComputation() { + if (entry_computation_profile_idx_) { + return ir_builder_.CreateGEP( + GetProfileCountersArgument(), + ir_builder_.getInt64(*entry_computation_profile_idx_), + "prof_counter.computation"); + } + return nullptr; +} + void IrEmitter::ProfilingState::UpdateProfileCounter( llvm::IRBuilder<>* ir_builder, llvm::Value* prof_counter, llvm::Value* cycle_end, llvm::Value* cycle_start) { @@ -2723,14 +2715,14 @@ void IrEmitter::ProfilingState::RecordCompleteComputation( Status IrEmitter::Preprocess(HloInstruction* hlo) { VLOG(3) << "Visiting: " << hlo->ToString(); - if (hlo_to_profile_idx_ && hlo_to_profile_idx_->count(hlo)) { + if (hlo_to_profile_idx_.count(hlo)) { profiling_state_.RecordCycleStart(&ir_builder_, hlo); } return Status::OK(); } Status IrEmitter::Postprocess(HloInstruction* hlo) { - if (auto* prof_counter = GetProfileCounterFor(hlo)) { + if (auto* prof_counter = GetProfileCounterFor(*hlo)) { profiling_state_.RecordCycleDelta(&ir_builder_, hlo, prof_counter); } return Status::OK(); @@ -2766,45 +2758,16 @@ llvm::Type* IrEmitter::IrShapeType(const Shape& shape) { return llvm_ir::ShapeToIrType(shape, module_); } -std::vector IrEmitter::GetComputeFunctionParams() { - llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext()); - llvm::Type* i8_ptr_ptr_type = i8_ptr_type->getPointerTo(); - llvm::Type* i64_ptr_type = llvm::Type::getInt64PtrTy(module_->getContext()); - std::vector compute_function_params( - {i8_ptr_type, i8_ptr_type, i8_ptr_ptr_type, i8_ptr_ptr_type}); - if (num_dynamic_loop_bounds_ > 0) { - compute_function_params.push_back(i64_ptr_type); - } - if (hlo_to_profile_idx_) { - compute_function_params.push_back(i64_ptr_type); - } - return compute_function_params; -} - -llvm::Argument* IrEmitter::GetResultArgument() { - return GetArg(compute_function_, 0); -} - -llvm::Argument* IrEmitter::GetProfileCountersArgument() { - const int64 arg_index = num_dynamic_loop_bounds_ > 0 ? 5 : 4; - return hlo_to_profile_idx_ ? GetArg(compute_function_, arg_index) : nullptr; +llvm::Value* IrEmitter::GetProfileCountersArgument() { + return compute_function_->profile_counters_arg(); } llvm::Value* IrEmitter::GetTempBuffersArgument() { - return GetArg(compute_function_, 3); -} - -llvm::Value* IrEmitter::GetDynamicLoopBound(const int64 offset) { - CHECK_GT(num_dynamic_loop_bounds_, 0); - CHECK_LT(offset, num_dynamic_loop_bounds_ * 2); - llvm::Argument* loop_bounds_arg = GetArg(compute_function_, 4); - string name = tensorflow::strings::StrCat("dynamic_loop_bound_", offset); - return ir_builder_.CreateLoad(ir_builder_.CreateGEP( - loop_bounds_arg, ir_builder_.getInt64(offset), AsStringRef(name))); + return compute_function_->temp_buffers_arg(); } llvm::Value* IrEmitter::GetExecutableRunOptionsArgument() { - return GetArg(compute_function_, 1); + return compute_function_->exec_run_options_arg(); } llvm::Value* IrEmitter::EmitTempBufferPointer( @@ -2869,42 +2832,6 @@ llvm::Value* IrEmitter::EmitElementFunctionCall( AsStringRef(tensorflow::strings::StrCat(name, "_return_value"))); } -// Emits code to allocate an array of parameter address pointers, and store -// each address from 'parameter_addresses'. -// Returns an array of compute function call arguments (including parameter -// address buffer). -std::vector IrEmitter::GetArrayFunctionCallArguments( - tensorflow::gtl::ArraySlice parameter_addresses, - llvm::Value* return_value_buffer, tensorflow::StringPiece name) { - llvm::Value* parameter_addresses_buffer = - llvm_ir::EmitAllocaAtFunctionEntryWithCount( - ir_builder_.getInt8PtrTy(), - ir_builder_.getInt32(parameter_addresses.size()), - tensorflow::strings::StrCat(name, "_parameter_addresses"), - &ir_builder_); - for (size_t i = 0; i < parameter_addresses.size(); ++i) { - llvm::Value* parameter_as_i8ptr = ir_builder_.CreateBitCast( - parameter_addresses[i], ir_builder_.getInt8PtrTy(), - AsStringRef(tensorflow::strings::StrCat(name, "_parameter_", i, - "_address_as_i8ptr"))); - llvm::Value* slot_in_param_adresses = ir_builder_.CreateInBoundsGEP( - parameter_addresses_buffer, {ir_builder_.getInt64(i)}); - ir_builder_.CreateStore(parameter_as_i8ptr, slot_in_param_adresses); - } - - const auto to_int8_ptr = [this](llvm::Value* ptr) { - return ir_builder_.CreatePointerCast(ptr, ir_builder_.getInt8PtrTy()); - }; - std::vector arguments{ - to_int8_ptr(return_value_buffer), - to_int8_ptr(GetExecutableRunOptionsArgument()), - parameter_addresses_buffer, GetTempBuffersArgument()}; - if (auto* profile_counters = GetProfileCountersArgument()) { - arguments.push_back(profile_counters); - } - return arguments; -} - // Emits a core function call based on the following pseudo-code. // // char** parameter_addresses_buffer = @@ -2920,8 +2847,12 @@ void IrEmitter::EmitArrayFunctionCallInto( tensorflow::gtl::ArraySlice parameter_addresses, llvm::Value* return_value_buffer, tensorflow::StringPiece name) { ir_builder_.CreateCall( - function, GetArrayFunctionCallArguments(parameter_addresses, - return_value_buffer, name)); + function, GetArrayFunctionCallArguments( + parameter_addresses, &ir_builder_, name, + /*return_value_buffer=*/return_value_buffer, + /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(), + /*temp_buffers_arg=*/GetTempBuffersArgument(), + /*profile_counters_arg=*/GetProfileCountersArgument())); } llvm::Value* IrEmitter::EmitArrayFunctionCall( @@ -2941,117 +2872,13 @@ llvm::Value* IrEmitter::EmitArrayFunctionCall( return return_value_buffer; } -// Emits a call to a runtime fork/join function which dispatches parallel -// calls to 'parallel_function' (and joins threads before returning). -Status IrEmitter::EmitParallelForkJoin( - tensorflow::gtl::ArraySlice parameter_addresses, - llvm::Value* output_address, HloComputation* computation, - llvm::Function* parallel_function) { - HloInstruction* root = computation->root_instruction(); - - // Build ParallelForkJoin function type. - std::vector compute_function_params = GetComputeFunctionParams(); - // Number of parallel compute functions. - compute_function_params.push_back(ir_builder_.getInt32Ty()); - // Array of partitions. There is an array element for each - // partition x partition_dim x 2 (for dimension start and limit). - compute_function_params.push_back( - llvm::Type::getInt64PtrTy(module_->getContext())); - // Number of partitioned most-major dimensions in 'root.shape'. - compute_function_params.push_back(ir_builder_.getInt32Ty()); - // Function pointer for compute function to be dispatched in parallel. - compute_function_params.push_back( - llvm::Type::getInt8PtrTy(module_->getContext())); - - llvm::FunctionType* fork_join_type = llvm::FunctionType::get( - /*Result=*/llvm::Type::getVoidTy(module_->getContext()), - /*Params=*/compute_function_params, - /*isVarArg=*/false); - - llvm::Function* fork_join_func = - llvm::cast(module_->getOrInsertFunction( - runtime::kParallelForkJoinSymbolName, fork_join_type)); - fork_join_func->setCallingConv(llvm::CallingConv::C); - fork_join_func->setDoesNotThrow(); - - // Add common compute function arguments. - const string name = computation->name(); - std::vector arguments = - GetArrayFunctionCallArguments(parameter_addresses, output_address, name); - - // Create ShapePartitionIterator to generate all partitions of 'root.shape'. - ShapePartitionIterator partition_iterator(root->shape(), - root->outer_dimension_partitions()); - const int64 num_partitions = partition_iterator.GetTotalPartitionCount(); - // Add argument specifying the number of parallel partitions. - arguments.push_back(ir_builder_.getInt32(num_partitions)); - - // The number of partitioned most-major dimensions in 'root.shape'. - const int32 num_partitioned_dims = root->outer_dimension_partitions().size(); - // A dimension partition consists of two elements: [start_index, limit_index). - const int32 dim_partition_size = 2; - // Calculate array partition stride. - const int32 array_partition_stride = - num_partitioned_dims * dim_partition_size; - // Calculate the total number of elements in the partition array. - const int32 partition_array_size = - dim_partition_size * num_partitioned_dims * num_partitions; - - // Store dimension partition values as llvm constants in 'partitions'. - // See comments in runtime_fork_join.cc for array layout description. - std::vector partitions(partition_array_size); - for (int32 i = 0; i < num_partitions; ++i) { - std::vector> dim_partitions = - partition_iterator.GetPartition(i); - CHECK_EQ(num_partitioned_dims, dim_partitions.size()); - const int32 partition_index = i * array_partition_stride; - for (int32 j = 0; j < num_partitioned_dims; ++j) { - const std::pair& dim_partition = dim_partitions[j]; - const int32 index = partition_index + j * dim_partition_size; - // Store partition [dim_start, dim_limit) intervals for each dimension. - partitions[index] = ir_builder_.getInt64(dim_partition.first); - partitions[index + 1] = - ir_builder_.getInt64(dim_partition.first + dim_partition.second); - } - } - - // Create global variable out of dimension partitions in 'partitions'. - llvm::ArrayType* partitions_array_type = - llvm::ArrayType::get(ir_builder_.getInt64Ty(), partition_array_size); - llvm::Constant* partitions_array = - llvm::ConstantArray::get(partitions_array_type, partitions); - llvm::GlobalVariable* global_partitions_array = new llvm::GlobalVariable( - /*Module=*/*module_, - /*Type=*/partitions_array_type, - /*isConstant=*/true, - /*Linkage=*/llvm::GlobalValue::PrivateLinkage, - /*Initializer=*/partitions_array, - /*Name=*/ - AsStringRef( - tensorflow::strings::StrCat(name, "_parallel_dimension_partitions"))); - - // Add argument specifying parallel dimension partitions. - arguments.push_back(ir_builder_.CreateBitCast( - global_partitions_array, - llvm::Type::getInt64PtrTy(module_->getContext()))); - // Add argument specifying the number of partitioned most-major dimensions. - arguments.push_back(ir_builder_.getInt32(num_partitioned_dims)); - // Add argument for parallel compute function pointer. - arguments.push_back( - ir_builder_.CreateBitCast(parallel_function, ir_builder_.getInt8PtrTy())); - // Emit call to parallel fork/join. - ir_builder_.CreateCall(fork_join_func, arguments); - - return Status::OK(); -} - Status IrEmitter::EmitTargetAddressForOp(const HloInstruction* op) { llvm::Value* addr; const Shape& target_shape = op->shape(); if (op == op->parent()->root_instruction()) { // For the root node, we write directly to the output buffer of the // function. - llvm::Argument* retval = GetResultArgument(); + llvm::Argument* retval = compute_function_->result_arg(); if (!ShapeUtil::IsNil(target_shape)) { llvm::AttrBuilder attr_builder; attr_builder.addAlignmentAttr(MinimumAlignmentForShape(target_shape)); @@ -3112,8 +2939,13 @@ Status IrEmitter::EmitTargetElementLoop( } else { if (ShouldEmitParallelLoopFor(*target_op)) { - TF_RETURN_IF_ERROR(EmitParallelTargetElementLoop( - target_shape, element_generator, IrName(target_op), &target_array)); + // Emit code to read dynamic loop bounds from compute function argument. + std::vector> dynamic_loop_bounds = + compute_function_->GetDynamicLoopBounds(); + // Emit parallel loop with dynamic loop bounds for most-major dimensions. + TF_RETURN_IF_ERROR(ParallelLoopEmitter(element_generator, target_array, + &dynamic_loop_bounds, &ir_builder_) + .EmitLoop(IrName(target_op))); } else { TF_RETURN_IF_ERROR( llvm_ir::LoopEmitter(element_generator, target_array, &ir_builder_) @@ -3123,60 +2955,6 @@ Status IrEmitter::EmitTargetElementLoop( return Status::OK(); } -Status IrEmitter::EmitParallelTargetElementLoop( - const Shape& target_shape, - const llvm_ir::ElementGenerator& element_generator, - tensorflow::StringPiece loop_name, llvm_ir::IrArray* target_array) { - CHECK(!ShapeUtil::IsTuple(target_shape)); - CHECK(!ShapeUtil::IsScalar(target_shape)); - - // Emit code to read dynamic loop bounds from function argument 4. - std::vector dynamic_loop_bounds(2 * num_dynamic_loop_bounds_); - for (int i = 0; i < 2 * num_dynamic_loop_bounds_; ++i) { - dynamic_loop_bounds[i] = GetDynamicLoopBound(i); - } - - llvm_ir::ForLoopNest loop_nest(loop_name, &ir_builder_); - const int64 num_dims = target_shape.dimensions_size(); - llvm_ir::IrArray::Index array_index(num_dims); - - // Add loops from outer-most to inner-most dimensions. - for (int i = target_shape.layout().minor_to_major_size() - 1; i >= 0; --i) { - const int64 dimension = target_shape.layout().minor_to_major(i); - const int bounds_index = num_dims - 1 - i; - if (bounds_index < num_dynamic_loop_bounds_) { - // Emit dynamic loop bounds for this dimension. Dynamic loop bounds - // are read from ir function dynamic loop bounds argument. - llvm::Value* start_index = dynamic_loop_bounds[bounds_index * 2 + 0]; - llvm::Value* end_index = dynamic_loop_bounds[bounds_index * 2 + 1]; - - std::unique_ptr loop = loop_nest.AddLoop( - /*suffix=*/tensorflow::strings::Printf("dim.%lld", dimension), - start_index, end_index); - array_index[dimension] = loop->GetIndVarValue(); - } else { - // Emit static loop bounds for this dimension. - std::unique_ptr loop = loop_nest.AddLoop( - /*start_index=*/0, - /*end_index=*/target_shape.dimensions(dimension), - /*suffix=*/tensorflow::strings::Printf("dim.%lld", dimension)); - array_index[dimension] = loop->GetIndVarValue(); - } - } - // Point IR builder at inner loop BB. - SetToFirstInsertPoint(loop_nest.GetInnerLoopBodyBasicBlock(), &ir_builder_); - - // Emit loop body. - TF_ASSIGN_OR_RETURN(llvm::Value * target_element, - element_generator(array_index)); - target_array->EmitWriteArrayElement(array_index, target_element, - &ir_builder_); - // Point IR builder at outer loop exit BB. - SetToFirstInsertPoint(loop_nest.GetOuterLoopExitBasicBlock(), &ir_builder_); - - return Status::OK(); -} - Status IrEmitter::EmitMemcpy(const HloInstruction& source, const HloInstruction& destination) { llvm::Value* source_value = GetEmittedValueFor(&source); diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index 5d061e11e3c9e07bdcfdc749711e4369ec2bea2a..9bc2d9739757168562b8dc7b482eff203f303766 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -30,6 +31,7 @@ limitations under the License. #include "llvm/Target/TargetMachine.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/cpu/external_constant_pool.h" +#include "tensorflow/compiler/xla/service/cpu/ir_function.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -105,15 +107,18 @@ class IrEmitter : public DfsHloVisitorWithDefault { // llvm_module: the LLVM module to emit IR into. // hlo_to_profile_idx: the mapping from HLO to its index in the profiling // array. + // entry_computation_profile_idx: the index in the profiling array + // for the entry computation. // external_constant_pool: if non-null, points to an ExternalConstantPool // instance into which the Ir emitter can spill // constants. - IrEmitter(const HloModule& hlo_module, const BufferAssignment& assignment, - llvm::Module* llvm_module, - const std::unordered_map* - hlo_to_profile_idx, - llvm::TargetMachine* target_machine, - ExternalConstantPool* external_constant_pool); + IrEmitter( + const HloModule& hlo_module, const BufferAssignment& assignment, + llvm::Module* llvm_module, + std::unordered_map hlo_to_profile_idx, + tensorflow::gtl::optional entry_computation_profile_idx, + llvm::TargetMachine* target_machine, + ExternalConstantPool* external_constant_pool); ~IrEmitter() override; // Emit and return the given HLO computation as an LLVM IR @@ -171,11 +176,13 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status HandleReduceWindow(HloInstruction* reduce_window) override; Status HandleSelectAndScatter(HloInstruction* select_and_scatter) override; Status HandleSend(HloInstruction* send) override; + Status HandleSendDone(HloInstruction* send_done) override; Status HandleSlice(HloInstruction* slice) override; Status HandleDynamicSlice(HloInstruction* dynamic_slice) override; Status HandleDynamicUpdateSlice( HloInstruction* dynamic_update_slice) override; Status HandleRecv(HloInstruction* recv) override; + Status HandleRecvDone(HloInstruction* recv_done) override; Status HandlePad(HloInstruction* pad) override; Status HandleTuple(HloInstruction* tuple) override; Status HandleMap(HloInstruction* map) override; @@ -184,6 +191,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status HandleCustomCall(HloInstruction* custom_call) override; Status HandleWhile(HloInstruction* xla_while) override; Status HandleConcatenate(HloInstruction* concatenate) override; + Status HandleConditional(HloInstruction* conditional) override; Status FinishVisit(HloInstruction* root) override; Status Preprocess(HloInstruction* hlo) override; @@ -195,7 +203,12 @@ class IrEmitter : public DfsHloVisitorWithDefault { // Convenience function to generate a GEP into the profile counter parameter // which would correspond to the index for a given HLO. - llvm::Value* GetProfileCounterFor(const HloInstruction* hlo); + llvm::Value* GetProfileCounterFor(const HloInstruction& hlo); + + // Convenience function to generate a GEP into the profile counter parameter + // corresponding to the index for the entry computation. Returns nullptr if + // profiling the entry computation is disabled. + llvm::Value* GetProfileCounterForEntryComputation(); // Gets the IR Value emitted previously for the given hlo. // @@ -223,16 +236,9 @@ class IrEmitter : public DfsHloVisitorWithDefault { // Convenience function to get the IR type matching the given shape. llvm::Type* IrShapeType(const Shape& shape); - // Returns an array of compute function parameter types. - std::vector GetComputeFunctionParams(); - - // Get the llvm::Value* that represents the "retval" argument of the - // computation function being emitted by this emitter. - llvm::Argument* GetResultArgument(); - // Get the llvm::Value* that represents the "prof_counters" argument of the // computation function being emitted by this emitter. - llvm::Argument* GetProfileCountersArgument(); + llvm::Value* GetProfileCountersArgument(); // Get the xla::ExecutableRunOptions that represents the "run_options" // argument of the computation function being emitted by this emitter. @@ -242,11 +248,6 @@ class IrEmitter : public DfsHloVisitorWithDefault { // computation function being emitted by this emitter. llvm::Value* GetTempBuffersArgument(); - // Emit ir to read and return the ir value for the dynamic loop bound at - // 'offset' from the "dynamic_loop_bounds" argument of the computation - // function being emitted by this emitter. - llvm::Value* GetDynamicLoopBound(const int64 offset); - // Emits code that computes the address of the given temporary buffer to the // function. target_shape is the shape of this temporary buffer. // The returned Value's type is a pointer to element_type. @@ -300,18 +301,6 @@ class IrEmitter : public DfsHloVisitorWithDefault { tensorflow::gtl::ArraySlice parameter_addresses, tensorflow::StringPiece name); - // Returns an array of compute function call arguments. - std::vector GetArrayFunctionCallArguments( - tensorflow::gtl::ArraySlice parameter_addresses, - llvm::Value* return_value_buffer, tensorflow::StringPiece name); - - // Emits a call to a runtime fork/join function which dispatches parallel - // calls to 'parallel_function' (and joins threads before returning). - Status EmitParallelForkJoin( - tensorflow::gtl::ArraySlice parameter_addresses, - llvm::Value* output_address, HloComputation* computation, - llvm::Function* parallel_function); - // Verifies that the element types of all of the given operand instructions // match and are of one of the given supported types. Status ElementTypesSameAndSupported( @@ -336,15 +325,6 @@ class IrEmitter : public DfsHloVisitorWithDefault { HloInstruction* target_op, tensorflow::StringPiece desc, const llvm_ir::ElementGenerator& element_generator); - // Emit IR to perform a computation for every element in a partition/slice of - // 'target_shape'. The loop bounds for the outer-dimension partitions are - // passed into the compute function as a runtime argument (accessible from - // GetDynamicLoopBound). - Status EmitParallelTargetElementLoop( - const Shape& target_shape, - const llvm_ir::ElementGenerator& element_generator, - tensorflow::StringPiece loop_name, llvm_ir::IrArray* target_array); - // Emits a memcpy from the source instruction's result value to the // destination's. Both source and destination must have an entry in the // emitted_value_ table. @@ -466,12 +446,15 @@ class IrEmitter : public DfsHloVisitorWithDefault { thread_local_buffers_; // The following fields track the IR emission state. According to LLVM memory - // management rules, their memory is owned by the module. - llvm::Function* compute_function_; + // management rules, their memory is owned by the module (Note that IrFunction + // creates the encapsulated llvm::Function s.t. it is added to the llvm + // module's function list). + std::unique_ptr compute_function_; llvm::IRBuilder<> ir_builder_; // Maps HLOs to their index into the profile counter array. - const std::unordered_map* hlo_to_profile_idx_; + std::unordered_map hlo_to_profile_idx_; + const tensorflow::gtl::optional entry_computation_profile_idx_; // Maps HLOs to Values emitted for them. std::unordered_map emitted_value_; @@ -479,7 +462,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { llvm_ir::AliasAnalysis alias_analysis_; // The number of root instruction outer dimensions used in parallel loop - // emission (EmitParallelTargetElementLoop). + // emission (ParallelLoopEmitter). int64 num_dynamic_loop_bounds_ = 0; // Returns whether the given instruction should be emitted as a parallel loop. @@ -499,7 +482,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { use_rdtscp_(false), prof_counters_(nullptr) {} ProfilingState(bool is_top_level_computation, bool use_rdtscp, - llvm::Argument* prof_counters) + llvm::Value* prof_counters) : is_top_level_computation_(is_top_level_computation), use_rdtscp_(use_rdtscp), prof_counters_(prof_counters) {} @@ -532,7 +515,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { bool use_rdtscp_; // The argument which corresponds to the profile counter buffer. - llvm::Argument* prof_counters_; + llvm::Value* prof_counters_; // The first read cycle counter in the program. llvm::Value* first_read_cycle_start_ = nullptr; diff --git a/tensorflow/compiler/xla/service/cpu/ir_function.cc b/tensorflow/compiler/xla/service/cpu/ir_function.cc new file mode 100644 index 0000000000000000000000000000000000000000..ca8c290dd1c4959e42026c3917d37f8fc95a1011 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/ir_function.cc @@ -0,0 +1,333 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/xla/service/cpu/ir_function.h" + +#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" +#include "tensorflow/compiler/xla/service/cpu/shape_partition.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/status_macros.h" + +namespace xla { + +namespace { +using llvm_ir::AsStringRef; +} // namespace + +namespace cpu { + +static std::vector GetComputeFunctionParams( + llvm::Module* llvm_module, const int64 num_dynamic_loop_bounds) { + llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(llvm_module->getContext()); + llvm::Type* i8_ptr_ptr_type = i8_ptr_type->getPointerTo(); + llvm::Type* i64_ptr_type = + llvm::Type::getInt64PtrTy(llvm_module->getContext()); + std::vector compute_function_params( + {i8_ptr_type, i8_ptr_type, i8_ptr_ptr_type, i8_ptr_ptr_type}); + if (num_dynamic_loop_bounds > 0) { + compute_function_params.push_back(i64_ptr_type); + } + compute_function_params.push_back(i64_ptr_type); + return compute_function_params; +} + +IrFunction::IrFunction(const string& function_name, + llvm::Function::LinkageTypes linkage, + const bool optimize_for_size_requested, + const bool enable_fast_math, llvm::Module* llvm_module, + llvm::IRBuilder<>* ir_builder, + int64 num_dynamic_loop_bounds) + : ir_builder_(ir_builder), + llvm_module_(llvm_module), + caller_insert_point_guard_(*ir_builder), + num_dynamic_loop_bounds_(num_dynamic_loop_bounds) { + Initialize(function_name, linkage, optimize_for_size_requested, + enable_fast_math); +} + +IrFunction::~IrFunction() { + // Emit function return value. + ir_builder_->CreateRetVoid(); +} + +DynamicLoopBounds IrFunction::GetDynamicLoopBounds() { + DynamicLoopBounds dynamic_loop_bounds(num_dynamic_loop_bounds_); + for (int i = 0; i < num_dynamic_loop_bounds_; ++i) { + dynamic_loop_bounds[i].first = GetDynamicLoopBound(i * 2 + 0); + dynamic_loop_bounds[i].second = GetDynamicLoopBound(i * 2 + 1); + } + return dynamic_loop_bounds; +} + +void IrFunction::Initialize(const string& function_name, + llvm::Function::LinkageTypes linkage, + const bool optimize_for_size_requested, + const bool enable_fast_math) { + // The function signature is: + // void function(i8* retval, i8* run_options, i8** params, i8** temps, + // i64* dynamic_loop_bounds, i64* prof_counters) + // + // retval: points to the returned value. + // params: address of an array with pointers to parameters. + // temps: address of an array with pointers to temporary buffers. + // + // Therefore, the generated function's signature (FunctionType) is statically + // determined - parameter unpacking is done in code generated into the + // function, rather than by a prologue dictated by the platform ABI. + // + // /--------------\ + // retval ----------> | return value | + // \--------------/ + // + // /-------------------------------\ + // run_options -----> | xla::ExecutableRunOptions | + // \-------------------------------/ + // + // /---------------------------------------------\ + // params --------> | param 0 | param 1 | ..... | param N-1 | + // | addr | addr | | addr | + // \---------------------------------------------/ + // | | | + // | | | + // V V V + // /---------\ /---------\ /-----------\ + // | param 0 | | param 1 | | param N-1 | + // \---------/ \---------/ \-----------/ + // + // /---------------------------------------------\ + // temps ---------> | temp 0 | temp 1 | ..... | temp N-1 | + // | addr | addr | | addr | + // \---------------------------------------------/ + // | | | + // | | | + // V V V + // /---------\ /---------\ /-----------\ + // | temp 0 | | temp 1 | | temp N-1 | + // \---------/ \---------/ \-----------/ + // + // /--------------------------------------------\ + // dynamic loop bounds -> | outer_dim0_start | outer_dim0_limit | .....| + // (elided for aot) \--------------------------------------------/ + // + // /---------------------------------------------\ + // prof counters -> | counter 0 | counter 1 | ..... | counter N-1 | + // \---------------------------------------------/ + + // Even though the type of params and temps is void** in the host's view, in + // LLVM IR this is represented by i8*, similarly to void*. It's up to the code + // to use GEPs to unravel the indirection layers. + llvm::FunctionType* function_type = llvm::FunctionType::get( + /*Result=*/llvm::Type::getVoidTy(llvm_module_->getContext()), + /*Params=*/ + GetComputeFunctionParams(llvm_module_, num_dynamic_loop_bounds_), + /*isVarArg=*/false); + + // Functions with local linkage get an inlining bonus. Because we know + // a-priori that embedded functions (non-entry functions) will not have its + // name resolved, give it local linkage. + function_ = + llvm_ir::CreateFunction(function_type, linkage, + /*enable_fast_math=*/enable_fast_math, + /*optimize_for_size=*/optimize_for_size_requested, + function_name, llvm_module_); + + // Set meaningful names for the function's arguments: useful for debugging. + llvm::Function::arg_iterator arg_iter = function_->arg_begin(); + arg_iter->setName("retval"); + result_arg_ = &*arg_iter; + (++arg_iter)->setName("run_options"); + exec_run_options_arg_ = &*arg_iter; + (++arg_iter)->setName("params"); + parameters_arg_ = &*arg_iter; + (++arg_iter)->setName("temps"); + temp_buffers_arg_ = &*arg_iter; + if (num_dynamic_loop_bounds_ > 0) { + (++arg_iter)->setName("dynamic_loop_bounds"); + dynamic_loop_bounds_arg_ = &*arg_iter; + } + (++arg_iter)->setName("prof_counters"); + profile_counters_arg_ = &*arg_iter; + + // We know a-priori that the function arguments are guaranteed to point to + // disjoint objects. + llvm::Argument* retval = result_arg(); + for (llvm::Argument& argument : function_->args()) { + // However, the return buffer aliases the temporaries and thus cannot be + // marked noalias. + if (&argument == retval) { + continue; + } + function_->addAttribute(argument.getArgNo() + 1, llvm::Attribute::NoAlias); + } + + ir_builder_->SetInsertPoint(llvm::BasicBlock::Create( + /*Context=*/llvm_module_->getContext(), + /*Name=*/"entry", + /*Parent=*/function_)); +} + +llvm::Value* IrFunction::GetDynamicLoopBound(const int64 offset) { + CHECK_GT(num_dynamic_loop_bounds_, 0); + CHECK_LT(offset, num_dynamic_loop_bounds_ * 2); + string name = tensorflow::strings::StrCat("dynamic_loop_bound_", offset); + return ir_builder_->CreateLoad( + ir_builder_->CreateGEP(CHECK_NOTNULL(dynamic_loop_bounds_arg_), + ir_builder_->getInt64(offset), AsStringRef(name))); +} + +// Emits code to allocate an array of parameter address pointers, and store +// each address from 'parameter_addresses'. +// Returns an array of compute function call arguments (including parameter +// address buffer). +std::vector GetArrayFunctionCallArguments( + tensorflow::gtl::ArraySlice parameter_addresses, + llvm::IRBuilder<>* ir_builder, tensorflow::StringPiece name, + llvm::Value* return_value_buffer, llvm::Value* exec_run_options_arg, + llvm::Value* temp_buffers_arg, llvm::Value* profile_counters_arg) { + llvm::Value* parameter_addresses_buffer = + llvm_ir::EmitAllocaAtFunctionEntryWithCount( + ir_builder->getInt8PtrTy(), + ir_builder->getInt32(parameter_addresses.size()), + tensorflow::strings::StrCat(name, "_parameter_addresses"), + ir_builder); + for (size_t i = 0; i < parameter_addresses.size(); ++i) { + llvm::Value* parameter_as_i8ptr = ir_builder->CreateBitCast( + parameter_addresses[i], ir_builder->getInt8PtrTy(), + AsStringRef(tensorflow::strings::StrCat(name, "_parameter_", i, + "_address_as_i8ptr"))); + llvm::Value* slot_in_param_adresses = ir_builder->CreateInBoundsGEP( + parameter_addresses_buffer, {ir_builder->getInt64(i)}); + ir_builder->CreateStore(parameter_as_i8ptr, slot_in_param_adresses); + } + + const auto to_int8_ptr = [=](llvm::Value* ptr) { + return ir_builder->CreatePointerCast(ptr, ir_builder->getInt8PtrTy()); + }; + std::vector arguments{ + to_int8_ptr(return_value_buffer), to_int8_ptr(exec_run_options_arg), + parameter_addresses_buffer, temp_buffers_arg}; + if (profile_counters_arg != nullptr) { + arguments.push_back(profile_counters_arg); + } + return arguments; +} + +// Emits a call to a runtime fork/join function which dispatches parallel +// calls to 'parallel_function' (and joins threads before returning). +Status EmitCallToParallelForkJoin( + const std::vector& arguments, const Shape& shape, + const std::vector& dimension_partition_counts, + llvm::IRBuilder<>* ir_builder, llvm::Function* parallel_function, + const string& name) { + llvm::Module* module = ir_builder->GetInsertBlock()->getModule(); + + // Build ParallelForkJoin function type. + std::vector compute_function_params = + GetComputeFunctionParams(module, /*num_dynamic_loop_bounds=*/0); + // Number of parallel compute functions. + compute_function_params.push_back(ir_builder->getInt32Ty()); + // Array of partitions. There is an array element for each + // partition x partition_dim x 2 (for dimension start and limit). + compute_function_params.push_back( + llvm::Type::getInt64PtrTy(module->getContext())); + // Number of partitioned most-major dimensions in 'shape'. + compute_function_params.push_back(ir_builder->getInt32Ty()); + // Function pointer for compute function to be dispatched in parallel. + compute_function_params.push_back( + llvm::Type::getInt8PtrTy(module->getContext())); + + llvm::FunctionType* fork_join_type = llvm::FunctionType::get( + /*Result=*/llvm::Type::getVoidTy(module->getContext()), + /*Params=*/compute_function_params, + /*isVarArg=*/false); + + llvm::Function* fork_join_func = + llvm::cast(module->getOrInsertFunction( + runtime::kParallelForkJoinSymbolName, fork_join_type)); + fork_join_func->setCallingConv(llvm::CallingConv::C); + fork_join_func->setDoesNotThrow(); + + // Add common compute function arguments. + std::vector fork_join_arguments(arguments); + + // Create ShapePartitionIterator to generate all partitions of 'shape'. + ShapePartitionIterator partition_iterator(shape, dimension_partition_counts); + const int64 num_partitions = partition_iterator.GetTotalPartitionCount(); + // Add argument specifying the number of parallel partitions. + fork_join_arguments.push_back(ir_builder->getInt32(num_partitions)); + + // The number of partitioned most-major dimensions in 'shape'. + const int32 num_partitioned_dims = dimension_partition_counts.size(); + // A dimension partition consists of two elements: [start_index, limit_index). + const int32 dim_partition_size = 2; + // Calculate array partition stride. + const int32 array_partition_stride = + num_partitioned_dims * dim_partition_size; + // Calculate the total number of elements in the partition array. + const int32 partition_array_size = + dim_partition_size * num_partitioned_dims * num_partitions; + + // Store dimension partition values as llvm constants in 'partitions'. + // See comments in runtime_fork_join.cc for array layout description. + std::vector partitions(partition_array_size); + for (int32 i = 0; i < num_partitions; ++i) { + std::vector> dim_partitions = + partition_iterator.GetPartition(i); + CHECK_EQ(num_partitioned_dims, dim_partitions.size()); + const int32 partition_index = i * array_partition_stride; + for (int32 j = 0; j < num_partitioned_dims; ++j) { + const std::pair& dim_partition = dim_partitions[j]; + const int32 index = partition_index + j * dim_partition_size; + // Store partition [dim_start, dim_limit) intervals for each dimension. + partitions[index] = ir_builder->getInt64(dim_partition.first); + partitions[index + 1] = + ir_builder->getInt64(dim_partition.first + dim_partition.second); + } + } + + // Create global variable out of dimension partitions in 'partitions'. + llvm::ArrayType* partitions_array_type = + llvm::ArrayType::get(ir_builder->getInt64Ty(), partition_array_size); + llvm::Constant* partitions_array = + llvm::ConstantArray::get(partitions_array_type, partitions); + llvm::GlobalVariable* global_partitions_array = new llvm::GlobalVariable( + /*M=*/*module, + /*Ty=*/partitions_array_type, + /*isConstant=*/true, + /*Linkage=*/llvm::GlobalValue::PrivateLinkage, + /*Initializer=*/partitions_array, + /*Name=*/ + AsStringRef( + tensorflow::strings::StrCat(name, "_parallel_dimension_partitions"))); + + // Add argument specifying parallel dimension partitions. + fork_join_arguments.push_back(ir_builder->CreateBitCast( + global_partitions_array, + llvm::Type::getInt64PtrTy(module->getContext()))); + // Add argument specifying the number of partitioned most-major dimensions. + fork_join_arguments.push_back(ir_builder->getInt32(num_partitioned_dims)); + // Add argument for parallel compute function pointer. + fork_join_arguments.push_back( + ir_builder->CreateBitCast(parallel_function, ir_builder->getInt8PtrTy())); + // Emit call to parallel fork/join. + ir_builder->CreateCall(fork_join_func, fork_join_arguments); + + return Status::OK(); +} + +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/ir_function.h b/tensorflow/compiler/xla/service/cpu/ir_function.h new file mode 100644 index 0000000000000000000000000000000000000000..1fd2da4dce23982ed030f3aa8ec604182d0ebab8 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/ir_function.h @@ -0,0 +1,134 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_FUNCTION_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_FUNCTION_H_ + +#include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Value.h" +#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/gtl/array_slice.h" + +namespace xla { +namespace cpu { + +// IrFunction creates and encapsulates an llvm::Function, exposing methods to +// emitters for function and function argument access. +// The llvm::Function is created with the standard function signature +// used in the XLA CPU backend (see ir_function.cc for argument details). +// In addtion IrFunction saves the callers IR insert point during contruction, +// and restores it after desctruction. +// +// Example usage: +// +// // Create and initialize new IrFunction. +// std::unique_ptr compute_function(new IrFunction(...)); +// // Emit IR for function body using IrFunction helper methods. +// ... +// // Store reference to llvm::Function for future invocation. +// ir_functions.push_back(compute_function.function()); +// // Delete IrFunction (finalizes IR function and restores caller insertion +// // point). +// compute_function.reset(); +// + +class IrFunction { + public: + IrFunction(const string& function_name, llvm::Function::LinkageTypes linkage, + const bool optimize_for_size_requested, + const bool enable_fast_math, llvm::Module* llvm_module, + llvm::IRBuilder<>* ir_builder, int64 num_dynamic_loop_bounds); + ~IrFunction(); + + // Emit ir to read and return the set of ir values representing the dynamic + // loop bounds argument of this function. + // Each element in returned vector is a pair of ir values representing + // the loop bounds for a specific dimension, where the first element of the + // pair is the dimension start index, and the second element of the pair + // is the dimension limit. + // EX: [dimension_i_index_start_ir_value, dimension_i_index_limit_ir_value] + // + DynamicLoopBounds GetDynamicLoopBounds(); + + // Returns the encapculated llvm::Function. + llvm::Function* function() { return function_; } + + // Get the llvm::Value* that represents this functions "retval" argument. + llvm::Argument* result_arg() { return result_arg_; } + + // Get the xla::ExecutableRunOptions that represents this functions + // "run_options" argument. + llvm::Value* exec_run_options_arg() { return exec_run_options_arg_; } + + // Get the llvm::Value* that represents this functions parameters argument. + llvm::Value* parameters_arg() { return parameters_arg_; } + + // Get the llvm::Value* that represents this functions "temps" argument. + llvm::Value* temp_buffers_arg() { return temp_buffers_arg_; } + + // Get the llvm::Value* that represents this functions "prof_counters" + // argument. + llvm::Value* profile_counters_arg() { return profile_counters_arg_; } + + private: + // Initialize an llvm::Function with standard signature based on arguments. + void Initialize(const string& function_name, + llvm::Function::LinkageTypes linkage, + bool optimize_for_size_requested, bool enable_fast_math); + + // Emit ir to read and return the ir value for the dynamic loop bound at + // 'offset' from the "dynamic_loop_bounds" argument of this function. + llvm::Value* GetDynamicLoopBound(int64 offset); + + llvm::IRBuilder<>* ir_builder_; + llvm::Module* llvm_module_; + llvm::IRBuilder<>::InsertPointGuard caller_insert_point_guard_; + + int64 num_dynamic_loop_bounds_ = 0; + // Encapsulated llvm::Function. + llvm::Function* function_; + // Function argument IR values. + llvm::Argument* result_arg_; + llvm::Value* exec_run_options_arg_; + llvm::Value* parameters_arg_; + llvm::Value* temp_buffers_arg_; + llvm::Value* dynamic_loop_bounds_arg_ = nullptr; + llvm::Value* profile_counters_arg_; +}; + +// Returns an array of compute function call argument ir values. +std::vector GetArrayFunctionCallArguments( + tensorflow::gtl::ArraySlice parameter_addresses, + llvm::IRBuilder<>* ir_builder, tensorflow::StringPiece name, + llvm::Value* return_value_buffer, llvm::Value* exec_run_options_arg, + llvm::Value* temp_buffers_arg, llvm::Value* profile_counters_arg); + +// Emits a call to a runtime fork/join function which dispatches parallel +// calls to 'parallel_function' (and joins threads before returning). +Status EmitCallToParallelForkJoin( + const std::vector& arguments, const Shape& shape, + const std::vector& dimension_partition_counts, + llvm::IRBuilder<>* ir_builder, llvm::Function* parallel_function, + const string& name); + +} // namespace cpu +} // namespace xla + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_FUNCTION_H_ diff --git a/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc b/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc index b49047283119fb2f10b9f68eaa37a7bdc27f63a6..81c29e4726c7be53b433be896f558f502e43c885 100644 --- a/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc @@ -52,7 +52,7 @@ llvm::Function* EmitVectorF32TanhIfNeeded(llvm::Module* module, llvm::IRBuilder<> ir_builder(vector_tanh_body); llvm::FastMathFlags fast_math_flags; - fast_math_flags.setUnsafeAlgebra(); + fast_math_flags.setFast(); ir_builder.setFastMathFlags(fast_math_flags); llvm::Value* input = &*vector_tanh_function->arg_begin(); diff --git a/tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.cc b/tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.cc new file mode 100644 index 0000000000000000000000000000000000000000..e624e5cc7ebdbb79a8a3b3c73633ec697a71d172 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.cc @@ -0,0 +1,40 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/mutex.h" + +namespace xla { +namespace cpu { +namespace orc_jit_memory_mapper { + +static tensorflow::mutex mapper_instance_mutex(tensorflow::LINKER_INITIALIZED); +static llvm::SectionMemoryManager::MemoryMapper* mapper_instance + GUARDED_BY(mapper_instance_mutex) = nullptr; + +llvm::SectionMemoryManager::MemoryMapper* GetInstance() { + tensorflow::mutex_lock lock(mapper_instance_mutex); + return mapper_instance; +} + +Registrar::Registrar( + std::unique_ptr mapper) { + tensorflow::mutex_lock lock(mapper_instance_mutex); + mapper_instance = mapper.release(); +} +} // namespace orc_jit_memory_mapper +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.h b/tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.h new file mode 100644 index 0000000000000000000000000000000000000000..2d29550fd5bd659770cc6300e56b57bf1763e671 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.h @@ -0,0 +1,56 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_ORC_JIT_MEMORY_MAPPER_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_ORC_JIT_MEMORY_MAPPER_H_ + +#include + +#include "llvm/ExecutionEngine/SectionMemoryManager.h" + +namespace xla { +namespace cpu { + +namespace orc_jit_memory_mapper { +// Returns the registered memory mapper if there is one. Returns nullptr if no +// memory mapper is registered. +llvm::SectionMemoryManager::MemoryMapper* GetInstance(); + +class Registrar { + public: + // Registers the `mapper` as a memory mapper. This is a no-op if `mapper` is + // null. Precondition: no other memory mapper has been registered yet. + explicit Registrar( + std::unique_ptr mapper); +}; +} // namespace orc_jit_memory_mapper + +#define XLA_INTERNAL_REGISTER_ORC_JIT_MEMORY_MAPPER(mapper_instance, ctr) \ + static ::xla::cpu::orc_jit_memory_mapper::Registrar \ + XLA_INTERNAL_REGISTER_ORC_JIT_MEMORY_MAPPER_NAME(ctr)(mapper_instance) + +// __COUNTER__ must go through another macro to be properly expanded +#define XLA_INTERNAL_REGISTER_ORC_JIT_MEMORY_MAPPER_NAME(ctr) \ + __orc_jit_memory_mapper_registrar_##ctr + +// Registers the std::unique_ptr +// returned by the `factory` expression. `factory` is allowed to evaluate to +// a null unique_ptr in which case this macro does nothing. +#define XLA_REGISTER_ORC_JIT_MEMORY_MAPPER(factory) \ + XLA_INTERNAL_REGISTER_ORC_JIT_MEMORY_MAPPER(factory, __COUNTER__) +} // namespace cpu +} // namespace xla + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_ORC_JIT_MEMORY_MAPPER_H_ diff --git a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc index adedc1c37fdc8fb3c3e017f0773ef3fc52ebdec6..0077e344e2bd34aa598ee076220fee678f31b4ad 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc @@ -58,20 +58,21 @@ ParallelCpuExecutable::ParallelCpuExecutable( std::unique_ptr jit, std::unique_ptr assignment, std::unique_ptr hlo_module, - std::unique_ptr> function_names, - std::unordered_map hlo_to_profile_idx, + std::unique_ptr> function_names, std::unordered_map> - aligned_constants) - : Executable(std::move(hlo_module)), + aligned_constants, + std::unique_ptr hlo_profile_printer, + std::unique_ptr hlo_profile_index_map) + : Executable(std::move(hlo_module), std::move(hlo_profile_printer), + std::move(hlo_profile_index_map)), jit_(std::move(jit)), assignment_(std::move(assignment)), function_names_(std::move(function_names)), - hlo_to_profile_idx_(std::move(hlo_to_profile_idx)), aligned_constants_(std::move(aligned_constants)) {} // Type of the computation function we expect in the JIT. using ComputeFunctionType = void (*)(void*, const void*, const void**, void**, - int64*, uint64*); + int64*, int64*); // Given a pointer to an output buffer (following the CPU JIT calling // conventions), mark addresses that are "live". The initial pointer itself is @@ -102,11 +103,11 @@ namespace { // in 'pending' on 'thread_pool' (storing resulting data in 'results'). class Executor { public: - Executor(const std::map& functions, + Executor(const HloInstructionMap& functions, const ServiceExecutableRunOptions* run_options, std::list* pending, - std::map* results, void** temps_array, - uint64* profile_counters_array, const BufferAssignment* assignment) + HloInstructionMap* results, void** temps_array, + int64* profile_counters_array, const BufferAssignment* assignment) : functions_(functions), run_options_(run_options), pending_(pending), @@ -142,12 +143,12 @@ class Executor { const void** GetOperandBuffers(HloInstruction* instruction); // Arguments passed into Executor. - const std::map& functions_; + const HloInstructionMap& functions_; const ServiceExecutableRunOptions* run_options_; std::list* pending_; - std::map* results_; + HloInstructionMap* results_; void** temps_array_; - uint64* profile_counters_array_; + int64* profile_counters_array_; tensorflow::thread::ThreadPool* thread_pool_; const BufferAssignment* assignment_; @@ -389,9 +390,11 @@ Status ParallelCpuExecutable::ExecuteComputeFunctions( tensorflow::gtl::ArraySlice buffers, HloExecutionProfile* hlo_execution_profile) { // Allocate profiling counters for each hlo instruction that we would like to - // profile. Allocate an additional profile counter for the entire - // computation. - std::vector profile_counters(hlo_to_profile_idx_.size() + 1); + // profile. + std::vector* profile_counters = nullptr; + if (hlo_execution_profile) { + profile_counters = hlo_execution_profile->mutable_profile_counters(); + } std::vector buffer_pointers; buffer_pointers.reserve(buffers.size()); @@ -400,7 +403,7 @@ Status ParallelCpuExecutable::ExecuteComputeFunctions( } // Resolve functions for all the HLO instructions ahead of time. - std::map functions; + HloInstructionMap functions; for (auto& entry : *function_names_) { tensorflow::mutex_lock lock(jit_mutex_); HloInstruction* instruction = entry.first; @@ -412,7 +415,7 @@ Status ParallelCpuExecutable::ExecuteComputeFunctions( } // Map containing pointers to result buffers for each instruction. - std::map results; + HloInstructionMap results; uint64 start_micros = tensorflow::Env::Default()->NowMicros(); @@ -441,9 +444,9 @@ Status ParallelCpuExecutable::ExecuteComputeFunctions( // For example, if we expect a library conv/matmul call to run at max // concurrency, we should not dispatch runnable instructions until the // library call is finished (to avoid expensive cache invalidation). - Executor executor(functions, run_options, &pending, &results, - buffer_pointers.data(), profile_counters.data(), - assignment_.get()); + Executor executor( + functions, run_options, &pending, &results, buffer_pointers.data(), + profile_counters ? profile_counters->data() : nullptr, assignment_.get()); TF_RETURN_IF_ERROR(executor.Run()); @@ -453,18 +456,6 @@ Status ParallelCpuExecutable::ExecuteComputeFunctions( tensorflow::mutex_lock lock(mutex_); double nanoseconds = (end_micros - start_micros) * 1000.0; execution_profile_.set_compute_time_ns(std::max(nanoseconds, 1.0)); - // The last profile counter is used for the computation as a whole. - execution_profile_.set_compute_cycle_count(profile_counters.back()); - } - if (hlo_execution_profile != nullptr) { - hlo_execution_profile->set_total_cycles_executed(entry_computation, - profile_counters.back()); - - for (auto hlo_prof_idx : hlo_to_profile_idx_) { - const HloInstruction* hlo = hlo_prof_idx.first; - uint64 cycles_taken = profile_counters[hlo_prof_idx.second]; - hlo_execution_profile->AddProfileResult(hlo, cycles_taken); - } } return Status::OK(); @@ -618,10 +609,5 @@ const PointsToSet& ParallelCpuExecutable::GetRootPointsToSet() const { module().entry_computation()->root_instruction()); } -std::unique_ptr ParallelCpuExecutable::CreateCostAnalysis() - const { - return MakeUnique(ShapeSizeBytes); -} - } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h index a75552b7d1eeda2f04e95fb8abc3a597f423024a..d65e3f42f3cb34eff005f34b51b81fd5c42974a3 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h +++ b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h @@ -51,11 +51,12 @@ class ParallelCpuExecutable : public Executable { std::unique_ptr jit, std::unique_ptr assignment, std::unique_ptr hlo_module, - std::unique_ptr> function_names, - std::unordered_map hlo_to_profile_idx, + std::unique_ptr> function_names, std::unordered_map> - aligned_constants); + aligned_constants, + std::unique_ptr hlo_profile_printer, + std::unique_ptr hlo_profile_index_map); ~ParallelCpuExecutable() override {} StatusOr ExecuteOnStream( @@ -95,8 +96,6 @@ class ParallelCpuExecutable : public Executable { "Equality test on CPU parallel executable is not implemented."); } - std::unique_ptr CreateCostAnalysis() const override; - private: // Allocate buffers required for execution and assign them to the elements of // "buffers". "buffers" should be sized to the number of buffers in buffer @@ -141,11 +140,7 @@ class ParallelCpuExecutable : public Executable { string ir_module_string_; // Map containing the JITted function names for each HLO instruction. - const std::unique_ptr> - function_names_; - - // Maps HLOs to their index into the profile counter array. - const std::unordered_map hlo_to_profile_idx_; + const std::unique_ptr> function_names_; // Map from HLO Constant instructions to a pointer to their literal data. // The data stored in the protocol buffer might be insufficiently aligned, diff --git a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc new file mode 100644 index 0000000000000000000000000000000000000000..a3c3c1e5efc91af6b924a3712689f3d7ccf5d6f6 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc @@ -0,0 +1,76 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h" + +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/core/lib/strings/stringprintf.h" + +namespace xla { +namespace cpu { + +ParallelLoopEmitter::ParallelLoopEmitter( + const llvm_ir::ElementGenerator& target_element_generator, + const llvm_ir::IrArray& target_array, + const DynamicLoopBounds* dynamic_loop_bounds, llvm::IRBuilder<>* ir_builder) + : LoopEmitter(target_element_generator, target_array, ir_builder), + dynamic_loop_bounds_(dynamic_loop_bounds) {} + +llvm_ir::IrArray::Index ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock( + tensorflow::StringPiece loop_name) { + CHECK(!ShapeUtil::IsTuple(shape_)); + CHECK(!ShapeUtil::IsScalar(shape_)); + + llvm_ir::ForLoopNest loop_nest(loop_name, ir_builder_); + const int64 num_dims = shape_.dimensions_size(); + llvm_ir::IrArray::Index array_index(num_dims); + + // Add loops from outer-most to inner-most dimensions. + for (int i = shape_.layout().minor_to_major_size() - 1; i >= 0; --i) { + const int64 dimension = shape_.layout().minor_to_major(i); + const int bounds_index = num_dims - 1 - i; + if (bounds_index < dynamic_loop_bounds_->size()) { + // Emit dynamic loop bounds for this dimension. Dynamic loop bounds + // are read from ir function dynamic loop bounds argument. + llvm::Value* start_index = (*dynamic_loop_bounds_)[bounds_index].first; + llvm::Value* end_index = (*dynamic_loop_bounds_)[bounds_index].second; + + std::unique_ptr loop = loop_nest.AddLoop( + /*suffix=*/tensorflow::strings::Printf("dim.%lld", dimension), + start_index, end_index); + array_index[dimension] = loop->GetIndVarValue(); + } else { + // Emit static loop bounds for this dimension. + std::unique_ptr loop = loop_nest.AddLoop( + /*start_index=*/0, + /*end_index=*/shape_.dimensions(dimension), + /*suffix=*/tensorflow::strings::Printf("dim.%lld", dimension)); + array_index[dimension] = loop->GetIndVarValue(); + } + } + // Point IR builder at inner loop BB. + llvm_ir::SetToFirstInsertPoint(loop_nest.GetInnerLoopBodyBasicBlock(), + ir_builder_); + + // Set exit_bb_ to the exit block of the loop nest. + exit_bb_ = loop_nest.GetOuterLoopExitBasicBlock(); + CHECK(exit_bb_ != nullptr); + + return array_index; +} + +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h new file mode 100644 index 0000000000000000000000000000000000000000..9335d2818e99eb3588537d80dabddda08c1c020e --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h @@ -0,0 +1,73 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_LOOP_EMITTER_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_LOOP_EMITTER_H_ + +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Value.h" +#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" +#include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" + +namespace xla { +namespace cpu { + +// ParallelLoopEmitter emits a loop nest for the target array shape. +// The outer loop bounds of the loop nest are passed as ir values at runtime +// (specified in 'dynamic_loop_bounds'), and the inner loop bounds are static. +// Dynamic loop bounds are specified as an array of dimension index +// [start, limit) pairs of ir values (one for each partitioned outer dimension). +// +// EX: Let 'shape' = [8, 16, 32], with the loop bounds of the two-most major +// dimensions dynamic. Then 'dynamic_loop_bounds' will contain the +// following ir values for the two most-major dimensions: +// [dim0_index_start_ir_value, dim0_index_limit_ir_value] +// [dim1_index_start_ir_value, dim1_index_limit_ir_value] +// +// Code emitted by ParallelLoopEmitter will be called in a multi-threaded +// context where each thread will be assigned a different set of outer dimension +// partitions, and where all threads will collectively iterate over the +// entire target array shape. +// +// Outer dimension partitions can be generated using the ShapePartitionAssigner +// and ShapePartitionIterator utility classes from shape_partition.cc. +// +class ParallelLoopEmitter : public llvm_ir::LoopEmitter { + public: + // Constructs a ParallelLoopEmitter which uses 'target_element_generator' to + // generate elements, 'dynamic_loop_bounds' to set the loop bounds of the + // most-major dimensions, and 'target_array.' shape to set the static loop + // bounds for the most-minor dimensions. + ParallelLoopEmitter(const llvm_ir::ElementGenerator& target_element_generator, + const llvm_ir::IrArray& target_array, + const DynamicLoopBounds* dynamic_loop_bounds, + llvm::IRBuilder<>* ir_builder); + + ParallelLoopEmitter(const ParallelLoopEmitter&) = delete; + ParallelLoopEmitter& operator=(const ParallelLoopEmitter&) = delete; + ~ParallelLoopEmitter() override = default; + + llvm_ir::IrArray::Index EmitIndexAndSetExitBasicBlock( + tensorflow::StringPiece loop_name) override; + + private: + const DynamicLoopBounds* dynamic_loop_bounds_; +}; + +} // namespace cpu +} // namespace xla + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_LOOP_EMITTER_H_ diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc index c2213c8f2ef592c537daf9abe2ffa10b83a8fa4c..4b44ac8941e222d5954121bbb9654062e41f55d6 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h" +#include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h" #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/cpu/shape_partition.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -101,11 +102,9 @@ class DefaultCostModel : public ParallelCostModel { const std::unique_ptr cost_analysis_; }; - ParallelTaskAssignment::ParallelTaskAssignment( const int64 max_parallelism, - const HloCostAnalysis::ShapeSizeFunction& shape_size, - HloModule* module) { + const HloCostAnalysis::ShapeSizeFunction& shape_size, HloModule* module) { VLOG(1) << "ParallelTaskAssignment max_parallelism: " << max_parallelism; // Run cost analysis on 'module'. auto cost_analysis = MakeUnique(shape_size); @@ -153,7 +152,6 @@ int64 ParallelTaskAssignment::GetTargetParallelTaskCount( StatusOr ParallelTaskAssigner::Run(HloModule* module) { XLA_VLOG_LINES(2, "ParallelTaskAssigner ENTRY"); XLA_VLOG_LINES(3, module->ToString()); - // Compute target parallel task counts for all instructions in 'module'. HloToParallelTasks hlo_to_parallel_tasks; ComputeTargetParallelTasks(module, &hlo_to_parallel_tasks); @@ -230,6 +228,9 @@ bool ParallelTaskAssigner::AssignParallelTasksHelper( void ParallelTaskAssigner::ComputeTargetParallelTasks( HloModule* module, HloToParallelTasks* hlo_to_parallel_tasks) { + ParallelTaskAssignment parallel_task_assignment(max_parallelism_, + shape_size_function_, module); + // Compute parallel task counts for all instructions in 'module'. for (auto* computation : module->computations()) { if (computation->IsFusionComputation()) { @@ -238,7 +239,7 @@ void ParallelTaskAssigner::ComputeTargetParallelTasks( for (auto* instruction : computation->instructions()) { // Query ParallelTaskAssignment for target parallel task count. const int64 target_parallel_task_count = - parallel_task_assignment_.GetTargetParallelTaskCount(instruction); + parallel_task_assignment.GetTargetParallelTaskCount(instruction); if (target_parallel_task_count > 1) { hlo_to_parallel_tasks->insert( {instruction, target_parallel_task_count}); diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h index e036da5784f6151eb3b01107ec7f3ab820071a60..5801ec8d270cdaed7f2f65c24987a9ea643edb02 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h @@ -37,10 +37,9 @@ class ParallelTaskAssignment { // 'shape_size': shape size function used by HloCostAnalysis during parallel // task assignment. // 'module': the containing HloModule. - ParallelTaskAssignment( - const int64 max_parallelism, - const HloCostAnalysis::ShapeSizeFunction& shape_size, - HloModule* module); + ParallelTaskAssignment(const int64 max_parallelism, + const HloCostAnalysis::ShapeSizeFunction& shape_size, + HloModule* module); ~ParallelTaskAssignment() {} // Computes and returns the target parallel task count for 'instruction'. @@ -63,11 +62,9 @@ class ParallelTaskAssigner : public HloPassInterface { // 'max_parallelism': the maximum parallel task count per instruction. // 'shape_size': shape size function used by HloCostAnalysis during parallel // task assignment. - // 'module': the containing HloModule. ParallelTaskAssigner(const int64 max_parallelism, - const HloCostAnalysis::ShapeSizeFunction& shape_size, - HloModule* module) - : parallel_task_assignment_(max_parallelism, shape_size, module) {} + const HloCostAnalysis::ShapeSizeFunction& shape_size) + : max_parallelism_(max_parallelism), shape_size_function_(shape_size) {} ~ParallelTaskAssigner() override {} tensorflow::StringPiece name() const override { @@ -95,7 +92,8 @@ class ParallelTaskAssigner : public HloPassInterface { void ComputeTargetParallelTasks(HloModule* module, HloToParallelTasks* hlo_to_parallel_tasks); - ParallelTaskAssignment parallel_task_assignment_; + int64 max_parallelism_; + HloCostAnalysis::ShapeSizeFunction shape_size_function_; }; } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index fdf02e5b422f75e256feec77470bb0d079e8ef1f..cda2783307925b77ac6d8cfe679c5b325db2befc 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_runtime_neon.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h" #include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h" +#include "tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.h" #include "tensorflow/compiler/xla/service/cpu/runtime_conv2d.h" #include "tensorflow/compiler/xla/service/cpu/runtime_fork_join.h" #include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h" @@ -125,8 +126,10 @@ SimpleOrcJIT::SimpleOrcJIT(const llvm::TargetOptions& target_options, /*MAttrs=*/DetectMachineAttributes()))), disassembler_(*target_machine_), data_layout_(target_machine_->createDataLayout()), - object_layer_( - [] { return std::make_shared(); }), + object_layer_([] { + return std::make_shared( + orc_jit_memory_mapper::GetInstance()); + }), compile_layer_( object_layer_, CompilerFunctor(target_machine_.get(), &disassembler_, opt_level, @@ -210,71 +213,75 @@ bool RegisterKnownJITSymbols() { #undef REGISTER_CPU_RUNTIME_SYMBOL -#define REGISTER_LIBM_SYMBOL(name) \ - do { \ - /* Register both the F32 and F64 variants of the libm symbol. */ \ - registry->Register(#name "f", reinterpret_cast(name##f)); \ - registry->Register(#name, reinterpret_cast(name)); \ +// Register both the f32 (float) and f64 (double) versions of a libm symbol. +// Unfortunately the double versions are overloaded on some systems, e.g. +// Mac so we need an explicit cast. This requires passing the function signature +// for that case. +#define REGISTER_LIBM_SYMBOL(name, double_sig) \ + do { \ + registry->Register(#name "f", reinterpret_cast(name##f)); \ + registry->Register( \ + #name, reinterpret_cast(static_cast(name))); \ } while (false) - REGISTER_LIBM_SYMBOL(acos); - REGISTER_LIBM_SYMBOL(acosh); - REGISTER_LIBM_SYMBOL(asin); - REGISTER_LIBM_SYMBOL(asinh); - REGISTER_LIBM_SYMBOL(atan); - REGISTER_LIBM_SYMBOL(atan2); - REGISTER_LIBM_SYMBOL(atanh); - REGISTER_LIBM_SYMBOL(cbrt); - REGISTER_LIBM_SYMBOL(ceil); - REGISTER_LIBM_SYMBOL(copysign); - REGISTER_LIBM_SYMBOL(cos); - REGISTER_LIBM_SYMBOL(cosh); - REGISTER_LIBM_SYMBOL(erf); - REGISTER_LIBM_SYMBOL(erfc); - REGISTER_LIBM_SYMBOL(exp); - REGISTER_LIBM_SYMBOL(exp2); - REGISTER_LIBM_SYMBOL(expm1); - REGISTER_LIBM_SYMBOL(fabs); - REGISTER_LIBM_SYMBOL(fdim); - REGISTER_LIBM_SYMBOL(floor); - REGISTER_LIBM_SYMBOL(fma); - REGISTER_LIBM_SYMBOL(fmax); - REGISTER_LIBM_SYMBOL(fmin); - REGISTER_LIBM_SYMBOL(fmod); - REGISTER_LIBM_SYMBOL(frexp); - REGISTER_LIBM_SYMBOL(hypot); - REGISTER_LIBM_SYMBOL(ilogb); - REGISTER_LIBM_SYMBOL(ldexp); - REGISTER_LIBM_SYMBOL(lgamma); - REGISTER_LIBM_SYMBOL(llrint); - REGISTER_LIBM_SYMBOL(llround); - REGISTER_LIBM_SYMBOL(log); - REGISTER_LIBM_SYMBOL(log10); - REGISTER_LIBM_SYMBOL(log1p); - REGISTER_LIBM_SYMBOL(log2); - REGISTER_LIBM_SYMBOL(logb); - REGISTER_LIBM_SYMBOL(lrint); - REGISTER_LIBM_SYMBOL(lround); - REGISTER_LIBM_SYMBOL(modf); - REGISTER_LIBM_SYMBOL(nan); - REGISTER_LIBM_SYMBOL(nearbyint); - REGISTER_LIBM_SYMBOL(nextafter); - REGISTER_LIBM_SYMBOL(nexttoward); - REGISTER_LIBM_SYMBOL(pow); - REGISTER_LIBM_SYMBOL(remainder); - REGISTER_LIBM_SYMBOL(remquo); - REGISTER_LIBM_SYMBOL(rint); - REGISTER_LIBM_SYMBOL(round); - REGISTER_LIBM_SYMBOL(scalbln); - REGISTER_LIBM_SYMBOL(scalbn); - REGISTER_LIBM_SYMBOL(sin); - REGISTER_LIBM_SYMBOL(sincos); - REGISTER_LIBM_SYMBOL(sinh); - REGISTER_LIBM_SYMBOL(sqrt); - REGISTER_LIBM_SYMBOL(tan); - REGISTER_LIBM_SYMBOL(tanh); - REGISTER_LIBM_SYMBOL(tgamma); - REGISTER_LIBM_SYMBOL(trunc); + REGISTER_LIBM_SYMBOL(acos, double (*)(double)); + REGISTER_LIBM_SYMBOL(acosh, double (*)(double)); + REGISTER_LIBM_SYMBOL(asin, double (*)(double)); + REGISTER_LIBM_SYMBOL(asinh, double (*)(double)); + REGISTER_LIBM_SYMBOL(atan, double (*)(double)); + REGISTER_LIBM_SYMBOL(atan2, double (*)(double, double)); + REGISTER_LIBM_SYMBOL(atanh, double (*)(double)); + REGISTER_LIBM_SYMBOL(cbrt, double (*)(double)); + REGISTER_LIBM_SYMBOL(ceil, double (*)(double)); + REGISTER_LIBM_SYMBOL(copysign, double (*)(double, double)); + REGISTER_LIBM_SYMBOL(cos, double (*)(double)); + REGISTER_LIBM_SYMBOL(cosh, double (*)(double)); + REGISTER_LIBM_SYMBOL(erf, double (*)(double)); + REGISTER_LIBM_SYMBOL(erfc, double (*)(double)); + REGISTER_LIBM_SYMBOL(exp, double (*)(double)); + REGISTER_LIBM_SYMBOL(exp2, double (*)(double)); + REGISTER_LIBM_SYMBOL(expm1, double (*)(double)); + REGISTER_LIBM_SYMBOL(fabs, double (*)(double)); + REGISTER_LIBM_SYMBOL(fdim, double (*)(double, double)); + REGISTER_LIBM_SYMBOL(floor, double (*)(double)); + REGISTER_LIBM_SYMBOL(fma, double (*)(double, double, double)); + REGISTER_LIBM_SYMBOL(fmax, double (*)(double, double)); + REGISTER_LIBM_SYMBOL(fmin, double (*)(double, double)); + REGISTER_LIBM_SYMBOL(fmod, double (*)(double, double)); + REGISTER_LIBM_SYMBOL(frexp, double (*)(double, int*)); + REGISTER_LIBM_SYMBOL(hypot, double (*)(double, double)); + REGISTER_LIBM_SYMBOL(ilogb, int (*)(double)); + REGISTER_LIBM_SYMBOL(ldexp, double (*)(double, int)); + REGISTER_LIBM_SYMBOL(lgamma, double (*)(double)); + REGISTER_LIBM_SYMBOL(llrint, long long (*)(double)); + REGISTER_LIBM_SYMBOL(llround, long long (*)(double)); + REGISTER_LIBM_SYMBOL(log, double (*)(double)); + REGISTER_LIBM_SYMBOL(log10, double (*)(double)); + REGISTER_LIBM_SYMBOL(log1p, double (*)(double)); + REGISTER_LIBM_SYMBOL(log2, double (*)(double)); + REGISTER_LIBM_SYMBOL(logb, double (*)(double)); + REGISTER_LIBM_SYMBOL(lrint, long (*)(double)); + REGISTER_LIBM_SYMBOL(lround, long (*)(double)); + REGISTER_LIBM_SYMBOL(modf, double (*)(double, double*)); + REGISTER_LIBM_SYMBOL(nan, double (*)(const char*)); + REGISTER_LIBM_SYMBOL(nearbyint, double (*)(double)); + REGISTER_LIBM_SYMBOL(nextafter, double (*)(double, double)); + REGISTER_LIBM_SYMBOL(nexttoward, double (*)(double, long double)); + REGISTER_LIBM_SYMBOL(pow, double (*)(double, double)); + REGISTER_LIBM_SYMBOL(remainder, double (*)(double, double)); + REGISTER_LIBM_SYMBOL(remquo, double (*)(double, double, int*)); + REGISTER_LIBM_SYMBOL(rint, double (*)(double)); + REGISTER_LIBM_SYMBOL(round, double (*)(double)); + REGISTER_LIBM_SYMBOL(scalbln, double (*)(double, long)); + REGISTER_LIBM_SYMBOL(scalbn, double (*)(double, int)); + REGISTER_LIBM_SYMBOL(sin, double (*)(double)); + REGISTER_LIBM_SYMBOL(sincos, void (*)(double, double*, double*)); + REGISTER_LIBM_SYMBOL(sinh, double (*)(double)); + REGISTER_LIBM_SYMBOL(sqrt, double (*)(double)); + REGISTER_LIBM_SYMBOL(tan, double (*)(double)); + REGISTER_LIBM_SYMBOL(tanh, double (*)(double)); + REGISTER_LIBM_SYMBOL(tgamma, double (*)(double)); + REGISTER_LIBM_SYMBOL(trunc, double (*)(double)); #undef REGISTER_LIBM_SYMBOL diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.cc b/tensorflow/compiler/xla/service/dfs_hlo_visitor.cc index 6efd0bcee58d19b355b6c2afa6d9497f75ef4b3c..2172ae0a29626660e8abd29a789e0baa3831519d 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.cc +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.cc @@ -24,37 +24,55 @@ limitations under the License. namespace xla { -Status DfsHloVisitor::HandleElementwiseUnary(HloInstruction* hlo) { +template +Status DfsHloVisitorBase::HandleElementwiseUnary( + HloInstructionPtr hlo) { return Unimplemented("DfsHloVisitor::HandleElementwiseUnary: %s", HloOpcodeString(hlo->opcode()).c_str()); } -Status DfsHloVisitor::HandleElementwiseBinary(HloInstruction* hlo) { +template +Status DfsHloVisitorBase::HandleElementwiseBinary( + HloInstructionPtr hlo) { return Unimplemented("DfsHloVisitor::HandleElementwiseBinary: %s", HloOpcodeString(hlo->opcode()).c_str()); } -DfsHloVisitor::VisitState DfsHloVisitor::GetVisitState( +template +typename DfsHloVisitorBase::VisitState +DfsHloVisitorBase::GetVisitState( const HloInstruction& instruction) { return GetVisitState(instruction.unique_id()); } -void DfsHloVisitor::SetVisiting(const HloInstruction& instruction) { +template +void DfsHloVisitorBase::SetVisiting( + const HloInstruction& instruction) { VLOG(3) << "marking HLO " << &instruction << " as visiting: "; DCHECK(NotVisited(instruction)); visit_state_.SetState(instruction.unique_id(), VisitState::kVisiting); } -void DfsHloVisitor::SetVisited(const HloInstruction& instruction) { +template +void DfsHloVisitorBase::SetVisited( + const HloInstruction& instruction) { VLOG(3) << "marking HLO " << &instruction << " as visited: "; DCHECK(NotVisited(instruction) || IsVisiting(instruction)); visit_state_.SetState(instruction.unique_id(), VisitState::kVisited); } -Status DfsHloVisitor::Preprocess(HloInstruction* hlo) { return Status::OK(); } +template +Status DfsHloVisitorBase::Preprocess(HloInstructionPtr) { + return Status::OK(); +} -Status DfsHloVisitor::Postprocess(HloInstruction* visited) { +template +Status DfsHloVisitorBase::Postprocess(HloInstructionPtr) { return Status::OK(); } +// Explicit instantiations. +template class DfsHloVisitorBase; +template class DfsHloVisitorBase; + } // namespace xla diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index 237cd8c31de1ba1aa97739c579d6d92264ddc61b..91086fd4a5f68211ef56c2417bb0ef4a38de2cff 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_H_ +#include #include #include "tensorflow/compiler/xla/literal_util.h" @@ -52,170 +53,183 @@ class HloInstruction; // "unimplemented" error status. // // Note: this may change to an iterator in the future for flexibility purposes. -class DfsHloVisitor { +// +// Users should not use this class directly, but use the type-aliases +// DfsHloVisitor/ConstDfsHloVisitor instead. +template +class DfsHloVisitorBase { + static_assert( + std::is_same::value || + std::is_same::value, + "Template argument expected to be HloInstruction* or const " + "HloInstruction*"); + public: - DfsHloVisitor() {} - virtual ~DfsHloVisitor() {} + DfsHloVisitorBase() {} + virtual ~DfsHloVisitorBase() {} // These routines are self-descriptive, see class comment for usage // information. - virtual Status HandleElementwiseUnary(HloInstruction* hlo); - virtual Status HandleElementwiseBinary(HloInstruction* hlo); - virtual Status HandleClamp(HloInstruction* clamp) = 0; - virtual Status HandleSelect(HloInstruction* select) = 0; - virtual Status HandleMaximum(HloInstruction* maximum) { - return HandleElementwiseBinary(maximum); + virtual Status HandleElementwiseUnary(HloInstructionPtr hlo); + virtual Status HandleElementwiseBinary(HloInstructionPtr hlo); + + virtual Status HandleClamp(HloInstructionPtr hlo) = 0; + virtual Status HandleSelect(HloInstructionPtr hlo) = 0; + virtual Status HandleMaximum(HloInstructionPtr hlo) { + return HandleElementwiseBinary(hlo); } - virtual Status HandleMinimum(HloInstruction* minimum) { - return HandleElementwiseBinary(minimum); + virtual Status HandleMinimum(HloInstructionPtr hlo) { + return HandleElementwiseBinary(hlo); } - virtual Status HandleConcatenate(HloInstruction* concatenate) = 0; - virtual Status HandleConvert(HloInstruction* convert) { - return HandleElementwiseUnary(convert); + virtual Status HandleConcatenate(HloInstructionPtr hlo) = 0; + virtual Status HandleConvert(HloInstructionPtr hlo) { + return HandleElementwiseUnary(hlo); } - virtual Status HandleCopy(HloInstruction* copy) { - return HandleElementwiseUnary(copy); + virtual Status HandleBitcastConvert(HloInstructionPtr hlo) { + return HandleElementwiseUnary(hlo); } - virtual Status HandleComplex(HloInstruction* complex) { - return HandleElementwiseBinary(complex); + virtual Status HandleCopy(HloInstructionPtr hlo) { + return HandleElementwiseUnary(hlo); } - virtual Status HandleMultiply(HloInstruction* multiply) { - return HandleElementwiseBinary(multiply); + virtual Status HandleComplex(HloInstructionPtr hlo) { + return HandleElementwiseBinary(hlo); } - virtual Status HandleDot(HloInstruction* dot) = 0; - virtual Status HandlePower(HloInstruction* power) { - return HandleElementwiseBinary(power); + virtual Status HandleMultiply(HloInstructionPtr hlo) { + return HandleElementwiseBinary(hlo); } - virtual Status HandleConvolution(HloInstruction* convolution) = 0; - virtual Status HandleCrossReplicaSum(HloInstruction* crs) = 0; - virtual Status HandleCompare(HloInstruction* compare) { - return HandleElementwiseBinary(compare); + virtual Status HandleDot(HloInstructionPtr hlo) = 0; + virtual Status HandlePower(HloInstructionPtr hlo) { + return HandleElementwiseBinary(hlo); } - virtual Status HandleAdd(HloInstruction* add) { - return HandleElementwiseBinary(add); + virtual Status HandleConvolution(HloInstructionPtr hlo) = 0; + virtual Status HandleCrossReplicaSum(HloInstructionPtr hlo) = 0; + virtual Status HandleCompare(HloInstructionPtr hlo) { + return HandleElementwiseBinary(hlo); } - virtual Status HandleDivide(HloInstruction* divide) { - return HandleElementwiseBinary(divide); + virtual Status HandleAdd(HloInstructionPtr hlo) { + return HandleElementwiseBinary(hlo); } - virtual Status HandleRemainder(HloInstruction* remainder) { - return HandleElementwiseBinary(remainder); + virtual Status HandleDivide(HloInstructionPtr hlo) { + return HandleElementwiseBinary(hlo); } - virtual Status HandleSubtract(HloInstruction* subtract) { - return HandleElementwiseBinary(subtract); + virtual Status HandleRemainder(HloInstructionPtr hlo) { + return HandleElementwiseBinary(hlo); } - virtual Status HandleAbs(HloInstruction* abs) { - return HandleElementwiseUnary(abs); + virtual Status HandleSubtract(HloInstructionPtr hlo) { + return HandleElementwiseBinary(hlo); } - virtual Status HandleAtan2(HloInstruction* atan2) { - return HandleElementwiseBinary(atan2); + virtual Status HandleAbs(HloInstructionPtr hlo) { + return HandleElementwiseUnary(hlo); } - virtual Status HandleRound(HloInstruction* round) { - return HandleElementwiseUnary(round); + virtual Status HandleAtan2(HloInstructionPtr hlo) { + return HandleElementwiseBinary(hlo); } - virtual Status HandleSign(HloInstruction* sign) { - return HandleElementwiseUnary(sign); + virtual Status HandleRound(HloInstructionPtr hlo) { + return HandleElementwiseUnary(hlo); } - virtual Status HandleNegate(HloInstruction* negate) { - return HandleElementwiseUnary(negate); + virtual Status HandleSign(HloInstructionPtr hlo) { + return HandleElementwiseUnary(hlo); } - virtual Status HandleExp(HloInstruction* exp) { - return HandleElementwiseUnary(exp); + virtual Status HandleNegate(HloInstructionPtr hlo) { + return HandleElementwiseUnary(hlo); } - virtual Status HandleFloor(HloInstruction* floor) { - return HandleElementwiseUnary(floor); + virtual Status HandleExp(HloInstructionPtr hlo) { + return HandleElementwiseUnary(hlo); } - virtual Status HandleCeil(HloInstruction* ceil) { - return HandleElementwiseUnary(ceil); + virtual Status HandleFloor(HloInstructionPtr hlo) { + return HandleElementwiseUnary(hlo); } - virtual Status HandleLog(HloInstruction* log) { - return HandleElementwiseUnary(log); + virtual Status HandleCeil(HloInstructionPtr hlo) { + return HandleElementwiseUnary(hlo); } - virtual Status HandleCos(HloInstruction* cos) { - return HandleElementwiseUnary(cos); + virtual Status HandleLog(HloInstructionPtr hlo) { + return HandleElementwiseUnary(hlo); } - virtual Status HandleSin(HloInstruction* sin) { - return HandleElementwiseUnary(sin); + virtual Status HandleCos(HloInstructionPtr hlo) { + return HandleElementwiseUnary(hlo); } - virtual Status HandleTanh(HloInstruction* tanh) { - return HandleElementwiseUnary(tanh); + virtual Status HandleSin(HloInstructionPtr hlo) { + return HandleElementwiseUnary(hlo); } - virtual Status HandleReal(HloInstruction* real) { - return HandleElementwiseUnary(real); + virtual Status HandleTanh(HloInstructionPtr hlo) { + return HandleElementwiseUnary(hlo); } - virtual Status HandleImag(HloInstruction* imag) { - return HandleElementwiseUnary(imag); + virtual Status HandleReal(HloInstructionPtr hlo) { + return HandleElementwiseUnary(hlo); } - virtual Status HandleIsFinite(HloInstruction* is_finite) { - return HandleElementwiseUnary(is_finite); + virtual Status HandleImag(HloInstructionPtr hlo) { + return HandleElementwiseUnary(hlo); } - virtual Status HandleAnd(HloInstruction* and_) { - return HandleElementwiseBinary(and_); + virtual Status HandleIsFinite(HloInstructionPtr hlo) { + return HandleElementwiseUnary(hlo); } - virtual Status HandleNot(HloInstruction* not_) { - return HandleElementwiseUnary(not_); + virtual Status HandleAnd(HloInstructionPtr hlo) { + return HandleElementwiseBinary(hlo); } - virtual Status HandleOr(HloInstruction* or_) { - return HandleElementwiseBinary(or_); + virtual Status HandleNot(HloInstructionPtr hlo) { + return HandleElementwiseUnary(hlo); } - virtual Status HandleShiftLeft(HloInstruction* shift_left) { - return HandleElementwiseBinary(shift_left); + virtual Status HandleOr(HloInstructionPtr hlo) { + return HandleElementwiseBinary(hlo); } - virtual Status HandleShiftRightArithmetic( - HloInstruction* shift_right_arithmetic) { - return HandleElementwiseBinary(shift_right_arithmetic); + virtual Status HandleShiftLeft(HloInstructionPtr hlo) { + return HandleElementwiseBinary(hlo); } - virtual Status HandleShiftRightLogical(HloInstruction* shift_right_logical) { - return HandleElementwiseBinary(shift_right_logical); + virtual Status HandleShiftRightArithmetic(HloInstructionPtr hlo) { + return HandleElementwiseBinary(hlo); + } + virtual Status HandleShiftRightLogical(HloInstructionPtr hlo) { + return HandleElementwiseBinary(hlo); } - virtual Status HandleReducePrecision(HloInstruction* reduce_precision) { - return HandleElementwiseUnary(reduce_precision); + virtual Status HandleReducePrecision(HloInstructionPtr hlo) { + return HandleElementwiseUnary(hlo); } - virtual Status HandleInfeed(HloInstruction* infeed) = 0; - virtual Status HandleOutfeed(HloInstruction* outfeed) = 0; - virtual Status HandleRng(HloInstruction* random) = 0; - virtual Status HandleReverse(HloInstruction* reverse) = 0; - virtual Status HandleSort(HloInstruction* sort) = 0; - virtual Status HandleConstant(HloInstruction* constant) = 0; - virtual Status HandleGetTupleElement(HloInstruction* get_tuple_element) = 0; - virtual Status HandleReduce(HloInstruction* reduce) = 0; - virtual Status HandleBitcast(HloInstruction* bitcast) = 0; - virtual Status HandleBroadcast(HloInstruction* broadcast) = 0; - virtual Status HandleReshape(HloInstruction* reshape) = 0; - virtual Status HandleTranspose(HloInstruction* transpose) = 0; - virtual Status HandleParameter(HloInstruction* parameter) = 0; - virtual Status HandleFusion(HloInstruction* fusion) = 0; - virtual Status HandleCall(HloInstruction* call) = 0; - virtual Status HandleCustomCall(HloInstruction* custom_call) = 0; - virtual Status HandleSlice(HloInstruction* slice) = 0; - virtual Status HandleDynamicSlice(HloInstruction* dynamic_slice) = 0; - virtual Status HandleDynamicUpdateSlice( - HloInstruction* dynamic_update_slice) = 0; - virtual Status HandleTuple(HloInstruction* tuple) = 0; - virtual Status HandleMap(HloInstruction* map) = 0; - virtual Status HandleReduceWindow(HloInstruction* reduce_window) = 0; - virtual Status HandleSelectAndScatter(HloInstruction* instruction) = 0; - virtual Status HandleWhile(HloInstruction* xla_while) = 0; + virtual Status HandleInfeed(HloInstructionPtr hlo) = 0; + virtual Status HandleOutfeed(HloInstructionPtr hlo) = 0; + virtual Status HandleRng(HloInstructionPtr hlo) = 0; + virtual Status HandleReverse(HloInstructionPtr hlo) = 0; + virtual Status HandleSort(HloInstructionPtr hlo) = 0; + virtual Status HandleConstant(HloInstructionPtr hlo) = 0; + virtual Status HandleGetTupleElement(HloInstructionPtr hlo) = 0; + virtual Status HandleReduce(HloInstructionPtr hlo) = 0; + virtual Status HandleBitcast(HloInstructionPtr hlo) = 0; + virtual Status HandleBroadcast(HloInstructionPtr hlo) = 0; + virtual Status HandleReshape(HloInstructionPtr hlo) = 0; + virtual Status HandleTranspose(HloInstructionPtr hlo) = 0; + virtual Status HandleParameter(HloInstructionPtr hlo) = 0; + virtual Status HandleFusion(HloInstructionPtr hlo) = 0; + virtual Status HandleCall(HloInstructionPtr hlo) = 0; + virtual Status HandleCustomCall(HloInstructionPtr hlo) = 0; + virtual Status HandleSlice(HloInstructionPtr hlo) = 0; + virtual Status HandleDynamicSlice(HloInstructionPtr hlo) = 0; + virtual Status HandleDynamicUpdateSlice(HloInstructionPtr hlo) = 0; + virtual Status HandleTuple(HloInstructionPtr hlo) = 0; + virtual Status HandleMap(HloInstructionPtr hlo) = 0; + virtual Status HandleReduceWindow(HloInstructionPtr hlo) = 0; + virtual Status HandleSelectAndScatter(HloInstructionPtr hlo) = 0; + virtual Status HandleWhile(HloInstructionPtr hlo) = 0; + virtual Status HandleConditional(HloInstructionPtr hlo) = 0; - virtual Status HandlePad(HloInstruction* pad) = 0; + virtual Status HandlePad(HloInstructionPtr hlo) = 0; - virtual Status HandleSend(HloInstruction* send) = 0; + virtual Status HandleSend(HloInstructionPtr send) = 0; + virtual Status HandleSendDone(HloInstructionPtr send_done) = 0; - virtual Status HandleRecv(HloInstruction* recv) = 0; + virtual Status HandleRecv(HloInstructionPtr recv) = 0; + virtual Status HandleRecvDone(HloInstructionPtr recv_done) = 0; - virtual Status HandleBatchNormTraining( - HloInstruction* batch_norm_training) = 0; + virtual Status HandleBatchNormTraining(HloInstructionPtr hlo) = 0; - virtual Status HandleBatchNormInference( - HloInstruction* batch_norm_inference) = 0; + virtual Status HandleBatchNormInference(HloInstructionPtr hlo) = 0; - virtual Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) = 0; + virtual Status HandleBatchNormGrad(HloInstructionPtr hlo) = 0; // Invoked to inform the visitor that the traversal has completed, and that // the root was "root". - virtual Status FinishVisit(HloInstruction* root) = 0; + virtual Status FinishVisit(HloInstructionPtr root) = 0; // 3 possible visitation states of HLO instructions. Each instruction's // state only flows one way: kNotVisited -> kVisiting -> kVisited. @@ -273,7 +287,7 @@ class DfsHloVisitor { // // Overriding methods should call DfsHloVisitor::Preprocess before doing their // own preprocessing. - virtual Status Preprocess(HloInstruction* hlo); + virtual Status Preprocess(HloInstructionPtr hlo); // This method should be overridden by subclasses that wish to run some // operation on an op after its Handle* visitor method is called. See @@ -281,7 +295,7 @@ class DfsHloVisitor { // // Overriding methods should call DfsHloVisitor::Postprocess after doing their // own postprocessing. - virtual Status Postprocess(HloInstruction* visited); + virtual Status Postprocess(HloInstructionPtr hlo); private: class DFSVisitStates { @@ -322,9 +336,14 @@ class DfsHloVisitor { DFSVisitStates visit_state_; - TF_DISALLOW_COPY_AND_ASSIGN(DfsHloVisitor); + TF_DISALLOW_COPY_AND_ASSIGN(DfsHloVisitorBase); }; +// Users should use one of these two type aliases, which are the only two valid +// instantiations of DfsHloVisitorBase. +using DfsHloVisitor = DfsHloVisitorBase; +using ConstDfsHloVisitor = DfsHloVisitorBase; + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_H_ diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h index a1d7acf90429e3611bb6dea56d98bbd6ffb8f580..133aa2509405738de8388708b0c61a82023e2738 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h @@ -33,161 +33,198 @@ class HloComputation; class HloInstruction; // DfsHloVisitor with default action based on the HloInstruction being visited. -class DfsHloVisitorWithDefault : public DfsHloVisitor { +// Users should not use this class directly, but use the type aliases +// DfsHloVisitorWithDefault/ConstDfsHloVisitorWithDefault instead. +template +class DfsHloVisitorWithDefaultBase + : public DfsHloVisitorBase { public: - DfsHloVisitorWithDefault() {} - ~DfsHloVisitorWithDefault() override {} + DfsHloVisitorWithDefaultBase() {} + ~DfsHloVisitorWithDefaultBase() override {} // Default action performed on HloInstruction. - virtual Status DefaultAction(HloInstruction* hlo_instruction) = 0; + virtual Status DefaultAction(HloInstructionPtr hlo_instruction) = 0; - Status HandleElementwiseUnary(HloInstruction* hlo) override { + Status HandleElementwiseUnary(HloInstructionPtr hlo) override { return DefaultAction(hlo); } - Status HandleElementwiseBinary(HloInstruction* hlo) override { + Status HandleElementwiseBinary(HloInstructionPtr hlo) override { return DefaultAction(hlo); } - Status HandleBatchNormTraining(HloInstruction* hlo) override { + Status HandleBatchNormTraining(HloInstructionPtr hlo) override { return DefaultAction(hlo); } - Status HandleBatchNormInference(HloInstruction* hlo) override { + Status HandleBatchNormInference(HloInstructionPtr hlo) override { return DefaultAction(hlo); } - Status HandleBatchNormGrad(HloInstruction* hlo) override { + Status HandleBatchNormGrad(HloInstructionPtr hlo) override { return DefaultAction(hlo); } - Status HandleClamp(HloInstruction* clamp) override { + Status HandleClamp(HloInstructionPtr clamp) override { return DefaultAction(clamp); } - Status HandleConcatenate(HloInstruction* concatenate) override { + Status HandleConcatenate(HloInstructionPtr concatenate) override { return DefaultAction(concatenate); } - Status HandleConvert(HloInstruction* convert) override { + Status HandleConvert(HloInstructionPtr convert) override { return DefaultAction(convert); } - Status HandleCopy(HloInstruction* copy) override { + Status HandleCopy(HloInstructionPtr copy) override { return DefaultAction(copy); } - Status HandleSelect(HloInstruction* select) override { + Status HandleSelect(HloInstructionPtr select) override { return DefaultAction(select); } - Status HandleDot(HloInstruction* dot) override { return DefaultAction(dot); } - Status HandleConvolution(HloInstruction* convolution) override { + Status HandleDot(HloInstructionPtr dot) override { + return DefaultAction(dot); + } + Status HandleConvolution(HloInstructionPtr convolution) override { return DefaultAction(convolution); } - Status HandleCrossReplicaSum(HloInstruction* crs) override { + Status HandleCrossReplicaSum(HloInstructionPtr crs) override { return DefaultAction(crs); } - Status HandleCompare(HloInstruction* compare) override { + Status HandleCompare(HloInstructionPtr compare) override { return DefaultAction(compare); } - Status HandleRng(HloInstruction* random) override { + Status HandleRng(HloInstructionPtr random) override { return DefaultAction(random); } - Status HandleInfeed(HloInstruction* infeed) override { + Status HandleInfeed(HloInstructionPtr infeed) override { return DefaultAction(infeed); } - Status HandleOutfeed(HloInstruction* outfeed) override { + Status HandleOutfeed(HloInstructionPtr outfeed) override { return DefaultAction(outfeed); } - Status HandleReverse(HloInstruction* reverse) override { + Status HandleReverse(HloInstructionPtr reverse) override { return DefaultAction(reverse); } - Status HandleSort(HloInstruction* sort) override { + Status HandleSort(HloInstructionPtr sort) override { return DefaultAction(sort); } - Status HandleConstant(HloInstruction* constant) override { + Status HandleConstant(HloInstructionPtr constant) override { return DefaultAction(constant); } - Status HandleGetTupleElement(HloInstruction* get_tuple_element) override { + Status HandleGetTupleElement(HloInstructionPtr get_tuple_element) override { return DefaultAction(get_tuple_element); } - Status HandleParameter(HloInstruction* parameter) override { + Status HandleParameter(HloInstructionPtr parameter) override { return DefaultAction(parameter); } - Status HandleFusion(HloInstruction* fusion) override { + Status HandleFusion(HloInstructionPtr fusion) override { return DefaultAction(fusion); } - Status HandleCall(HloInstruction* call) override { + Status HandleCall(HloInstructionPtr call) override { return DefaultAction(call); } - Status HandleCustomCall(HloInstruction* custom_call) override { + Status HandleCustomCall(HloInstructionPtr custom_call) override { return DefaultAction(custom_call); } - Status HandleSlice(HloInstruction* slice) override { + Status HandleSlice(HloInstructionPtr slice) override { return DefaultAction(slice); } - Status HandleDynamicSlice(HloInstruction* dynamic_slice) override { + Status HandleDynamicSlice(HloInstructionPtr dynamic_slice) override { return DefaultAction(dynamic_slice); } Status HandleDynamicUpdateSlice( - HloInstruction* dynamic_update_slice) override { + HloInstructionPtr dynamic_update_slice) override { return DefaultAction(dynamic_update_slice); } - Status HandleTuple(HloInstruction* tuple) override { + Status HandleTuple(HloInstructionPtr tuple) override { return DefaultAction(tuple); } - Status HandleMap(HloInstruction* map) override { return DefaultAction(map); } - Status HandleReduce(HloInstruction* reduce) override { + Status HandleMap(HloInstructionPtr map) override { + return DefaultAction(map); + } + Status HandleReduce(HloInstructionPtr reduce) override { return DefaultAction(reduce); } - Status HandleReduceWindow(HloInstruction* reduce_window) override { + Status HandleReduceWindow(HloInstructionPtr reduce_window) override { return DefaultAction(reduce_window); } - Status HandleSelectAndScatter(HloInstruction* select_and_scatter) override { + Status HandleSelectAndScatter(HloInstructionPtr select_and_scatter) override { return DefaultAction(select_and_scatter); } - Status HandleBitcast(HloInstruction* bitcast) override { + Status HandleBitcast(HloInstructionPtr bitcast) override { return DefaultAction(bitcast); } - Status HandleBroadcast(HloInstruction* broadcast) override { + Status HandleBroadcast(HloInstructionPtr broadcast) override { return DefaultAction(broadcast); } - Status HandlePad(HloInstruction* pad) override { return DefaultAction(pad); } - Status HandleReshape(HloInstruction* reshape) override { + Status HandlePad(HloInstructionPtr pad) override { + return DefaultAction(pad); + } + Status HandleReshape(HloInstructionPtr reshape) override { return DefaultAction(reshape); } - Status HandleTranspose(HloInstruction* transpose) override { + Status HandleTranspose(HloInstructionPtr transpose) override { return DefaultAction(transpose); } - Status HandleWhile(HloInstruction* xla_while) override { + Status HandleWhile(HloInstructionPtr xla_while) override { return DefaultAction(xla_while); } - Status HandleSend(HloInstruction* send) override { - return DefaultAction(send); + Status HandleConditional(HloInstructionPtr conditional) override { + return DefaultAction(conditional); } - Status HandleRecv(HloInstruction* recv) override { + Status HandleRecv(HloInstructionPtr recv) override { return DefaultAction(recv); } + Status HandleRecvDone(HloInstructionPtr recv_done) override { + return DefaultAction(recv_done); + } + Status HandleSend(HloInstructionPtr send) override { + return DefaultAction(send); + } + Status HandleSendDone(HloInstructionPtr send_done) override { + return DefaultAction(send_done); + } // Invoked to inform the visitor that the traversal has completed, and that // the root was "root". - Status FinishVisit(HloInstruction* /*root*/) override { return Status::OK(); } + Status FinishVisit(HloInstructionPtr /*root*/) override { + return Status::OK(); + } private: - TF_DISALLOW_COPY_AND_ASSIGN(DfsHloVisitorWithDefault); + TF_DISALLOW_COPY_AND_ASSIGN(DfsHloVisitorWithDefaultBase); }; -// Helper class for Accept(VisitorFunction) which visits instructions in DFS -// order calling the given function at each instruction. -class FunctionVisitor : public DfsHloVisitorWithDefault { +// Users should use these type aliases which are only two valid instantiations. +using DfsHloVisitorWithDefault = DfsHloVisitorWithDefaultBase; +using ConstDfsHloVisitorWithDefault = + DfsHloVisitorWithDefaultBase; + +// (Const)FunctionVisitor lets you transform an +// std::function into a (Const)DfsHloVisitor. +// +// This is useful if you have code that needs to handle visitors in the form of +// both std::function and DfsHloVisitor. You can wrap the function in a +// FunctionVisitor and then treat it like any other DfsHloVisitor. +template +class FunctionVisitorBase + : public DfsHloVisitorWithDefaultBase { public: - using VisitorFunction = std::function; - explicit FunctionVisitor(VisitorFunction visitor_func) + explicit FunctionVisitorBase( + std::function visitor_func) : visitor_func_(std::move(visitor_func)) {} - Status DefaultAction(HloInstruction* hlo_instruction) override { + Status DefaultAction(HloInstructionPtr hlo_instruction) override { return visitor_func_(hlo_instruction); } private: - VisitorFunction visitor_func_; + TF_DISALLOW_COPY_AND_ASSIGN(FunctionVisitorBase); + + std::function visitor_func_; }; +using FunctionVisitor = FunctionVisitorBase; +using ConstFunctionVisitor = FunctionVisitorBase; + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_ diff --git a/tensorflow/compiler/xla/service/dot_decomposer.cc b/tensorflow/compiler/xla/service/dot_decomposer.cc new file mode 100644 index 0000000000000000000000000000000000000000..12faed69677cd99c6ed82c8d13dad3138d9461b7 --- /dev/null +++ b/tensorflow/compiler/xla/service/dot_decomposer.cc @@ -0,0 +1,185 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/dot_decomposer.h" + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +namespace { + +// TODO(b/69062148) Remove this code when all backends support BatchDot +// natively. +Status DecomposeBatchDot(HloInstruction* dot) { + auto computation = dot->parent(); + const DotDimensionNumbers& dnums = dot->dot_dimension_numbers(); + HloInstruction* lhs = dot->mutable_operand(0); + HloInstruction* rhs = dot->mutable_operand(1); + const Shape& lhs_shape = lhs->shape(); + const Shape& rhs_shape = rhs->shape(); + const Shape& dot_shape = dot->shape(); + + // ShapeInference should guarantee that lhs/rhs batch dimensions match. + CHECK_EQ(dnums.lhs_batch_dimensions_size(), + dnums.rhs_batch_dimensions_size()); + const int64 num_batch_dims = dnums.lhs_batch_dimensions_size(); + // Calculate total batch size (note that ShapeInference requires that + // the batch dimensions are most-major). + int64 batch_size = 1; + for (int i = 0; i < num_batch_dims; ++i) { + CHECK_EQ(lhs_shape.dimensions(dnums.lhs_batch_dimensions(i)), + rhs_shape.dimensions(dnums.rhs_batch_dimensions(i))); + batch_size *= lhs_shape.dimensions(dnums.lhs_batch_dimensions(i)); + } + + // Set lhs/rhs_transpose. + CHECK_EQ(1, dnums.lhs_contracting_dimensions_size()); + const int64 lhs_contracting_dim_number = dnums.lhs_contracting_dimensions(0); + const bool lhs_transpose = (lhs_contracting_dim_number - num_batch_dims) == 0; + + CHECK_EQ(1, dnums.rhs_contracting_dimensions_size()); + const int64 rhs_contracting_dim_number = dnums.rhs_contracting_dimensions(0); + const bool rhs_transpose = (rhs_contracting_dim_number - num_batch_dims) == 1; + + // Compute R3 and R3 shapes for lhs. + PrimitiveType lhs_type = lhs_shape.element_type(); + const int64 lhs_rows = lhs_shape.dimensions(num_batch_dims + 0); + const int64 lhs_cols = lhs_shape.dimensions(num_batch_dims + 1); + Shape lhs_shape_r3 = + ShapeUtil::MakeShape(lhs_type, {batch_size, lhs_rows, lhs_cols}); + Shape lhs_slice_shape_r3 = + ShapeUtil::MakeShape(lhs_type, {1, lhs_rows, lhs_cols}); + Shape lhs_slice_shape_r2 = + ShapeUtil::MakeShape(lhs_type, {lhs_rows, lhs_cols}); + + // Compute R3 and R3 shapes for rhs. + PrimitiveType rhs_type = rhs_shape.element_type(); + const int64 rhs_rows = rhs_shape.dimensions(num_batch_dims + 0); + const int64 rhs_cols = rhs_shape.dimensions(num_batch_dims + 1); + Shape rhs_shape_r3 = + ShapeUtil::MakeShape(rhs_type, {batch_size, rhs_rows, rhs_cols}); + Shape rhs_slice_shape_r3 = + ShapeUtil::MakeShape(rhs_type, {1, rhs_rows, rhs_cols}); + Shape rhs_slice_shape_r2 = + ShapeUtil::MakeShape(rhs_type, {rhs_rows, rhs_cols}); + + // Compute R3 and R3 shapes for dot output. + PrimitiveType dot_type = dot_shape.element_type(); + const int64 dot_rows = dot_shape.dimensions(num_batch_dims + 0); + const int64 dot_cols = dot_shape.dimensions(num_batch_dims + 1); + Shape dot_shape_r2 = ShapeUtil::MakeShape(dot_type, {dot_rows, dot_cols}); + Shape dot_shape_r3 = ShapeUtil::MakeShape(dot_type, {1, dot_rows, dot_cols}); + Shape concat_shape_r3 = + ShapeUtil::MakeShape(dot_type, {batch_size, dot_rows, dot_cols}); + + // Reshape lhs/rhs into R3. + auto lhs_r3 = computation->AddInstruction( + HloInstruction::CreateReshape(lhs_shape_r3, lhs)); + auto rhs_r3 = computation->AddInstruction( + HloInstruction::CreateReshape(rhs_shape_r3, rhs)); + + // Loop through batch size, slicing out required lhs/rhs to compute each Dot. + std::vector output_slices(batch_size); + for (int64 i = 0; i < batch_size; ++i) { + // Slice R3 shape from 'lhs' and reshape to R2. + auto lhs_slice_r3 = computation->AddInstruction( + HloInstruction::CreateSlice(lhs_slice_shape_r3, lhs_r3, {i, 0, 0}, + {i + 1, lhs_rows, lhs_cols}, {1, 1, 1})); + auto lhs_slice_r2 = computation->AddInstruction( + HloInstruction::CreateReshape(lhs_slice_shape_r2, lhs_slice_r3)); + + // Slice R3 shape from 'rhs' and reshape to R2. + auto rhs_slice_r3 = computation->AddInstruction( + HloInstruction::CreateSlice(rhs_slice_shape_r3, rhs_r3, {i, 0, 0}, + {i + 1, rhs_rows, rhs_cols}, {1, 1, 1})); + auto rhs_slice_r2 = computation->AddInstruction( + HloInstruction::CreateReshape(rhs_slice_shape_r2, rhs_slice_r3)); + + // Transpose lhs/rhs (if needed). + if (lhs_transpose) { + Shape lhs_slice_shape_r2_transpose = + ShapeUtil::MakeShape(lhs_type, {lhs_cols, lhs_rows}); + lhs_slice_r2 = + computation->AddInstruction(HloInstruction::CreateTranspose( + lhs_slice_shape_r2_transpose, lhs_slice_r2, {1, 0})); + } + if (rhs_transpose) { + Shape rhs_slice_shape_r2_transpose = + ShapeUtil::MakeShape(rhs_type, {rhs_cols, rhs_rows}); + rhs_slice_r2 = + computation->AddInstruction(HloInstruction::CreateTranspose( + rhs_slice_shape_r2_transpose, rhs_slice_r2, {1, 0})); + } + + // Compute Dot of lhs/rhs R2 slices. + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + auto dot_r2 = computation->AddInstruction(HloInstruction::CreateDot( + dot_shape_r2, lhs_slice_r2, rhs_slice_r2, dot_dnums)); + + // Reshape Dot to R3 so we can concat along batch dimension. + auto dot_r3 = computation->AddInstruction( + HloInstruction::CreateReshape(dot_shape_r3, dot_r2)); + + output_slices[i] = dot_r3; + } + + // Concatenate slices from 'output_slices' along batch dimension. + auto concat = computation->AddInstruction( + HloInstruction::CreateConcatenate(concat_shape_r3, output_slices, 0)); + // Reshape output 'new_dot' to original dimensions. + auto new_dot = computation->AddInstruction( + HloInstruction::CreateReshape(dot_shape, concat)); + + // Replace all uses of 'dot' in 'computation' with 'new_dot'. + return computation->ReplaceInstruction(dot, new_dot); +} + +} // namespace + +StatusOr DotDecomposer::Run(HloModule* module) { + XLA_VLOG_LINES(2, "DotDecomposer ENTRY\n" + module->ToString()); + // Gather all batch Dot operations. + std::vector batch_dots; + for (auto* computation : module->MakeNonfusionComputations()) { + for (auto* instruction : computation->instructions()) { + if (instruction->opcode() != HloOpcode::kDot) { + continue; + } + const DotDimensionNumbers& dnums = instruction->dot_dimension_numbers(); + if (dnums.lhs_batch_dimensions_size() > 0 && decompose_batch_dot_) { + batch_dots.push_back(instruction); + } + } + } + // Decompose each batch Dot in 'batch_dots'. + bool changed = false; + for (auto* dot : batch_dots) { + TF_RETURN_IF_ERROR(DecomposeBatchDot(dot)); + changed = true; + } + XLA_VLOG_LINES(2, "DotDecompose EXIT\n" + module->ToString()); + return changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/dot_decomposer.h b/tensorflow/compiler/xla/service/dot_decomposer.h new file mode 100644 index 0000000000000000000000000000000000000000..5ff0ab34eac0cd0fbc264b408c57653c944402a6 --- /dev/null +++ b/tensorflow/compiler/xla/service/dot_decomposer.h @@ -0,0 +1,44 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_DOT_DECOMPOSER_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_DOT_DECOMPOSER_H_ + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// DotDecomposer is a pass which decomposes batch Dot operations into a +// sequence of smaller (R2) Dot operations. +class DotDecomposer : public HloPassInterface { + public: + // Decomposes batch Dot operations when 'decompose_batch_dot' is true. + DotDecomposer(bool decompose_batch_dot = true) + : decompose_batch_dot_(decompose_batch_dot) {} + ~DotDecomposer() = default; + tensorflow::StringPiece name() const override { return "dot_decomposer"; } + + // Run DotDecomposer pass on computations in 'module'. + // Returns whether the 'module' was changed. + StatusOr Run(HloModule* module) override; + + private: + bool decompose_batch_dot_; +}; + +} // namespace xla + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_DOT_DECOMPOSER_H_ diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index fd4c332cba94513ec5b4cd88a842189e716f35d5..7e88bbd63123cd33682bb5ff67761ae5c5bdc98c 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -50,11 +50,161 @@ using llvm_ir::IrName; using llvm_ir::SetToFirstInsertPoint; using tensorflow::strings::StrCat; +namespace { + +llvm::Value* EmitReducePrecisionFloat(llvm::Value* x, int64 exponent_bits, + int64 mantissa_bits, + llvm::IRBuilder<>* ir_builder) { + // Integer and float types for casting and constant generation. + llvm::Type* float_type = x->getType(); + llvm::IntegerType* int_type = ir_builder->getInt32Ty(); + + // Cast the input value to an integer for bitwise manipulation. + llvm::Value* x_as_int = ir_builder->CreateBitCast(x, int_type); + + if (mantissa_bits < 23) { + // Last remaining mantissa bit. + const uint32_t last_mantissa_bit_mask = 1u << (23 - mantissa_bits); + + // Compute rounding bias for round-to-nearest with ties to even. This is + // equal to a base value of 0111... plus one bit if the last remaining + // mantissa bit is 1. + const uint32_t base_rounding_bias = (last_mantissa_bit_mask >> 1) - 1; + llvm::Value* x_last_mantissa_bit = ir_builder->CreateLShr( + ir_builder->CreateAnd( + x_as_int, llvm::ConstantInt::get(int_type, last_mantissa_bit_mask)), + (23 - mantissa_bits)); + llvm::Value* x_rounding_bias = ir_builder->CreateAdd( + x_last_mantissa_bit, + llvm::ConstantInt::get(int_type, base_rounding_bias)); + + // Add rounding bias, and mask out truncated bits. Note that the case + // where adding the rounding bias overflows into the exponent bits is + // correct; the non-masked mantissa bits will all be zero, and the + // exponent will be incremented by one. + const uint32_t truncation_mask = ~(last_mantissa_bit_mask - 1); + x_as_int = ir_builder->CreateAdd(x_as_int, x_rounding_bias); + x_as_int = ir_builder->CreateAnd( + x_as_int, llvm::ConstantInt::get(int_type, truncation_mask)); + } + + if (exponent_bits < 8) { + // Masks for f32 values. + const uint32_t f32_sign_bit_mask = 1u << 31; + const uint32_t f32_exp_bits_mask = 0xffu << 23; + + // An exponent of 2^(n-1)-1 -- that is, 0111... with the zero in the most- + // significant bit -- is equal to 1.0f for all exponent sizes. Adding + // 2^(n-1)-1 to this gives us the highest non-infinite exponent for a bit- + // size of n, and subtracting 2^(n-1)-1 from this gives us the lowest' + // exponent (corresponding to 0.0f). + // + // Thus, the f32 exponent corresponding to the highest non-infinite + // exponent for a bit size of n is (2^7-1) + 2^(n-1)-1, and the f32 + // exponent corresponding to the lowest exponent for a bit size of n is + // (2^7-1) - 2^(n-1)-1. + // + // Note that we have already checked that exponents_bits >= 1. + const uint32_t f32_exponent_bias = (1 << 7) - 1; + const uint32_t reduced_exponent_bias = (1 << (exponent_bits - 1)) - 1; + const uint32_t reduced_max_exponent = + f32_exponent_bias + reduced_exponent_bias; + const uint32_t reduced_min_exponent = + f32_exponent_bias - reduced_exponent_bias; + + // Do we overflow or underflow? + llvm::Value* x_exponent = ir_builder->CreateAnd( + x_as_int, llvm::ConstantInt::get(int_type, f32_exp_bits_mask)); + llvm::Value* x_overflows = ir_builder->CreateICmpUGT( + x_exponent, + llvm::ConstantInt::get(int_type, reduced_max_exponent << 23)); + llvm::Value* x_underflows = ir_builder->CreateICmpULE( + x_exponent, + llvm::ConstantInt::get(int_type, reduced_min_exponent << 23)); + + // Compute appropriately-signed values of zero and infinity. + llvm::Value* x_signed_zero = ir_builder->CreateAnd( + x_as_int, llvm::ConstantInt::get(int_type, f32_sign_bit_mask)); + llvm::Value* x_signed_inf = ir_builder->CreateOr( + x_signed_zero, llvm::ConstantInt::get(int_type, f32_exp_bits_mask)); + + // Force to zero or infinity if overflow or underflow. (Note that this + // truncates all denormal values to zero, rather than rounding them.) + x_as_int = ir_builder->CreateSelect(x_overflows, x_signed_inf, x_as_int); + x_as_int = ir_builder->CreateSelect(x_underflows, x_signed_zero, x_as_int); + } + + // Cast the result back to a floating-point type. + llvm::Value* result = ir_builder->CreateBitCast(x_as_int, float_type); + + // Correct result for NaN inputs. + // + // The exponent handling will "normalize" NaN values to infinities, which is + // undesirable (except in the case with no mantissa bits, in which case it + // is mandatory). This logic also handles cases where mantissa-rounding + // causes a NaN's mantissa to overflow into the exponent bits, which would + // otherwise create an erroneous zero value. + // + // If the fast-math flags are set to assume no NaNs, the comparison is likely + // to be optimized away, so there's no point in even emitting it. + if (!ir_builder->getFastMathFlags().noNaNs()) { + llvm::Value* x_is_nan = ir_builder->CreateFCmpUNO(x, x); + + if (mantissa_bits > 0) { + result = ir_builder->CreateSelect(x_is_nan, x, result); + } else { + result = ir_builder->CreateSelect( + x_is_nan, llvm::ConstantFP::getInfinity(float_type), result); + } + } + return result; +} + +llvm::Value* EmitF32ToBF16(llvm::Value* f32_value, + llvm::IRBuilder<>* ir_builder) { + auto reduced_precision = EmitReducePrecisionFloat( + f32_value, + /*exponent_bits=*/primitive_util::kBFloat16ExponentBits, + /*mantissa_bits=*/primitive_util::kBFloat16MantissaBits, ir_builder); + auto as_int32 = + ir_builder->CreateBitCast(reduced_precision, ir_builder->getInt32Ty()); + auto shifted = ir_builder->CreateLShr(as_int32, 16); + auto truncated = ir_builder->CreateTrunc(shifted, ir_builder->getInt16Ty()); + return ir_builder->CreateBitCast(truncated, ir_builder->getInt16Ty()); +} + +llvm::Value* EmitBF16ToF32(llvm::Value* bf16_value, + llvm::IRBuilder<>* ir_builder) { + auto as_int16 = + ir_builder->CreateBitCast(bf16_value, ir_builder->getInt16Ty()); + auto as_int32 = ir_builder->CreateZExt(as_int16, ir_builder->getInt32Ty()); + auto shifted = ir_builder->CreateShl(as_int32, 16); + return ir_builder->CreateBitCast(shifted, ir_builder->getFloatTy()); +} + +llvm::Value* EmitIntegralToFloating(llvm::Value* integer_value, + PrimitiveType from_type, + PrimitiveType to_type, llvm::Module* module, + llvm::IRBuilder<>* ir_builder) { + if (primitive_util::IsSignedIntegralType(from_type)) { + return ir_builder->CreateSIToFP( + integer_value, llvm_ir::PrimitiveTypeToIrType(to_type, module)); + } else { + CHECK(primitive_util::IsUnsignedIntegralType(from_type) || + from_type == PRED); + return ir_builder->CreateUIToFP( + integer_value, llvm_ir::PrimitiveTypeToIrType(to_type, module)); + } +} + +} // namespace + StatusOr ElementalIrEmitter::EmitUnaryOp( const HloInstruction* op, llvm::Value* operand_value) const { if (op->opcode() == HloOpcode::kCopy) { return operand_value; - } else if (operand_value->getType()->isIntegerTy()) { + } else if (ShapeUtil::ElementIsIntegral(op->operand(0)->shape()) || + op->operand(0)->shape().element_type() == PRED) { return EmitIntegerUnaryOp(op, operand_value); } else if (ShapeUtil::ElementIsComplex(op->operand(0)->shape())) { return EmitComplexUnaryOp(op, operand_value); @@ -79,28 +229,27 @@ StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( primitive_util::IsSignedIntegralType(to_type)); } if (primitive_util::IsFloatingPointType(to_type)) { - if (primitive_util::IsSignedIntegralType(from_type)) { - return ir_builder_->CreateSIToFP( - operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_)); - } - if (primitive_util::IsUnsignedIntegralType(from_type) || - from_type == PRED) { - return ir_builder_->CreateUIToFP( - operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_)); + if (to_type == BF16) { + return EmitF32ToBF16( + EmitIntegralToFloating(operand_value, from_type, F32, module_, + ir_builder_), + ir_builder_); } + return EmitIntegralToFloating(operand_value, from_type, to_type, + module_, ir_builder_); } if (primitive_util::IsComplexType(to_type)) { auto to_ir_component_type = llvm_ir::PrimitiveTypeToIrType( primitive_util::ComplexComponentType(to_type), module_); if (primitive_util::IsSignedIntegralType(from_type)) { - return ComposeComplex( + return EmitComposeComplex( op, ir_builder_->CreateSIToFP(operand_value, to_ir_component_type), nullptr); } if (primitive_util::IsUnsignedIntegralType(from_type) || from_type == PRED) { - return ComposeComplex( + return EmitComposeComplex( op, ir_builder_->CreateUIToFP(operand_value, to_ir_component_type), nullptr); @@ -110,6 +259,26 @@ StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( PrimitiveType_Name(from_type).c_str(), PrimitiveType_Name(to_type).c_str()); } + case HloOpcode::kBitcastConvert: { + PrimitiveType from_type = op->operand(0)->shape().element_type(); + PrimitiveType to_type = op->shape().element_type(); + CHECK(primitive_util::IsIntegralType(from_type)); + if (from_type == to_type) { + return operand_value; + } + if (primitive_util::BitWidth(from_type) == + primitive_util::BitWidth(to_type)) { + return ir_builder_->CreateBitCast( + operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_)); + } + return InvalidArgument( + "bitcast conversion from primitive type %s to %s with unequal " + "bit-widths (%u versus %u) ", + PrimitiveType_Name(from_type).c_str(), + PrimitiveType_Name(to_type).c_str(), + primitive_util::BitWidth(from_type), + primitive_util::BitWidth(to_type)); + } case HloOpcode::kAbs: { bool is_signed = primitive_util::IsSignedIntegralType(op->shape().element_type()); @@ -178,15 +347,26 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( PrimitiveType to_component_type = primitive_util::ComplexComponentType(to_type); if (from_type == to_component_type) { - return ComposeComplex(op, operand_value, nullptr); + return EmitComposeComplex(op, operand_value, nullptr); } - return ComposeComplex( + return EmitComposeComplex( op, ir_builder_->CreateFPCast( operand_value, llvm_ir::PrimitiveTypeToIrType(to_component_type, module_)), nullptr); } + if (from_type == BF16) { + TF_RET_CHECK(to_type != BF16); + operand_value = EmitBF16ToF32(operand_value, ir_builder_); + from_type = F32; + if (from_type == to_type) { + return operand_value; + } + } + if (from_type == F32 && to_type == BF16) { + return EmitF32ToBF16(operand_value, ir_builder_); + } if (primitive_util::IsFloatingPointType(to_type)) { return ir_builder_->CreateFPCast( operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_)); @@ -203,6 +383,26 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( PrimitiveType_Name(from_type).c_str(), PrimitiveType_Name(to_type).c_str()); } + case HloOpcode::kBitcastConvert: { + PrimitiveType from_type = op->operand(0)->shape().element_type(); + PrimitiveType to_type = op->shape().element_type(); + CHECK(primitive_util::IsFloatingPointType(from_type)); + if (from_type == to_type) { + return operand_value; + } + if (primitive_util::BitWidth(from_type) == + primitive_util::BitWidth(to_type)) { + return ir_builder_->CreateBitCast( + operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_)); + } + return InvalidArgument( + "bitcast conversion from primitive type %s to %s with unequal " + "bit-widths (%u versus %u) ", + PrimitiveType_Name(from_type).c_str(), + PrimitiveType_Name(to_type).c_str(), + primitive_util::BitWidth(from_type), + primitive_util::BitWidth(to_type)); + } case HloOpcode::kExp: return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {operand_value}, {operand_value->getType()}, @@ -269,15 +469,8 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( StatusOr ElementalIrEmitter::EmitComplexUnaryOp( const HloInstruction* op, llvm::Value* operand_value) const { - auto real = [&](llvm::Value* x) { - return ir_builder_->CreateExtractValue(x, {0}); - }; - auto imag = [&](llvm::Value* x) { - return ir_builder_->CreateExtractValue(x, {1}); - }; switch (op->opcode()) { // TODO(b/65209142): Angle/Log require atan2. - // case HloOpcode::kAngle: // case HloOpcode::kLog: // log(a+bi) = .5*log(a^2+b^2) + i*atan2(b, a) case HloOpcode::kConvert: { PrimitiveType from_type = op->operand(0)->shape().element_type(); @@ -291,24 +484,26 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( primitive_util::ComplexComponentType(to_type); auto to_ir_component_type = llvm_ir::PrimitiveTypeToIrType(to_component_type, module_); - return ComposeComplex( + return EmitComposeComplex( op, - ir_builder_->CreateFPCast(real(operand_value), to_ir_component_type), - ir_builder_->CreateFPCast(imag(operand_value), to_ir_component_type)); + ir_builder_->CreateFPCast(EmitExtractReal(operand_value), + to_ir_component_type), + ir_builder_->CreateFPCast(EmitExtractImag(operand_value), + to_ir_component_type)); } case HloOpcode::kExp: { // e^(a+bi) = e^a*(cos(b)+sin(b)i) auto exp_a = llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::exp, {real(operand_value)}, - {real(operand_value)->getType()}, ir_builder_); + llvm::Intrinsic::exp, {EmitExtractReal(operand_value)}, + {EmitExtractReal(operand_value)->getType()}, ir_builder_); auto cos_b = llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::cos, {imag(operand_value)}, - {imag(operand_value)->getType()}, ir_builder_); + llvm::Intrinsic::cos, {EmitExtractImag(operand_value)}, + {EmitExtractImag(operand_value)->getType()}, ir_builder_); auto sin_b = llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::sin, {imag(operand_value)}, - {imag(operand_value)->getType()}, ir_builder_); - return ComposeComplex(op, ir_builder_->CreateFMul(exp_a, cos_b), - ir_builder_->CreateFMul(exp_a, sin_b)); + llvm::Intrinsic::sin, {EmitExtractImag(operand_value)}, + {EmitExtractImag(operand_value)->getType()}, ir_builder_); + return EmitComposeComplex(op, ir_builder_->CreateFMul(exp_a, cos_b), + ir_builder_->CreateFMul(exp_a, sin_b)); } case HloOpcode::kCos: { // cos(z) = .5(e^(iz) + e^(-iz)) @@ -318,8 +513,8 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( // cos(-x) = cos(x) and sin(-x) = -sin(x), so // cos(a+bi) = .5(e^-b*(cos(a)+sin(a)i) + e^b*(cos(a)-sin(a)i)) // = .5(cos(a)*(e^-b+e^b) + i*sin(a)*(e^-b-e^b)) - auto a = real(operand_value); - auto b = imag(operand_value); + auto a = EmitExtractReal(operand_value); + auto b = EmitExtractImag(operand_value); auto type = a->getType(); auto exp_b = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {b}, {type}, ir_builder_); @@ -331,7 +526,7 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( {type}, ir_builder_); auto sin_a = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sin, {a}, {type}, ir_builder_); - return ComposeComplex( + return EmitComposeComplex( op, ir_builder_->CreateFMul( cos_a, ir_builder_->CreateFAdd(half_exp_neg_b, half_exp_b)), @@ -348,8 +543,8 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( // cos(-x) = cos(x) and sin(-x) = -sin(x), so // = 0.5(e^b*(cos(a)i+sin(a)) - e^-b*(cos(a)i-sin(a))) // = 0.5(sin(a)*(e^b+e^-b) + i*cos(a)*(e^b-e^-b) - auto a = real(operand_value); - auto b = imag(operand_value); + auto a = EmitExtractReal(operand_value); + auto b = EmitExtractImag(operand_value); auto type = a->getType(); auto exp_b = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {b}, {type}, ir_builder_); @@ -361,7 +556,7 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( {type}, ir_builder_); auto sin_a = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sin, {a}, {type}, ir_builder_); - return ComposeComplex( + return EmitComposeComplex( op, ir_builder_->CreateFMul( sin_a, ir_builder_->CreateFAdd(half_exp_b, half_exp_neg_b)), @@ -370,33 +565,40 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( } case HloOpcode::kAbs: { auto sum_sq = ir_builder_->CreateFAdd( - ir_builder_->CreateFMul(real(operand_value), real(operand_value)), - ir_builder_->CreateFMul(imag(operand_value), imag(operand_value))); + ir_builder_->CreateFMul(EmitExtractReal(operand_value), + EmitExtractReal(operand_value)), + ir_builder_->CreateFMul(EmitExtractImag(operand_value), + EmitExtractImag(operand_value))); return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sqrt, {sum_sq}, {sum_sq->getType()}, ir_builder_); } case HloOpcode::kSign: { // Sign(c) = c / |c| auto sum_sq = ir_builder_->CreateFAdd( - ir_builder_->CreateFMul(real(operand_value), real(operand_value)), - ir_builder_->CreateFMul(imag(operand_value), imag(operand_value))); + ir_builder_->CreateFMul(EmitExtractReal(operand_value), + EmitExtractReal(operand_value)), + ir_builder_->CreateFMul(EmitExtractImag(operand_value), + EmitExtractImag(operand_value))); auto cplx_abs = llvm_ir::EmitCallToIntrinsic( llvm::Intrinsic::sqrt, {sum_sq}, {sum_sq->getType()}, ir_builder_); auto type = cplx_abs->getType(); auto zero = llvm::ConstantFP::get(type, 0.0); auto oeq = ir_builder_->CreateFCmpOEQ(cplx_abs, zero); return ir_builder_->CreateSelect( - oeq, ComposeComplex(op, zero, zero), - ComposeComplex( - op, ir_builder_->CreateFDiv(real(operand_value), cplx_abs), - ir_builder_->CreateFDiv(imag(operand_value), cplx_abs))); + oeq, EmitComposeComplex(op, zero, zero), + EmitComposeComplex( + op, + ir_builder_->CreateFDiv(EmitExtractReal(operand_value), cplx_abs), + ir_builder_->CreateFDiv(EmitExtractImag(operand_value), + cplx_abs))); } case HloOpcode::kNegate: - return ComposeComplex(op, ir_builder_->CreateFNeg(real(operand_value)), - ir_builder_->CreateFNeg(imag(operand_value))); + return EmitComposeComplex( + op, ir_builder_->CreateFNeg(EmitExtractReal(operand_value)), + ir_builder_->CreateFNeg(EmitExtractImag(operand_value))); case HloOpcode::kReal: - return real(operand_value); + return EmitExtractReal(operand_value); case HloOpcode::kImag: - return imag(operand_value); + return EmitExtractImag(operand_value); default: return Unimplemented("unary complex op '%s'", HloOpcodeString(op->opcode()).c_str()); @@ -407,7 +609,8 @@ StatusOr ElementalIrEmitter::EmitBinaryOp( const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) const { PrimitiveType operand_type = op->operand(0)->shape().element_type(); - if (lhs_value->getType()->isIntegerTy()) { + if (ShapeUtil::ElementIsIntegral(op->operand(0)->shape()) || + operand_type == PRED) { return EmitIntegerBinaryOp( op, lhs_value, rhs_value, primitive_util::IsSignedIntegralType(operand_type)); @@ -424,7 +627,7 @@ StatusOr ElementalIrEmitter::EmitFloatBinaryOp( switch (op->opcode()) { // case HloOpcode::kAtan2: // TODO(b/65209142): CPU atan2 support case HloOpcode::kComplex: - return ComposeComplex(op, lhs_value, rhs_value); + return EmitComposeComplex(op, lhs_value, rhs_value); case HloOpcode::kAdd: return ir_builder_->CreateFAdd(lhs_value, rhs_value); case HloOpcode::kSubtract: @@ -479,54 +682,66 @@ StatusOr ElementalIrEmitter::EmitFloatBinaryOp( StatusOr ElementalIrEmitter::EmitComplexBinaryOp( const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) const { - auto real = [&](llvm::Value* x) { - return ir_builder_->CreateExtractValue(x, {0}); - }; - auto imag = [&](llvm::Value* x) { - return ir_builder_->CreateExtractValue(x, {1}); - }; switch (op->opcode()) { case HloOpcode::kAdd: - return ComposeComplex( - op, ir_builder_->CreateFAdd(real(lhs_value), real(rhs_value)), - ir_builder_->CreateFAdd(imag(lhs_value), imag(rhs_value))); + return EmitComposeComplex( + op, + ir_builder_->CreateFAdd(EmitExtractReal(lhs_value), + EmitExtractReal(rhs_value)), + ir_builder_->CreateFAdd(EmitExtractImag(lhs_value), + EmitExtractImag(rhs_value))); case HloOpcode::kSubtract: - return ComposeComplex( - op, ir_builder_->CreateFSub(real(lhs_value), real(rhs_value)), - ir_builder_->CreateFSub(imag(lhs_value), imag(rhs_value))); + return EmitComposeComplex( + op, + ir_builder_->CreateFSub(EmitExtractReal(lhs_value), + EmitExtractReal(rhs_value)), + ir_builder_->CreateFSub(EmitExtractImag(lhs_value), + EmitExtractImag(rhs_value))); case HloOpcode::kMultiply: - return ComposeComplex( + return EmitComposeComplex( op, ir_builder_->CreateFSub( - ir_builder_->CreateFMul(real(lhs_value), real(rhs_value)), - ir_builder_->CreateFMul(imag(lhs_value), imag(rhs_value))), + ir_builder_->CreateFMul(EmitExtractReal(lhs_value), + EmitExtractReal(rhs_value)), + ir_builder_->CreateFMul(EmitExtractImag(lhs_value), + EmitExtractImag(rhs_value))), ir_builder_->CreateFAdd( - ir_builder_->CreateFMul(real(lhs_value), imag(rhs_value)), - ir_builder_->CreateFMul(imag(lhs_value), real(rhs_value)))); + ir_builder_->CreateFMul(EmitExtractReal(lhs_value), + EmitExtractImag(rhs_value)), + ir_builder_->CreateFMul(EmitExtractImag(lhs_value), + EmitExtractReal(rhs_value)))); case HloOpcode::kDivide: { // (a+bi) / (c+di) = ((a+bi)(c-di)) / ((c+di)(c-di)) // = ((ac + bd) + (bc - ad)i) / (c^2 + d^2) auto rhs_sum_sq = ir_builder_->CreateFAdd( - ir_builder_->CreateFMul(real(rhs_value), real(rhs_value)), - ir_builder_->CreateFMul(imag(rhs_value), imag(rhs_value))); + ir_builder_->CreateFMul(EmitExtractReal(rhs_value), + EmitExtractReal(rhs_value)), + ir_builder_->CreateFMul(EmitExtractImag(rhs_value), + EmitExtractImag(rhs_value))); auto type = rhs_sum_sq->getType(); auto zero = llvm::ConstantFP::get(type, 0.0); auto oeq = ir_builder_->CreateFCmpOEQ(rhs_sum_sq, zero); + auto real_inf_or_nan = + ir_builder_->CreateFDiv(EmitExtractReal(lhs_value), zero); + auto imag_inf_or_nan = + ir_builder_->CreateFDiv(EmitExtractImag(lhs_value), zero); return ir_builder_->CreateSelect( - oeq, ComposeComplex(op, llvm::ConstantFP::getInfinity(type), zero), - ComposeComplex( + oeq, EmitComposeComplex(op, real_inf_or_nan, imag_inf_or_nan), + EmitComposeComplex( op, ir_builder_->CreateFDiv( ir_builder_->CreateFAdd( - ir_builder_->CreateFMul(real(lhs_value), real(rhs_value)), - ir_builder_->CreateFMul(imag(lhs_value), - imag(rhs_value))), + ir_builder_->CreateFMul(EmitExtractReal(lhs_value), + EmitExtractReal(rhs_value)), + ir_builder_->CreateFMul(EmitExtractImag(lhs_value), + EmitExtractImag(rhs_value))), rhs_sum_sq), ir_builder_->CreateFDiv( ir_builder_->CreateFSub( - ir_builder_->CreateFMul(imag(lhs_value), real(rhs_value)), - ir_builder_->CreateFMul(real(lhs_value), - imag(rhs_value))), + ir_builder_->CreateFMul(EmitExtractImag(lhs_value), + EmitExtractReal(rhs_value)), + ir_builder_->CreateFMul(EmitExtractReal(lhs_value), + EmitExtractImag(rhs_value))), rhs_sum_sq))); } // LLVM comparisons can be "unordered" (U) or "ordered" (O) -- ordered @@ -538,16 +753,20 @@ StatusOr ElementalIrEmitter::EmitComplexBinaryOp( // matches C++'s semantics. case HloOpcode::kEq: return ir_builder_->CreateAnd( - llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, real(lhs_value), - real(rhs_value), ir_builder_), - llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, imag(lhs_value), - imag(rhs_value), ir_builder_)); + llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, + EmitExtractReal(lhs_value), + EmitExtractReal(rhs_value), ir_builder_), + llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, + EmitExtractImag(lhs_value), + EmitExtractImag(rhs_value), ir_builder_)); case HloOpcode::kNe: return ir_builder_->CreateOr( - llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, real(lhs_value), - real(rhs_value), ir_builder_), - llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, imag(lhs_value), - imag(rhs_value), ir_builder_)); + llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, + EmitExtractReal(lhs_value), + EmitExtractReal(rhs_value), ir_builder_), + llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, + EmitExtractImag(lhs_value), + EmitExtractImag(rhs_value), ir_builder_)); // TODO(b/65209142): requires arg(z) -> requires atan|atan2 intrinsic // case HloOpcode::kPower: @@ -659,111 +878,9 @@ StatusOr ElementalIrEmitter::EmitReducePrecision( if (hlo->operand(0)->shape().element_type() != F32) { return Unimplemented("reduce-precision only implemented for F32"); } - - // Integer and float types for casting and constant generation. - llvm::Type* float_type = x->getType(); - llvm::IntegerType* int_type = ir_builder_->getInt32Ty(); - - // Cast the input value to an integer for bitwise manipulation. - llvm::Value* x_as_int = ir_builder_->CreateBitCast(x, int_type); - - if (hlo->mantissa_bits() < 23) { - // Last remaining mantissa bit. - const uint32_t last_mantissa_bit_mask = 1u << (23 - hlo->mantissa_bits()); - - // Compute rounding bias for round-to-nearest with ties to even. This is - // equal to a base value of 0111... plus one bit if the last remaining - // mantissa bit is 1. - const uint32_t base_rounding_bias = (last_mantissa_bit_mask >> 1) - 1; - llvm::Value* x_last_mantissa_bit = ir_builder_->CreateLShr( - ir_builder_->CreateAnd( - x_as_int, llvm::ConstantInt::get(int_type, last_mantissa_bit_mask)), - (23 - hlo->mantissa_bits())); - llvm::Value* x_rounding_bias = ir_builder_->CreateAdd( - x_last_mantissa_bit, - llvm::ConstantInt::get(int_type, base_rounding_bias)); - - // Add rounding bias, and mask out truncated bits. Note that the case - // where adding the rounding bias overflows into the exponent bits is - // correct; the non-masked mantissa bits will all be zero, and the - // exponent will be incremented by one. - const uint32_t truncation_mask = ~(last_mantissa_bit_mask - 1); - x_as_int = ir_builder_->CreateAdd(x_as_int, x_rounding_bias); - x_as_int = ir_builder_->CreateAnd( - x_as_int, llvm::ConstantInt::get(int_type, truncation_mask)); - } - - if (hlo->exponent_bits() < 8) { - // Masks for f32 values. - const uint32_t f32_sign_bit_mask = 1u << 31; - const uint32_t f32_exp_bits_mask = 0xffu << 23; - - // An exponent of 2^(n-1)-1 -- that is, 0111... with the zero in the most- - // significant bit -- is equal to 1.0f for all exponent sizes. Adding - // 2^(n-1)-1 to this gives us the highest non-infinite exponent for a bit- - // size of n, and subtracting 2^(n-1)-1 from this gives us the lowest' - // exponent (corresponding to 0.0f). - // - // Thus, the f32 exponent corresponding to the highest non-infinite - // exponent for a bit size of n is (2^7-1) + 2^(n-1)-1, and the f32 - // exponent corresponding to the lowest exponent for a bit size of n is - // (2^7-1) - 2^(n-1)-1. - // - // Note that we have already checked that exponents_bits >= 1. - const uint32_t f32_exponent_bias = (1 << 7) - 1; - const uint32_t reduced_exponent_bias = - (1 << (hlo->exponent_bits() - 1)) - 1; - const uint32_t reduced_max_exponent = - f32_exponent_bias + reduced_exponent_bias; - const uint32_t reduced_min_exponent = - f32_exponent_bias - reduced_exponent_bias; - - // Do we overflow or underflow? - llvm::Value* x_exponent = ir_builder_->CreateAnd( - x_as_int, llvm::ConstantInt::get(int_type, f32_exp_bits_mask)); - llvm::Value* x_overflows = ir_builder_->CreateICmpUGT( - x_exponent, - llvm::ConstantInt::get(int_type, reduced_max_exponent << 23)); - llvm::Value* x_underflows = ir_builder_->CreateICmpULE( - x_exponent, - llvm::ConstantInt::get(int_type, reduced_min_exponent << 23)); - - // Compute appropriately-signed values of zero and infinity. - llvm::Value* x_signed_zero = ir_builder_->CreateAnd( - x_as_int, llvm::ConstantInt::get(int_type, f32_sign_bit_mask)); - llvm::Value* x_signed_inf = ir_builder_->CreateOr( - x_signed_zero, llvm::ConstantInt::get(int_type, f32_exp_bits_mask)); - - // Force to zero or infinity if overflow or underflow. (Note that this - // truncates all denormal values to zero, rather than rounding them.) - x_as_int = ir_builder_->CreateSelect(x_overflows, x_signed_inf, x_as_int); - x_as_int = ir_builder_->CreateSelect(x_underflows, x_signed_zero, x_as_int); - } - - // Cast the result back to a floating-point type. - llvm::Value* result = ir_builder_->CreateBitCast(x_as_int, float_type); - - // Correct result for NaN inputs. - // - // The exponent handling will "normalize" NaN values to infinities, which is - // undesirable (except in the case with no mantissa bits, in which case it - // is mandatory). This logic also handles cases where mantissa-rounding - // causes a NaN's mantissa to overflow into the exponent bits, which would - // otherwise create an erroneous zero value. - // - // If the fast-math flags are set to assume no NaNs, the comparison is likely - // to be optimized away, so there's no point in even emitting it. - if (!ir_builder_->getFastMathFlags().noNaNs()) { - llvm::Value* x_is_nan = ir_builder_->CreateFCmpUNO(x, x); - - if (hlo->mantissa_bits() > 0) { - result = ir_builder_->CreateSelect(x_is_nan, x, result); - } else { - result = ir_builder_->CreateSelect( - x_is_nan, llvm::ConstantFP::getInfinity(float_type), result); - } - } - return result; + return EmitReducePrecisionFloat(x, /*exponent_bits=*/hlo->exponent_bits(), + /*mantissa_bits=*/hlo->mantissa_bits(), + ir_builder_); } StatusOr ElementalIrEmitter::EmitIntegerBinaryOp( @@ -847,7 +964,7 @@ llvm_ir::IrArray::Index ElementalIrEmitter::ElementwiseSourceIndex( // If no implicit broadcast is needed for this operand, returns the target // index as the source index. - if (ShapeUtil::Compatible(operand_shape, hlo.shape())) { + if (ShapeUtil::CompatibleIgnoringElementType(operand_shape, hlo.shape())) { return target_index; } @@ -1055,6 +1172,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( case HloOpcode::kRoundNearestAfz: case HloOpcode::kCeil: case HloOpcode::kConvert: + case HloOpcode::kBitcastConvert: case HloOpcode::kCopy: case HloOpcode::kCos: case HloOpcode::kExp: @@ -1063,11 +1181,11 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( case HloOpcode::kIsFinite: case HloOpcode::kLog: case HloOpcode::kNegate: + case HloOpcode::kNot: case HloOpcode::kReal: case HloOpcode::kSign: case HloOpcode::kSin: case HloOpcode::kTanh: - case HloOpcode::kNot: return [this, hlo, &operand_to_generator]( const IrArray::Index& index) -> StatusOr { TF_ASSIGN_OR_RETURN(llvm::Value * operand_value, @@ -1076,6 +1194,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( return EmitUnaryOp(hlo, operand_value); }; case HloOpcode::kAdd: + case HloOpcode::kAnd: case HloOpcode::kAtan2: case HloOpcode::kComplex: case HloOpcode::kDivide: @@ -1088,14 +1207,13 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( case HloOpcode::kMinimum: case HloOpcode::kMultiply: case HloOpcode::kNe: + case HloOpcode::kOr: case HloOpcode::kPower: case HloOpcode::kRemainder: - case HloOpcode::kSubtract: - case HloOpcode::kAnd: - case HloOpcode::kOr: case HloOpcode::kShiftLeft: case HloOpcode::kShiftRightArithmetic: case HloOpcode::kShiftRightLogical: + case HloOpcode::kSubtract: return [this, hlo, &operand_to_generator]( const IrArray::Index& index) -> StatusOr { const HloInstruction* lhs = hlo->operand(0); @@ -1289,6 +1407,15 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( const int64 rank = ShapeUtil::Rank(input_hlo->shape()); llvm_ir::IrArray::Index slice_start_index(rank); llvm_ir::IrArray::Index slice_limit_index(rank); + // Slice starts at update[index - slice_start_index_adjusted], + // where adjusted value = slice_start_index when in bounds, and + // adjusted value = slice_start_index - input_dim, when wrapping. + llvm_ir::IrArray::Index slice_start_index_adjusted(rank); + + // Slice intersection gathers (ANDs) conditions on all ranks for which + // 'input' is set to 'update' + llvm::Value* slice_intersection = ir_builder_->getTrue(); + for (int64 i = 0; i < rank; ++i) { // Emit IR to read dynamic start indices from 'start_hlo'. llvm_ir::IrArray::Index dim_index(1, ir_builder_->getInt64(i)); @@ -1298,38 +1425,97 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( AsStringRef(IrName(hlo, StrCat("start_idx", i)))); slice_start_index[i] = ir_builder_->CreateZExtOrBitCast( start_index_value, index[i]->getType()); - // Emit IR to compute: slice_limit_index = start_index + update_dim - // NOTE: Although 'start_indices' is dynamic and could be - // out-of-range, we do not compute 'slice_limit_index' mod input dim - // size here, because subsequent array index calculations will be - // computed mod input dim size for safety. + + llvm::Value* input_dim_size = llvm::ConstantInt::get( + index[i]->getType(), input_hlo->shape().dimensions(i)); llvm::Value* update_dim_size = llvm::ConstantInt::get( index[i]->getType(), update_hlo->shape().dimensions(i)); + + // Generate code to handle wrapping semantics: + // slice_start_index[i] = slice_start_index[i] % input_dim_size; + // slice_limit_index[i] = slice_start_index[i] + update_dim_size. + // slice_start_index[i] is updated in place and it will now be in + // range. slice_limit_index[i] may be out of range, and it's being + // URem-ed below if so. + slice_start_index[i] = + ir_builder_->CreateURem(slice_start_index[i], input_dim_size); slice_limit_index[i] = ir_builder_->CreateAdd(slice_start_index[i], update_dim_size); - } - - // Check if 'index' intersects start/end indices. - llvm::Value* slice_intersection = - llvm::ConstantInt::get(ir_builder_->getInt1Ty(), 1); - for (int64 i = 0; i < rank; ++i) { - // Check that index[i] >= slice_start_index[i]. - slice_intersection = ir_builder_->CreateAnd( + // Test if slice_limit_index[i] is in bounds + llvm::Value* in_bounds = + ir_builder_->CreateICmpULE(slice_limit_index[i], input_dim_size); + llvm_ir::LlvmIfData if_in_bounds = + llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", ir_builder_); + + // Handle true BB (slice_limit_index[i] <= input_dim_size). + SetToFirstInsertPoint(if_in_bounds.true_block, ir_builder_); + // Check that index[i] >= slice_start_index[i] && + // index[i] < slice_limit_index[i] + llvm::Value* slice_intersection_in_bounds = ir_builder_->CreateAnd( slice_intersection, ir_builder_->CreateICmpSGE(index[i], slice_start_index[i]), - "slice_intersection"); - - // Check that index[i] < slice_limit_index[i]. - slice_intersection = ir_builder_->CreateAnd( - slice_intersection, + "slice_intersection_in"); + slice_intersection_in_bounds = ir_builder_->CreateAnd( + slice_intersection_in_bounds, ir_builder_->CreateICmpSLT(index[i], slice_limit_index[i]), - "slice_intersection"); + "slice_intersection_in"); + + // Handle false BB (slice_limit_index[i] > input_dim_size). + SetToFirstInsertPoint(if_in_bounds.false_block, ir_builder_); + // Check that index[i] >= slice_start_index[i] || + // index[i] < slice_limit_index[i]%input_dim_size. + llvm::Value* index_wraps = ir_builder_->CreateICmpSLT( + index[i], + ir_builder_->CreateURem(slice_limit_index[i], input_dim_size)); + llvm::Value* slice_intersection_or = ir_builder_->CreateOr( + ir_builder_->CreateICmpSGE(index[i], slice_start_index[i]), + index_wraps, "slice_intersection_out"); + llvm::Value* slice_intersection_out_of_bounds = + ir_builder_->CreateAnd(slice_intersection, slice_intersection_or, + "slice_intersection_out"); + // Create value for slice_start_index_adjusted[i] when out of bounds. + // If within out-of-bounds if. + llvm_ir::LlvmIfData if_start_needs_adjustment = + llvm_ir::EmitIfThenElse(index_wraps, "adjust_start", ir_builder_); + SetToFirstInsertPoint(if_start_needs_adjustment.true_block, + ir_builder_); + llvm::Value* slice_start_index_adjusted_oob = + ir_builder_->CreateSub(slice_start_index[i], input_dim_size); + SetToFirstInsertPoint(if_start_needs_adjustment.after_block, + ir_builder_); + llvm::PHINode* slice_start_index_adjusted_phi = + ir_builder_->CreatePHI(slice_start_index_adjusted_oob->getType(), + 2); + slice_start_index_adjusted_phi->addIncoming( + slice_start_index_adjusted_oob, + if_start_needs_adjustment.true_block); + slice_start_index_adjusted_phi->addIncoming( + slice_start_index[i], if_start_needs_adjustment.false_block); + // End of if within if. + + // After checking in/out of bounds. + SetToFirstInsertPoint(if_in_bounds.after_block, ir_builder_); + llvm::PHINode* phi_slice_intersection = + ir_builder_->CreatePHI(slice_intersection->getType(), 2); + phi_slice_intersection->addIncoming(slice_intersection_in_bounds, + if_in_bounds.true_block); + phi_slice_intersection->addIncoming( + slice_intersection_out_of_bounds, + if_start_needs_adjustment.after_block); + slice_intersection = phi_slice_intersection; + + llvm::PHINode* phi_index = + ir_builder_->CreatePHI(slice_start_index[i]->getType(), 2); + phi_index->addIncoming(slice_start_index[i], if_in_bounds.true_block); + phi_index->addIncoming(slice_start_index_adjusted_phi, + if_start_needs_adjustment.after_block); + slice_start_index_adjusted[i] = phi_index; } // Emit: // if (slice_intersection) -> return data from 'update'. - // else -> return data from 'index'. + // else -> return data from 'input'. llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry( llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), module_), @@ -1337,7 +1523,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( slice_intersection, "slice_intersection", ir_builder_); - // Handle true BB. + // Handle true BB (return data from 'update') SetToFirstInsertPoint(if_data.true_block, ir_builder_); // Compute update index for intersection case. llvm_ir::IrArray::Index update_index(rank); @@ -1346,14 +1532,14 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( index[i]->getType(), update_hlo->shape().dimensions(i)); // NOTE: Subtraction will be positive due to bounds checking above. update_index[i] = ir_builder_->CreateURem( - ir_builder_->CreateSub(index[i], slice_start_index[i]), + ir_builder_->CreateSub(index[i], slice_start_index_adjusted[i]), update_dim_size); } TF_ASSIGN_OR_RETURN(llvm::Value * true_value, operand_to_generator.at(update_hlo)(update_index)); ir_builder_->CreateStore(true_value, ret_value_addr); - // Handle false BB. + // Handle false BB (return data from 'input') SetToFirstInsertPoint(if_data.false_block, ir_builder_); TF_ASSIGN_OR_RETURN(llvm::Value * false_value, operand_to_generator.at(input_hlo)(index)); @@ -1497,25 +1683,25 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( TF_ASSIGN_OR_RETURN(llvm::Value * rhs_value, rhs_generator(rhs_index)); llvm::Value* next_accumulator; if (primitive_util::IsComplexType(primitive_type)) { - auto real = [&](llvm::Value* x) { - return ir_builder_->CreateExtractValue(x, {0}); - }; - auto imag = [&](llvm::Value* x) { - return ir_builder_->CreateExtractValue(x, {1}); - }; llvm::Value* product_real = ir_builder_->CreateFSub( - ir_builder_->CreateFMul(real(lhs_value), real(rhs_value)), - ir_builder_->CreateFMul(imag(lhs_value), imag(rhs_value))); + ir_builder_->CreateFMul(EmitExtractReal(lhs_value), + EmitExtractReal(rhs_value)), + ir_builder_->CreateFMul(EmitExtractImag(lhs_value), + EmitExtractImag(rhs_value))); llvm::Value* product_imag = ir_builder_->CreateFAdd( - ir_builder_->CreateFMul(real(lhs_value), imag(rhs_value)), - ir_builder_->CreateFMul(imag(lhs_value), real(rhs_value))); + ir_builder_->CreateFMul(EmitExtractReal(lhs_value), + EmitExtractImag(rhs_value)), + ir_builder_->CreateFMul(EmitExtractImag(lhs_value), + EmitExtractReal(rhs_value))); next_accumulator = ir_builder_->CreateInsertValue( current_accumulator, - ir_builder_->CreateFAdd(real(current_accumulator), product_real), + ir_builder_->CreateFAdd(EmitExtractReal(current_accumulator), + product_real), {0}); next_accumulator = ir_builder_->CreateInsertValue( next_accumulator, - ir_builder_->CreateFAdd(imag(current_accumulator), product_imag), + ir_builder_->CreateFAdd(EmitExtractImag(current_accumulator), + product_imag), {1}); } else if (primitive_util::IsFloatingPointType(primitive_type)) { next_accumulator = ir_builder_->CreateFAdd( @@ -1539,9 +1725,17 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( } } -llvm::Value* ElementalIrEmitter::ComposeComplex(const HloInstruction* op, - llvm::Value* real, - llvm::Value* imag) const { +llvm::Value* ElementalIrEmitter::EmitExtractReal(llvm::Value* value) const { + return ir_builder_->CreateExtractValue(value, {0}); +} + +llvm::Value* ElementalIrEmitter::EmitExtractImag(llvm::Value* value) const { + return ir_builder_->CreateExtractValue(value, {1}); +} + +llvm::Value* ElementalIrEmitter::EmitComposeComplex(const HloInstruction* op, + llvm::Value* real, + llvm::Value* imag) const { auto cplx_type = llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_); auto complex = ir_builder_->CreateInsertValue( diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/elemental_ir_emitter.h index 9d32436e38fa2fb3e27d09f01b860cd2edf2c8ac..cccb498f82936283a215370787907b293827ff2d 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.h @@ -95,6 +95,13 @@ class ElementalIrEmitter { virtual StatusOr EmitReducePrecision(const HloInstruction* hlo, llvm::Value* x) const; + virtual llvm::Value* EmitExtractReal(llvm::Value* value) const; + virtual llvm::Value* EmitExtractImag(llvm::Value* value) const; + + // Composes a complex struct. imag may be nullptr for simple cast operations. + llvm::Value* EmitComposeComplex(const HloInstruction* op, llvm::Value* real, + llvm::Value* imag) const; + // A helper method for MakeElementGenerator. Given an elementwise op `hlo` and // the target array index, computes the source array index of its // `operand_no`-th operand. @@ -117,11 +124,6 @@ class ElementalIrEmitter { // compiled executable outside of the HLO code itself. const HloModuleConfig& hlo_module_config_; - protected: - // Composes a complex struct. imag may be nullptr for simple cast operations. - llvm::Value* ComposeComplex(const HloInstruction* op, llvm::Value* real, - llvm::Value* imag) const; - private: // Returns a ElementGenerator for a RNG HloInstruction. llvm_ir::ElementGenerator MakeRngElementGenerator( diff --git a/tensorflow/compiler/xla/service/executable.h b/tensorflow/compiler/xla/service/executable.h index 2d32e59d36c4e3026e0e151561db3076146fabe4..08862308c90af736c1adcaa9438973f858852506 100644 --- a/tensorflow/compiler/xla/service/executable.h +++ b/tensorflow/compiler/xla/service/executable.h @@ -44,8 +44,15 @@ namespace xla { // interface that is used for launching compiled programs across platforms. class Executable { public: - explicit Executable(std::unique_ptr hlo_module) - : hlo_module_(std::move(hlo_module)) {} + explicit Executable(std::unique_ptr hlo_module, + std::unique_ptr hlo_profile_printer, + std::unique_ptr hlo_profile_index_map) + : hlo_module_(std::move(hlo_module)), + hlo_profile_printer_(std::move(hlo_profile_printer)), + hlo_profile_index_map_(std::move(hlo_profile_index_map)) { + CHECK_EQ(hlo_profile_printer_.get() == nullptr, + hlo_profile_index_map_.get() == nullptr); + } virtual ~Executable() {} // Enqueues the compilation result on the provided stream, passing the given @@ -88,6 +95,16 @@ class Executable { tensorflow::gtl::ArraySlice> arguments); + // Populates `hlo_execution_profile` from `executor`. This is implicit in any + // Execute* API call that takes a hlo_execution_profile argument, but must be + // called explicitly for other (async, for example) variants after the stream + // has completed. + virtual Status PopulateExecutionProfile( + HloExecutionProfile* hlo_execution_profile, + perftools::gputools::StreamExecutor* executor) { + return Status::OK(); + } + // Convenience wrapper for calling Executable::ExecuteOnStream. Sets up a // timer for the execution, sets up HLO profiling if enabled, and fills in the // given ExecutionProfile if non-null. The ExecuteOnStream overloads have @@ -113,12 +130,20 @@ class Executable { "Equality test on this executable is not implemented."); } + const HloProfilePrinter& hlo_profile_printer() const { + CHECK(hlo_profiling_enabled()); + return *hlo_profile_printer_; + } + + const HloProfileIndexMap& hlo_profile_index_map() const { + CHECK(hlo_profiling_enabled()); + return *hlo_profile_index_map_; + } + // Returns whether this executable was compiled with HLO profilings support // enabled. If not, the caller should not expect an hlo_execution_profile // passed to ExecuteOnStream above to be populated during execution. - bool hlo_profiling_enabled() const { - return hlo_module_->config().hlo_profiling_enabled(); - } + bool hlo_profiling_enabled() const { return hlo_profile_printer_ != nullptr; } const HloModule& module() const { return *hlo_module_; } @@ -150,10 +175,6 @@ class Executable { static Status DumpToDirectory(const string& directory_path, string filename, const SessionModule& session_module); - // Returns a cost analysis object appropriate for the platform on which this - // executable can run. - virtual std::unique_ptr CreateCostAnalysis() const = 0; - protected: mutable tensorflow::mutex mutex_; @@ -171,6 +192,9 @@ class Executable { // Execution count, used to generate a unique filename for each dumped // execution. int64 execution_count_ = 0; + + std::unique_ptr hlo_profile_printer_; + std::unique_ptr hlo_profile_index_map_; }; template @@ -187,14 +211,15 @@ StatusOr Executable::ExecuteOnStreamWrapper( VLOG(1) << "enqueueing executable on stream..."; // If the profiling flag isn't enabled, we pass nullptr as the profile to // indicate profiling is not requested. - HloExecutionProfile hlo_execution_profile; - HloExecutionProfile* profile_ptr = + std::unique_ptr profile_ptr = module_config().debug_options().xla_hlo_profile() && hlo_profiling_enabled() - ? &hlo_execution_profile + ? MakeUnique(&hlo_profile_printer(), + &hlo_profile_index_map()) : nullptr; - auto return_value = ExecuteOnStream(run_options, arguments, profile_ptr); + auto return_value = + ExecuteOnStream(run_options, arguments, profile_ptr.get()); if (profile != nullptr) { VLOG(1) << "enqueueing 'stop timer' and blocking host until done..."; @@ -222,24 +247,11 @@ StatusOr Executable::ExecuteOnStreamWrapper( } if (profile_ptr != nullptr) { - std::unordered_set profiled_computations = - profile_ptr->profiled_computations(); - // To ensure we have print the profiles in a stable order, iterate over the - // computations in post order. - std::list all_computations = - module().MakeComputationPostOrder(); - for (xla::HloComputation* computation : all_computations) { - if (profiled_computations.count(computation) > 0) { - string profile_string = profile_ptr->ToString( - *computation, stream->parent()->GetDeviceDescription(), - CreateCostAnalysis().get()); - if (!profile_string.empty()) { - XLA_LOG_LINES(tensorflow::INFO, profile_string); - } - } - } + XLA_LOG_LINES( + tensorflow::INFO, + profile_ptr->ToString(stream->parent()->GetDeviceDescription())); hlo_graph_dumper::MaybeDumpHloModule(module(), "Service::Execute", - profile_ptr); + profile_ptr.get()); } return return_value; diff --git a/tensorflow/compiler/xla/service/flatten_call_graph.cc b/tensorflow/compiler/xla/service/flatten_call_graph.cc index dfba22a6c4c5cf071c2cd8621643b8da6587ee3b..2b6caa149439a86d6d047605099bc3ff7b295a8e 100644 --- a/tensorflow/compiler/xla/service/flatten_call_graph.cc +++ b/tensorflow/compiler/xla/service/flatten_call_graph.cc @@ -26,7 +26,10 @@ namespace xla { namespace { -// Helper to replace the called computation at a while- or call-instruction. +// Helper to replace the called computation at a while-, call-, or +// conditional-instruction. This function replaces exactly one instance of +// 'computation' with 'new_computation' even if 'instruction' calls +// 'computation' more than once. void ReplaceCalledComputation(HloInstruction* instruction, HloComputation* computation, HloComputation* new_computation) { @@ -45,6 +48,15 @@ void ReplaceCalledComputation(HloInstruction* instruction, instruction->set_to_apply(new_computation); break; } + case HloOpcode::kConditional: { + if (computation == instruction->true_computation()) { + instruction->set_true_computation(new_computation); + } else { + CHECK_EQ(computation, instruction->false_computation()); + instruction->set_false_computation(new_computation); + } + break; + } default: LOG(FATAL) << "unexpected opcode: " << HloOpcodeString(instruction->opcode()); diff --git a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc index a68e90b7d009890012f94baa790d911871c9c960..d3854b40de3572a60df1ad99d8a4589f59ad7194 100644 --- a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc +++ b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc @@ -223,5 +223,35 @@ TEST_F(FlattenCallGraphTest, FlattenCalls) { EXPECT_EQ(1, b_node.caller_callsites().size()); } +TEST_F(FlattenCallGraphTest, FlattenCallsInConditional) { + auto module = CreateNewModule(); + HloComputation* sub_computation = + module->AddEmbeddedComputation(MakeScalarComputation()); + + // Create entry computation, which is a conditional that has the same + // computation in the true and false branch. + HloComputation::Builder builder(TestName()); + auto pred = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(true))); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(56.0f))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(12.0f))); + builder.AddInstruction(HloInstruction::CreateConditional( + kScalarShape, pred, constant1, sub_computation, constant2, + sub_computation)); + module->AddEntryComputation(builder.Build()); + EXPECT_EQ(2, module->computation_count()); + + TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get())); + EXPECT_TRUE(result); + std::unique_ptr call_graph = CallGraph::Build(module.get()); + // The true and false computations must now be different. + EXPECT_EQ(3, module->computation_count()); + + const CallGraphNode& sub_node = call_graph->GetNode(sub_computation); + EXPECT_EQ(1, sub_node.caller_callsites().size()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc index d3c83ea72e33b959e21d0cc9c1706d92bd659a5c..74aa77b4f165be76fbc0a8aa1a4a7e90a8e9acec 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/interpreter/platform_id.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -103,8 +104,7 @@ GenericTransferManager::ShallowCopyTupleFromDevice( // a vector of void* pointers. std::vector element_pointers(ShapeUtil::TupleElementCount(shape), nullptr); - int64 tuple_size = - ShapeUtil::ByteSizeOf(shape, /*pointer_size=*/sizeof(void*)); + int64 tuple_size = ShapeUtil::ByteSizeOf(shape, pointer_size_); auto copy_status = executor->SynchronousMemcpyD2H(source, tuple_size, element_pointers.data()); if (!copy_status.ok()) { @@ -121,9 +121,8 @@ GenericTransferManager::ShallowCopyTupleFromDevice( !ShapeUtil::HasZeroElements(shape.tuple_shapes(i))) { return FailedPrecondition("tuple contains nullptr at element %lu", i); } - int64 buffer_size = ShapeUtil::ByteSizeOf(shape.tuple_shapes(i), - /*pointer_size=*/sizeof(void*)); - destination.emplace_back(element_pointers[i], buffer_size); + destination.emplace_back(element_pointers[i], + GetByteSizeRequirement(shape.tuple_shapes(i))); } return std::move(destination); } @@ -138,11 +137,79 @@ Status GenericTransferManager::WriteTuplePointersToDevice( for (const se::DeviceMemoryBase& element : elements) { element_pointers.push_back(element.opaque()); } - int64 tuple_size = - ShapeUtil::ByteSizeOf(shape, /*pointer_size=*/sizeof(void*)); + return TransferBufferToDevice(executor, GetByteSizeRequirement(shape), + element_pointers.data(), region); +} + +StatusOr> +GenericTransferManager::TransferLiteralFromDevice( + se::StreamExecutor* executor, const ShapedBuffer& device_buffer) { + VLOG(2) << "transferring literal from device ordinal " + << executor->device_ordinal() << "; device shape: " + << ShapeUtil::HumanStringWithLayout(device_buffer.shape()) + << "; opaque: " << device_buffer.buffer(/*index=*/{}).opaque(); + TF_RET_CHECK(executor->device_ordinal() == device_buffer.device_ordinal()); + + std::unique_ptr literal = + Literal::CreateFromShape(device_buffer.shape()); + + TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( + device_buffer.shape(), + [&](const Shape& subshape, const ShapeIndex& index) -> Status { + if (!ShapeUtil::IsTuple(subshape)) { + TF_RETURN_IF_ERROR(TransferBufferFromDevice( + executor, + /*source=*/device_buffer.buffer(index), + /*size=*/GetByteSizeRequirement(subshape), + /*destination=*/ + literal->GetSubliteral(index).MutableInternalData())); + } + + return Status::OK(); + })); + return std::move(literal); +} + +Status GenericTransferManager::TransferLiteralToDevice( + se::StreamExecutor* executor, const Literal& literal, + const ShapedBuffer& device_buffer) { + const Shape& shape = literal.shape(); + VLOG(2) << "transferring literal shape to device: " + << ShapeUtil::HumanString(shape) << "; device location: " + << device_buffer.buffer(/*index=*/{}).opaque(); + + TF_RET_CHECK(ShapeUtil::Compatible(literal.shape(), device_buffer.shape())); + TF_RET_CHECK(executor->device_ordinal() == device_buffer.device_ordinal()); + + TF_RETURN_IF_ERROR(WriteTupleIndexTables(executor, device_buffer)); - return TransferBufferToDevice(executor, tuple_size, element_pointers.data(), - region); + return ShapeUtil::ForEachSubshapeWithStatus( + device_buffer.shape(), + [&](const Shape& device_subshape, const ShapeIndex& index) -> Status { + se::DeviceMemoryBase device_memory = device_buffer.buffer(index); + if (ShapeUtil::IsArray(device_subshape)) { + TF_RET_CHECK(GetByteSizeRequirement(device_subshape) == + device_memory.size()); + // Element is array-shaped: transfer array data to device buffer. + const Literal& subliteral = literal.GetSubliteral(index); + std::unique_ptr relayed_out_literal; + const void* source; + if (LayoutUtil::Equal(device_subshape.layout(), + subliteral.shape().layout())) { + source = subliteral.InternalData(); + } else { + // Relayout data before transferring. + relayed_out_literal = subliteral.Relayout(device_subshape.layout(), + /*shape_index=*/{}); + source = relayed_out_literal->InternalData(); + } + return TransferBufferToDevice( + executor, + /*size=*/GetByteSizeRequirement(device_subshape), source, + &device_memory); + } + return Status::OK(); + }); } Status GenericTransferManager::TransferLiteralToDevice( @@ -197,8 +264,8 @@ Status GenericTransferManager::ResetDevices( "Device reset is not yet supported on this platform (b/30481585)"); } -int64 GenericTransferManager::GetByteSizeRequirement(const Shape& shape) { - return ShapeUtil::ByteSizeOf(shape, /*pointer_size=*/sizeof(void*)); +int64 GenericTransferManager::GetByteSizeRequirement(const Shape& shape) const { + return ShapeUtil::ByteSizeOf(shape, pointer_size_); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.h b/tensorflow/compiler/xla/service/generic_transfer_manager.h index 26488d6ec651b75c753119a7ce818c692c6c03dd..50dca6aec5012f0b02cb54846b622f008600e48e 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.h +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.h @@ -52,6 +52,14 @@ class GenericTransferManager : public TransferManager { perftools::gputools::StreamExecutor* executor, const Literal& literal, perftools::gputools::DeviceMemoryBase* destination) override; + StatusOr> TransferLiteralFromDevice( + perftools::gputools::StreamExecutor* executor, + const ShapedBuffer& device_buffer) override; + + Status TransferLiteralToDevice(perftools::gputools::StreamExecutor* executor, + const Literal& literal, + const ShapedBuffer& device_buffer) override; + Status TransferLiteralToInfeed(perftools::gputools::StreamExecutor* executor, const Literal& literal) override; Status TransferBufferToInfeed(perftools::gputools::StreamExecutor* executor, @@ -71,6 +79,9 @@ class GenericTransferManager : public TransferManager { const perftools::gputools::DeviceMemoryBase& source, const Shape& shape) override; + int64 GetByteSizeRequirement(const Shape& shape) const override; + + protected: Status WriteTuplePointersToDevice( perftools::gputools::StreamExecutor* executor, tensorflow::gtl::ArraySlice @@ -78,8 +89,6 @@ class GenericTransferManager : public TransferManager { const Shape& shape, perftools::gputools::DeviceMemoryBase* region) override; - int64 GetByteSizeRequirement(const Shape& shape) override; - private: // The platform this transfer manager targets. const perftools::gputools::Platform::Id platform_id_; diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index de84e06cebab72d272bd888f280f5e5b221b97d1..4a72f87efdd92497ac4c2cd73b56c4990ed5b04c 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -343,15 +343,16 @@ tf_cc_test( ) cc_library( - name = "copy_insertion", - srcs = ["copy_insertion.cc"], - hdrs = ["copy_insertion.h"], + name = "gpu_copy_insertion", + srcs = ["gpu_copy_insertion.cc"], + hdrs = ["gpu_copy_insertion.h"], deps = [ ":ir_emission_utils", + "//tensorflow/compiler/xla/service:call_graph", "//tensorflow/compiler/xla/service:copy_insertion", "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/compiler/xla/service:logical_buffer", - "//tensorflow/compiler/xla/service:tuple_points_to_analysis", + "//tensorflow/compiler/xla/service:hlo_dataflow_analysis", + "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/core:lib", ], ) @@ -427,14 +428,14 @@ cc_library( hdrs = ["gpu_compiler.h"], deps = [ ":convolution_folding", - ":copy_insertion", ":fusion_merger", + ":gpu_copy_insertion", ":gpu_executable", + ":gpu_layout_assignment", ":hlo_schedule", ":instruction_fusion", ":ir_emission_utils", ":ir_emitter", - ":layout_assignment", ":pad_insertion", ":partition_assignment", ":stream_assignment", @@ -448,6 +449,7 @@ cc_library( "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:buffer_liveness", "//tensorflow/compiler/xla/service:call_inliner", + "//tensorflow/compiler/xla/service:dot_decomposer", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:flatten_call_graph", "//tensorflow/compiler/xla/service:hlo", @@ -465,10 +467,12 @@ cc_library( "//tensorflow/compiler/xla/service:reshape_mover", "//tensorflow/compiler/xla/service:transpose_folding", "//tensorflow/compiler/xla/service:tuple_simplifier", + "//tensorflow/compiler/xla/service:while_loop_simplifier", "//tensorflow/compiler/xla/service/gpu/llvm_gpu_backend", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:cuda_libdevice_path", "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", "//tensorflow/core:stream_executor_no_cuda", "@llvm//:core", "@llvm//:support", @@ -489,9 +493,9 @@ cc_library( ) cc_library( - name = "layout_assignment", - srcs = ["layout_assignment.cc"], - hdrs = ["layout_assignment.h"], + name = "gpu_layout_assignment", + srcs = ["gpu_layout_assignment.cc"], + hdrs = ["gpu_layout_assignment.h"], deps = [ ":ir_emission_utils", "//tensorflow/compiler/xla:shape_util", @@ -505,10 +509,10 @@ cc_library( ) tf_cc_test( - name = "layout_assignment_test", - srcs = ["layout_assignment_test.cc"], + name = "gpu_layout_assignment_test", + srcs = ["gpu_layout_assignment_test.cc"], deps = [ - ":layout_assignment", + ":gpu_layout_assignment", "//tensorflow/compiler/xla:shape_layout", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto", @@ -572,11 +576,14 @@ tf_cc_test( deps = [ ":instruction_fusion", ":while_transformer", + "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/service:copy_insertion", + "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", ], ) diff --git a/tensorflow/compiler/xla/service/gpu/convolution_folding.cc b/tensorflow/compiler/xla/service/gpu/convolution_folding.cc index 5aaf072f9d2c95e2fff70a1c5337432a12a1aa48..f198c4c08e93277b3a14a32d906b8083a94a8a2c 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_folding.cc +++ b/tensorflow/compiler/xla/service/gpu/convolution_folding.cc @@ -55,28 +55,20 @@ MatchBackwardFilter(HloInstruction* conv) { // v v // Convolution // conv - // | - // v - // Transpose (optional if identity transposition) CHECK_EQ(HloOpcode::kConvolution, conv->opcode()); - // If the forward convolution is followed by a transpose, we can fuse the - // transpose into the backward convolution as well. - HloInstruction* transpose = nullptr; - if (conv->user_count() == 1) { - HloInstruction* single_user = *conv->users().begin(); - if (single_user->opcode() == HloOpcode::kTranspose) { - transpose = single_user; - } - } // Step 2: match paddings and dimension numbers of the forward convolution. const ConvolutionDimensionNumbers& conv_dnums = conv->convolution_dimension_numbers(); auto input_batch_dim = conv_dnums.input_batch_dimension(); auto input_feature_dim = conv_dnums.input_feature_dimension(); + auto input_spatial_dims = conv_dnums.input_spatial_dimensions(); + auto kernel_input_feature_dim = conv_dnums.kernel_input_feature_dimension(); + auto kernel_output_feature_dim = conv_dnums.kernel_output_feature_dimension(); + auto kernel_spatial_dims = conv_dnums.kernel_spatial_dimensions(); auto output_batch_dim = conv_dnums.output_batch_dimension(); auto output_feature_dim = conv_dnums.output_feature_dimension(); - auto spatial_dims = conv_dnums.spatial_dimensions(); + auto output_spatial_dims = conv_dnums.output_spatial_dimensions(); for (const WindowDimension& window_dim : conv->window().dimensions()) { if (window_dim.stride() != 1) { @@ -97,7 +89,8 @@ MatchBackwardFilter(HloInstruction* conv) { } // Padding high will be checked in Step 3. } - if (transpose == nullptr && !window_util::HasWindowDilation(conv->window())) { + if (input_batch_dim == output_batch_dim && + !window_util::HasWindowDilation(conv->window())) { VLOG(1) << conv->ToString() << " is a regular forward convolution. No need " "to fold it to a backward filter convolution."; @@ -108,11 +101,11 @@ MatchBackwardFilter(HloInstruction* conv) { // // Compute the window of the backward convolution. Window backward_conv_window; - for (int i = 0; i < spatial_dims.size(); ++i) { + for (int i = 0; i < input_spatial_dims.size(); ++i) { WindowDimension* dim = backward_conv_window.add_dimensions(); // The window size of the backward convolution equals the output size of the // forward convolution. - int64 filter_size = conv->shape().dimensions(spatial_dims[i]); + int64 filter_size = conv->shape().dimensions(output_spatial_dims[i]); dim->set_size(filter_size); // The window stride equals the window dilation of the forward convolution. dim->set_stride(conv->window().dimensions(i).window_dilation()); @@ -120,7 +113,8 @@ MatchBackwardFilter(HloInstruction* conv) { // activations. dim->set_padding_low(conv->window().dimensions(i).padding_low()); - int64 input_size = conv->operand(0)->shape().dimensions(spatial_dims[i]); + int64 input_size = + conv->operand(0)->shape().dimensions(input_spatial_dims[i]); int64 output_size = conv->window().dimensions(i).size(); // Compute the range of the amount of valid high padding. We first compute // min_padding_high, the amount of padding on the right/bottom to ensure the @@ -167,50 +161,32 @@ MatchBackwardFilter(HloInstruction* conv) { } } - // To make future HLO passes easier, we canonicalize the fused expression by - // adding an identity transposition if it's omitted in the pattern. - if (transpose == nullptr) { - // Create an identity transposition with the same rank as the forward - // convolution. - HloComputation* parent_computation = conv->parent(); - std::vector transpose_dimensions(ShapeUtil::Rank(conv->shape())); - std::iota(transpose_dimensions.begin(), transpose_dimensions.end(), 0); - transpose = - parent_computation->AddInstruction(HloInstruction::CreateTranspose( - conv->shape(), conv, transpose_dimensions)); - TF_CHECK_OK(conv->ReplaceAllUsesWith(transpose)); - } - // Restore the dimension numbers of the backward convolution from the forward // convolution. The two activation dimensions are reversed (batch and // feature). ConvolutionDimensionNumbers backward_conv_dnums; backward_conv_dnums.set_input_batch_dimension(input_feature_dim); backward_conv_dnums.set_input_feature_dimension(input_batch_dim); - backward_conv_dnums.set_output_batch_dimension(output_feature_dim); - backward_conv_dnums.set_output_feature_dimension(output_batch_dim); - for (int i = 0; i < spatial_dims.size(); ++i) { - backward_conv_dnums.add_spatial_dimensions(spatial_dims[i]); + for (int i = 0; i < input_spatial_dims.size(); ++i) { + backward_conv_dnums.add_input_spatial_dimensions(input_spatial_dims[i]); + } + backward_conv_dnums.set_output_batch_dimension(kernel_input_feature_dim); + backward_conv_dnums.set_output_feature_dimension(kernel_output_feature_dim); + for (int i = 0; i < kernel_spatial_dims.size(); ++i) { + backward_conv_dnums.add_output_spatial_dimensions(kernel_spatial_dims[i]); } // The dimension numbering of the output of the forward convolution (before // transposition) is the same as that of the activations (according to the // semantics of kConvolution). The batch dimension of the activations should // be treated as the input feature dimension, and the feature dimension should // be treated as the output feature. - // - // The output of the forward convolution needs to be transposed to fit into - // the dimension numbering of the weight gradients. This transposition maps - // dimension i to PositionInContainer(transpose->dimensions(), i). - backward_conv_dnums.set_kernel_input_feature_dimension( - PositionInContainer(transpose->dimensions(), output_batch_dim)); - backward_conv_dnums.set_kernel_output_feature_dimension( - PositionInContainer(transpose->dimensions(), output_feature_dim)); - for (int i = 0; i < spatial_dims.size(); ++i) { - backward_conv_dnums.add_kernel_spatial_dimensions( - PositionInContainer(transpose->dimensions(), spatial_dims[i])); + backward_conv_dnums.set_kernel_input_feature_dimension(output_batch_dim); + backward_conv_dnums.set_kernel_output_feature_dimension(output_feature_dim); + for (int i = 0; i < output_spatial_dims.size(); ++i) { + backward_conv_dnums.add_kernel_spatial_dimensions(output_spatial_dims[i]); } - return std::make_tuple(true, std::vector({transpose, conv}), + return std::make_tuple(true, std::vector({conv}), backward_conv_window, backward_conv_dnums); } @@ -272,12 +248,14 @@ MatchBackwardInput(HloInstruction* conv) { } } - const auto& spatial_dims = dnums.spatial_dimensions(); - CHECK_EQ(conv->window().dimensions().size(), spatial_dims.size()); + const auto& input_spatial_dims = dnums.input_spatial_dimensions(); + const auto& output_spatial_dims = dnums.output_spatial_dimensions(); + CHECK_EQ(conv->window().dimensions().size(), input_spatial_dims.size()); + CHECK_EQ(output_spatial_dims.size(), input_spatial_dims.size()); const Window& old_window = conv->window(); Window new_window = old_window; - for (size_t i = 0; i < spatial_dims.size(); ++i) { + for (size_t i = 0; i < input_spatial_dims.size(); ++i) { // Restore backward convolution's padding config from the matched pattern. // See the comment in tensorflow/core/kernels/conv_grad_tuple_ops.cc // for how we convert backward input convolution to a variant of forward @@ -310,8 +288,9 @@ MatchBackwardInput(HloInstruction* conv) { // end at the border. The maximum amount (max_padding_high) equals // min_padding_high+stride-1 -- max_padding_high+1 would cause the output // size to change. - auto unpadded_input_size = conv->shape().dimensions(spatial_dims[i]); - auto output_size = conv->operand(0)->shape().dimensions(spatial_dims[i]); + auto unpadded_input_size = conv->shape().dimensions(output_spatial_dims[i]); + auto output_size = + conv->operand(0)->shape().dimensions(input_spatial_dims[i]); auto padded_input_size = kernel_size + dim->stride() * (output_size - 1); auto total_pad_size = padded_input_size - unpadded_input_size; auto min_padding_high = total_pad_size - backward_padding_low; diff --git a/tensorflow/compiler/xla/service/gpu/convolution_folding_test.cc b/tensorflow/compiler/xla/service/gpu/convolution_folding_test.cc index 19b122ba0603b4ec08d73e05da4c2ae11a760553..34e6bdb117d47a3d7e1eb3bae5806e130e94ea79 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_folding_test.cc +++ b/tensorflow/compiler/xla/service/gpu/convolution_folding_test.cc @@ -46,23 +46,27 @@ class ConvolutionFoldingTest : public HloTestBase { // // TODO(jingyue): Add more tests on NCHW input order which TF also supports. tf_default_dnums_for_backward_filter_.set_input_batch_dimension(3); - tf_default_dnums_for_backward_filter_.set_output_batch_dimension(3); tf_default_dnums_for_backward_filter_.set_input_feature_dimension(0); - tf_default_dnums_for_backward_filter_.set_output_feature_dimension(0); - tf_default_dnums_for_backward_filter_.add_spatial_dimensions(1); - tf_default_dnums_for_backward_filter_.add_spatial_dimensions(2); + tf_default_dnums_for_backward_filter_.add_input_spatial_dimensions(1); + tf_default_dnums_for_backward_filter_.add_input_spatial_dimensions(2); tf_default_dnums_for_backward_filter_.set_kernel_input_feature_dimension(0); tf_default_dnums_for_backward_filter_.set_kernel_output_feature_dimension( 3); tf_default_dnums_for_backward_filter_.add_kernel_spatial_dimensions(1); tf_default_dnums_for_backward_filter_.add_kernel_spatial_dimensions(2); + tf_default_dnums_for_backward_filter_.add_output_spatial_dimensions(0); + tf_default_dnums_for_backward_filter_.add_output_spatial_dimensions(1); + tf_default_dnums_for_backward_filter_.set_output_batch_dimension(2); + tf_default_dnums_for_backward_filter_.set_output_feature_dimension(3); tf_default_dnums_for_backward_input_.set_input_batch_dimension(0); tf_default_dnums_for_backward_input_.set_output_batch_dimension(0); tf_default_dnums_for_backward_input_.set_input_feature_dimension(3); tf_default_dnums_for_backward_input_.set_output_feature_dimension(3); - tf_default_dnums_for_backward_input_.add_spatial_dimensions(1); - tf_default_dnums_for_backward_input_.add_spatial_dimensions(2); + tf_default_dnums_for_backward_input_.add_input_spatial_dimensions(1); + tf_default_dnums_for_backward_input_.add_output_spatial_dimensions(1); + tf_default_dnums_for_backward_input_.add_input_spatial_dimensions(2); + tf_default_dnums_for_backward_input_.add_output_spatial_dimensions(2); tf_default_dnums_for_backward_input_.set_kernel_input_feature_dimension(3); tf_default_dnums_for_backward_input_.set_kernel_output_feature_dimension(2); tf_default_dnums_for_backward_input_.add_kernel_spatial_dimensions(0); @@ -82,7 +86,7 @@ class ConvolutionFoldingTest : public HloTestBase { ConvolutionDimensionNumbers tf_default_dnums_for_backward_input_; }; -TEST_F(ConvolutionFoldingTest, BackwardFilterConvolveWithoutTranspose) { +TEST_F(ConvolutionFoldingTest, BackwardFilterConvolve) { HloComputation::Builder builder(TestName()); HloInstruction* activations = builder.AddInstruction(HloInstruction::CreateParameter( @@ -132,7 +136,7 @@ TEST_F(ConvolutionFoldingTest, auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(FoldConvolution(module.get())); + EXPECT_TRUE(FoldConvolution(module.get())); } // Extracted from block35 training. @@ -151,13 +155,9 @@ TEST_F(ConvolutionFoldingTest, BackwardFilterConvolveWithPaddedActivations) { conv_window.mutable_dimensions(i)->set_padding_low(1); conv_window.mutable_dimensions(i)->set_padding_high(1); } - HloInstruction* convolution = - builder.AddInstruction(HloInstruction::CreateConvolve( - ShapeUtil::MakeShape(F32, {32, 3, 3, 32}), activations, gradients, - conv_window, tf_default_dnums_for_backward_filter_)); - - builder.AddInstruction(HloInstruction::CreateTranspose( - ShapeUtil::MakeShape(F32, {3, 3, 32, 32}), convolution, {1, 2, 3, 0})); + builder.AddInstruction(HloInstruction::CreateConvolve( + ShapeUtil::MakeShape(F32, {32, 3, 3, 32}), activations, gradients, + conv_window, tf_default_dnums_for_backward_filter_)); auto module = CreateNewModule(); HloComputation* entry_computation = @@ -185,13 +185,9 @@ TEST_F(ConvolutionFoldingTest, BackwardFilterConvolveWithPaddedGradients) { conv_window.mutable_dimensions(i)->set_padding_high(-1); conv_window.mutable_dimensions(i)->set_window_dilation(2); } - HloInstruction* convolution = - builder.AddInstruction(HloInstruction::CreateConvolve( - ShapeUtil::MakeShape(F32, {320, 3, 3, 192}), activations, gradients, - conv_window, tf_default_dnums_for_backward_filter_)); - - builder.AddInstruction(HloInstruction::CreateTranspose( - ShapeUtil::MakeShape(F32, {3, 3, 192, 320}), convolution, {1, 2, 3, 0})); + builder.AddInstruction(HloInstruction::CreateConvolve( + ShapeUtil::MakeShape(F32, {320, 3, 3, 192}), activations, gradients, + conv_window, tf_default_dnums_for_backward_filter_)); auto module = CreateNewModule(); HloComputation* entry_computation = @@ -218,13 +214,9 @@ TEST_F(ConvolutionFoldingTest, BackwardFilterConvolveWithUnevenPadding) { // Uneven padding: padding_low=0, padding_high=1 conv_window.mutable_dimensions(i)->set_padding_high(1); } - HloInstruction* convolution = - builder.AddInstruction(HloInstruction::CreateConvolve( - ShapeUtil::MakeShape(F32, {32, 2, 2, 32}), activations, gradients, - conv_window, tf_default_dnums_for_backward_filter_)); - - builder.AddInstruction(HloInstruction::CreateTranspose( - ShapeUtil::MakeShape(F32, {2, 2, 32, 32}), convolution, {1, 2, 3, 0})); + builder.AddInstruction(HloInstruction::CreateConvolve( + ShapeUtil::MakeShape(F32, {32, 2, 2, 32}), activations, gradients, + conv_window, tf_default_dnums_for_backward_filter_)); auto module = CreateNewModule(); HloComputation* entry_computation = @@ -258,8 +250,10 @@ TEST_F(ConvolutionFoldingTest, BackwardInputConvolveEvenPadding) { conv_dnums.set_output_batch_dimension(0); conv_dnums.set_input_feature_dimension(1); conv_dnums.set_output_feature_dimension(1); - conv_dnums.add_spatial_dimensions(2); - conv_dnums.add_spatial_dimensions(3); + conv_dnums.add_input_spatial_dimensions(2); + conv_dnums.add_output_spatial_dimensions(2); + conv_dnums.add_input_spatial_dimensions(3); + conv_dnums.add_output_spatial_dimensions(3); conv_dnums.set_kernel_input_feature_dimension(0); conv_dnums.set_kernel_output_feature_dimension(1); conv_dnums.add_kernel_spatial_dimensions(2); diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc index 536b96dcf620e908e25a775bc2efb57ba5f5edd6..899cc5c83b99f1bb6154f883ca17871863e1f457 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -28,12 +29,12 @@ namespace se = ::perftools::gputools; namespace xla { namespace gpu { +using se::dnn::AlgorithmDesc; using se::dnn::BatchDescriptor; using se::dnn::ConvolutionDescriptor; using se::dnn::DataLayout; using se::dnn::FilterDescriptor; using se::dnn::FilterLayout; -using se::dnn::AlgorithmDesc; ConvolveScratchAllocator::ConvolveScratchAllocator( int device_ordinal, DeviceMemoryAllocator* memory_allocator) @@ -130,8 +131,9 @@ tensorflow::Status ConvolutionThunk::ExecuteOnStream( const int effective_num_dimensions = std::max(2, num_dimensions); CHECK_EQ(F32, output_shape_.element_type()); - CHECK_EQ(num_dimensions, dim_nums_.spatial_dimensions_size()); + CHECK_EQ(num_dimensions, dim_nums_.input_spatial_dimensions_size()); CHECK_EQ(num_dimensions, dim_nums_.kernel_spatial_dimensions_size()); + CHECK_EQ(num_dimensions, dim_nums_.output_spatial_dimensions_size()); for (const WindowDimension& dim : window_.dimensions()) { CHECK_EQ(dim.padding_low(), dim.padding_high()); } @@ -147,7 +149,7 @@ tensorflow::Status ConvolutionThunk::ExecuteOnStream( // Note that the dimensions are reversed. The same holds below. input_descriptor.set_spatial_dim( static_cast(effective_num_dimensions - dim - 1), - input_shape_.dimensions(dim_nums_.spatial_dimensions(dim))); + input_shape_.dimensions(dim_nums_.input_spatial_dimensions(dim))); } FilterDescriptor filter_descriptor(effective_num_dimensions); @@ -181,7 +183,7 @@ tensorflow::Status ConvolutionThunk::ExecuteOnStream( for (int dim = 0; dim < num_dimensions; ++dim) { output_descriptor.set_spatial_dim( static_cast(effective_num_dimensions - dim - 1), - output_shape_.dimensions(dim_nums_.spatial_dimensions(dim))); + output_shape_.dimensions(dim_nums_.output_spatial_dimensions(dim))); } // Add a singleton dimension in the 1D convolution case. @@ -257,28 +259,52 @@ tensorflow::Status ConvolutionThunk::Convolve( } std::vector ConvolutionThunk::GetAlgorithms( - se::StreamExecutor* stream_exec) const { + bool with_winograd_nonfused, se::StreamExecutor* stream_exec) const { std::vector algorithms; - // TODO(yangzihao): Currently disable the use of winograd nonfused in XLA - // by default. Should send in conv parameters and enable it when - // ShouldIncludeWinogradNonfusedAlgo() returns true. switch (convolution_kind_) { case ConvolutionKind::kBackwardFilter: CHECK(stream_exec->GetConvolveBackwardFilterAlgorithms( - /*with_winograd_nonfused=*/false, &algorithms)); + with_winograd_nonfused, &algorithms)); break; case ConvolutionKind::kBackwardInput: CHECK(stream_exec->GetConvolveBackwardDataAlgorithms( - /*with_winograd_nonfused=*/false, &algorithms)); + with_winograd_nonfused, &algorithms)); break; case ConvolutionKind::kForward: - CHECK(stream_exec->GetConvolveAlgorithms(/*with_winograd_nonfused=*/false, + CHECK(stream_exec->GetConvolveAlgorithms(with_winograd_nonfused, &algorithms)); break; } return algorithms; } +static string AlgorithmToString(const se::dnn::AlgorithmDesc& algo) { + if (algo.tensor_ops_enabled()) { + return tensorflow::strings::StrCat(algo.algo_id(), "+TC"); + } + return tensorflow::strings::StrCat(algo.algo_id()); +} + +// Determines whether we can safely perform a winograd non-fused convolution for +// the given input and output descriptors. This works around b/68264959, an +// integer overflow in cuDNNv5 and cuDNNv6. +static bool ShouldIncludeWinogradNonfusedAlgo( + const BatchDescriptor& input_descriptor, + const BatchDescriptor& output_descriptor) { + int64 batch = input_descriptor.count(); + int64 in_depths = input_descriptor.feature_map_count(); + int64 in_rows = input_descriptor.height(); + int64 in_cols = input_descriptor.width(); + int64 out_depths = output_descriptor.feature_map_count(); + + int64 total_size = 16 * std::ceil(batch / 16.0) * + std::max(in_depths, out_depths) * in_cols * in_rows * + sizeof(float); + int64 threshold = 1L << 31; + + return total_size < threshold; +} + tensorflow::Status ConvolutionThunk::ConvolveWithTune( const BatchDescriptor& input_descriptor, se::DeviceMemory input_data, const FilterDescriptor& filter_descriptor, @@ -288,21 +314,29 @@ tensorflow::Status ConvolutionThunk::ConvolveWithTune( const ConvolutionDescriptor& convolution_descriptor, const BufferAllocations& buffer_allocations, se::Stream* stream) { // TODO(b/29126320): Try cudnn v5's new auto-tuner when it's rolled out. - if (best_algorithm_.algorithm().is_default()) { + if (!best_algorithm_.has_value()) { + best_algorithm_.emplace(); + // Auto-tuning either is disabled or only happens in the first run of this // function. VLOG(2) << "Profiling for best convolution algorithm used for " "ConvolutionThunk: " << this; + bool with_winograd_nonfused = + ShouldIncludeWinogradNonfusedAlgo(input_descriptor, output_descriptor); + se::dnn::ProfileResult best_result; se::dnn::ProfileResult best_result_without_scratch; - std::vector algorithms = GetAlgorithms(stream->parent()); + std::vector algorithms = + GetAlgorithms(with_winograd_nonfused, stream->parent()); for (auto algorithm : algorithms) { ConvolveScratchAllocator scratch_allocator( buffer_allocations.device_ordinal(), buffer_allocations.memory_allocator()); se::dnn::ProfileResult profile_result; + VLOG(3) << "Trying algorithm " << AlgorithmToString(algorithm) + << " for ConvolutionThunk: " << this; bool launch_ok = Convolve(input_descriptor, input_data, filter_descriptor, filter_data, output_descriptor, output_data, convolution_descriptor, @@ -310,6 +344,11 @@ tensorflow::Status ConvolutionThunk::ConvolveWithTune( &scratch_allocator, &profile_result) .ok(); if (launch_ok && profile_result.is_valid()) { + VLOG(3) << "Run of algorithm " << AlgorithmToString(algorithm) + << " for ConvolutionThunk " << this << " succeeded, taking " + << profile_result.elapsed_time_in_ms() + << "ms. (Best result: " << best_result.elapsed_time_in_ms() + << "ms)"; if (profile_result.elapsed_time_in_ms() < best_result.elapsed_time_in_ms()) { best_result = profile_result; @@ -319,39 +358,42 @@ tensorflow::Status ConvolutionThunk::ConvolveWithTune( best_result_without_scratch.elapsed_time_in_ms()) { best_result_without_scratch = profile_result; } + } else { + VLOG(3) << "Run of algorithm " << AlgorithmToString(algorithm) + << " for ConvolutionThunk " << this << " failed."; } } if (best_result.is_valid()) { - best_algorithm_.set_algorithm(best_result.algorithm()); + best_algorithm_->set_algorithm(best_result.algorithm()); } else { LOG(ERROR) << "No convolution algorithm works with profiling. Fall back " "to the default algorithm."; - best_algorithm_.set_algorithm(AlgorithmDesc()); + best_algorithm_->set_algorithm(AlgorithmDesc()); } if (best_result_without_scratch.is_valid()) { - best_algorithm_.set_algorithm_no_scratch( + best_algorithm_->set_algorithm_no_scratch( best_result_without_scratch.algorithm()); } else { LOG(ERROR) << "No convolution algorithm without scratch works with " "profiling. Fall back " "to the default algorithm."; - best_algorithm_.set_algorithm_no_scratch(AlgorithmDesc()); + best_algorithm_->set_algorithm_no_scratch(AlgorithmDesc()); } } { VLOG(2) << "Using convolution algorithm (" - << best_algorithm_.algorithm().algo_id() << ", " - << best_algorithm_.algorithm_no_scratch().algo_id() + << AlgorithmToString(best_algorithm_->algorithm()) << ", " + << AlgorithmToString(best_algorithm_->algorithm_no_scratch()) << ") for ConvolutionThunk: " << this; ConvolveScratchAllocator scratch_allocator( buffer_allocations.device_ordinal(), buffer_allocations.memory_allocator()); return Convolve(input_descriptor, input_data, filter_descriptor, filter_data, output_descriptor, output_data, - convolution_descriptor, best_algorithm_, stream, + convolution_descriptor, *best_algorithm_, stream, &scratch_allocator, nullptr); } } diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h index 13432301b2af34ab4bd0864e39ce22366cc1d11d..7c25a2e6450e30292667ecd7de54b50ac2450767 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace xla { @@ -87,6 +88,14 @@ class ConvolutionThunk : public Thunk { const BufferAllocations& buffer_allocations, perftools::gputools::Stream* stream) override; + // Returns true if the next run of ExecuteOnStream will do autotuning. If so, + // we want the GPU to be quiescent during autotuning, so as not to introduce + // noise in our results. + bool ShouldHaltAllActivityBeforeRunning( + perftools::gputools::Stream*) override { + return !best_algorithm_.has_value(); + } + private: tensorflow::Status ConvolveWithTune( const perftools::gputools::dnn::BatchDescriptor& input_descriptor, @@ -116,13 +125,15 @@ class ConvolutionThunk : public Thunk { // Returns the convolve algorithms that can be used for this ConvolutionThunk. std::vector GetAlgorithms( + bool with_winograd_nonfused, perftools::gputools::StreamExecutor* stream_exec) const; // Fastest cuDNN convolution algorithm for this thunk learned from // auto-tuning. If auto-tuning is disabled or failed, best_algorithm_ is set - // to the default value indicating cuDNN's convolution will choose - // the best algorithm from some heuristics based on its parameters. - perftools::gputools::dnn::AlgorithmConfig best_algorithm_; + // to the default value, indicating cuDNN's convolution will choose the best + // algorithm from some heuristics based on its parameters. + tensorflow::gtl::optional + best_algorithm_; const ConvolutionKind convolution_kind_; diff --git a/tensorflow/compiler/xla/service/gpu/copy_insertion.cc b/tensorflow/compiler/xla/service/gpu/copy_insertion.cc deleted file mode 100644 index 3dc85552015be67c20db9099704334c864b44b51..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/gpu/copy_insertion.cc +++ /dev/null @@ -1,71 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/gpu/copy_insertion.h" - -#include -#include -#include - -#include "tensorflow/compiler/xla/service/copy_insertion.h" -#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" -#include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/service/logical_buffer.h" -#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/platform/logging.h" - -namespace xla { -namespace gpu { - -StatusOr GpuCopyInsertion::Run(HloModule* module) { - TF_ASSIGN_OR_RETURN(bool changed, CopyInsertion::Run(module)); - - TF_ASSIGN_OR_RETURN(auto points_to_analysis, - TuplePointsToAnalysis::Run(module)); - - // Make sure all operands of a library call are in memory instead of constants - // in IR. The top-level (index {}) of the points-to set of each operand - // indicates the source(s) of the array buffer. If any of these are constant, - // then add a copy to materialize the array. - HloComputation* computation = module->entry_computation(); - for (HloInstruction* hlo : computation->MakeInstructionPostOrder()) { - if (ImplementedAsLibraryCall(*hlo)) { - for (int64 i = 0; i < hlo->operand_count(); ++i) { - HloInstruction* operand = hlo->mutable_operand(i); - const PointsToSet& points_to = - points_to_analysis->GetPointsToSet(operand); - const auto& element = points_to.element(/*index=*/{}); - if (std::any_of(element.begin(), element.end(), - [](const LogicalBuffer* buffer_source) { - return buffer_source->instruction()->opcode() == - HloOpcode::kConstant; - })) { - TF_ASSIGN_OR_RETURN(HloInstruction * copy, - CopyInsertion::FindOrInsertCopy(operand)); - TF_RETURN_IF_ERROR(hlo->ReplaceOperandWith(i, copy)); - changed = true; - } - } - } - } - - return changed; -} - -} // namespace gpu -} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc index 1b94499bc6ef6d587cdb1fafec48bc4e5b917c51..6bf00cfb8a53723ae9608093480bf2eed10144dd 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc @@ -230,6 +230,66 @@ StatusOr GpuElementalIrEmitter::EmitFloatUnaryOp( } } +StatusOr GpuElementalIrEmitter::EmitComplexBinaryOp( + const HloInstruction* op, llvm::Value* lhs_value, + llvm::Value* rhs_value) const { + PrimitiveType input_type = op->operand(0)->shape().element_type(); + TF_RET_CHECK(primitive_util::IsComplexType(input_type)); + PrimitiveType component_type = + primitive_util::ComplexComponentType(input_type); + switch (op->opcode()) { + case HloOpcode::kPower: { + // (a+bi)^(c+di) = + // (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)), + // where q = c*atan2(b,a)+0.5d*ln(a*a+b*b) + auto a = EmitExtractReal(lhs_value); + auto b = EmitExtractImag(lhs_value); + auto c = EmitExtractReal(rhs_value); + auto d = EmitExtractImag(rhs_value); + auto aa_p_bb = ir_builder_->CreateFAdd(ir_builder_->CreateFMul(a, a), + ir_builder_->CreateFMul(b, b)); + auto one_half = llvm::ConstantFP::get(a->getType(), 0.5); + auto half_c = ir_builder_->CreateFMul(one_half, c); + + TF_ASSIGN_OR_RETURN( + auto aa_p_bb_to_half_c, + EmitLibdeviceMathCall("__nv_pow", {aa_p_bb, half_c}, + {component_type, component_type}, + component_type)); + auto neg_d = ir_builder_->CreateFNeg(d); + TF_ASSIGN_OR_RETURN( + auto arg_lhs, EmitLibdeviceMathCall("__nv_atan2", {b, a}, + {component_type, component_type}, + component_type)); + auto neg_d_arg_lhs = ir_builder_->CreateFMul(neg_d, arg_lhs); + TF_ASSIGN_OR_RETURN( + auto e_to_neg_d_arg_lhs, + EmitLibdeviceMathCall("__nv_exp", {neg_d_arg_lhs}, {component_type}, + component_type)); + auto coeff = + ir_builder_->CreateFMul(aa_p_bb_to_half_c, e_to_neg_d_arg_lhs); + TF_ASSIGN_OR_RETURN( + auto ln_aa_p_bb, + EmitLibdeviceMathCall("__nv_log", {aa_p_bb}, {component_type}, + component_type)); + auto half_d = ir_builder_->CreateFMul(one_half, d); + auto q = + ir_builder_->CreateFAdd(ir_builder_->CreateFMul(c, arg_lhs), + ir_builder_->CreateFMul(half_d, ln_aa_p_bb)); + TF_ASSIGN_OR_RETURN( + auto cos_q, EmitLibdeviceMathCall("__nv_cos", {q}, {component_type}, + component_type)); + TF_ASSIGN_OR_RETURN( + auto sin_q, EmitLibdeviceMathCall("__nv_sin", {q}, {component_type}, + component_type)); + return EmitComposeComplex(op, ir_builder_->CreateFMul(coeff, cos_q), + ir_builder_->CreateFMul(coeff, sin_q)); + } + default: + return ElementalIrEmitter::EmitComplexBinaryOp(op, lhs_value, rhs_value); + } +} + StatusOr GpuElementalIrEmitter::EmitComplexUnaryOp( const HloInstruction* op, llvm::Value* operand_value) const { PrimitiveType input_type = op->operand(0)->shape().element_type(); @@ -237,18 +297,12 @@ StatusOr GpuElementalIrEmitter::EmitComplexUnaryOp( primitive_util::IsComplexType(input_type) ? primitive_util::ComplexComponentType(input_type) : input_type; - auto real = [&](llvm::Value* x) { - return ir_builder_->CreateExtractValue(x, {0}); - }; - auto imag = [&](llvm::Value* x) { - return ir_builder_->CreateExtractValue(x, {1}); - }; switch (op->opcode()) { case HloOpcode::kLog: { // log(a+bi) = .5*log(a^2+b^2) + i*atan2(b, a) - auto a = real(operand_value); - auto b = imag(operand_value); + auto a = EmitExtractReal(operand_value); + auto b = EmitExtractImag(operand_value); llvm::Type* llvm_ty = a->getType(); auto sum_sq = ir_builder_->CreateFAdd(ir_builder_->CreateFMul(a, a), ir_builder_->CreateFMul(b, b)); @@ -261,34 +315,33 @@ StatusOr GpuElementalIrEmitter::EmitComplexUnaryOp( {component_type, component_type}, component_type)); auto one_half = llvm::ConstantFP::get(llvm_ty, 0.5); - return ComposeComplex(op, ir_builder_->CreateFMul(one_half, log_sum_sq), - angle); + return EmitComposeComplex( + op, ir_builder_->CreateFMul(one_half, log_sum_sq), angle); } - // TODO(b/65408531): Implement kPower on GPU, where atan2 is available. - // case HloOpcode::kPower: - // // (a+bi)^(c+di) = exp(i(c+di)*arg(a+bi)) * (a*a+b*b)^(0.5(c+di)) case HloOpcode::kExp: { // e^(a+bi) = e^a*(cos(b)+sin(b)i) - auto b = imag(operand_value); + auto b = EmitExtractImag(operand_value); TF_ASSIGN_OR_RETURN( - auto exp_a, EmitLibdeviceMathCall("__nv_exp", {real(operand_value)}, - {component_type}, component_type)); + auto exp_a, + EmitLibdeviceMathCall("__nv_exp", {EmitExtractReal(operand_value)}, + {component_type}, component_type)); TF_ASSIGN_OR_RETURN( auto cos_b, EmitLibdeviceMathCall("__nv_cos", {b}, {component_type}, component_type)); TF_ASSIGN_OR_RETURN( auto sin_b, EmitLibdeviceMathCall("__nv_sin", {b}, {component_type}, component_type)); - return ComposeComplex(op, ir_builder_->CreateFMul(exp_a, cos_b), - ir_builder_->CreateFMul(exp_a, sin_b)); + return EmitComposeComplex(op, ir_builder_->CreateFMul(exp_a, cos_b), + ir_builder_->CreateFMul(exp_a, sin_b)); } case HloOpcode::kCos: { // cos(a+bi) = .5(cos(a)*(e^-b+e^b) + i*sin(a)*(e^-b-e^b)) - auto a = real(operand_value); + auto a = EmitExtractReal(operand_value); auto llvm_ty = a->getType(); TF_ASSIGN_OR_RETURN( - auto exp_b, EmitLibdeviceMathCall("__nv_exp", {imag(operand_value)}, - {component_type}, component_type)); + auto exp_b, + EmitLibdeviceMathCall("__nv_exp", {EmitExtractImag(operand_value)}, + {component_type}, component_type)); TF_ASSIGN_OR_RETURN( auto cos_a, EmitLibdeviceMathCall("__nv_cos", {a}, {component_type}, component_type)); @@ -299,7 +352,7 @@ StatusOr GpuElementalIrEmitter::EmitComplexUnaryOp( ir_builder_->CreateFMul(llvm::ConstantFP::get(llvm_ty, 0.5), exp_b); auto half_exp_neg_b = ir_builder_->CreateFDiv(llvm::ConstantFP::get(llvm_ty, 0.5), exp_b); - return ComposeComplex( + return EmitComposeComplex( op, ir_builder_->CreateFMul( cos_a, ir_builder_->CreateFAdd(half_exp_neg_b, half_exp_b)), @@ -309,11 +362,12 @@ StatusOr GpuElementalIrEmitter::EmitComplexUnaryOp( case HloOpcode::kSin: { // sin(a+bi) = 0.5(sin(a)*(e^b+e^-b) + i*cos(a)*(e^b-e^-b) - auto a = real(operand_value); + auto a = EmitExtractReal(operand_value); auto llvm_ty = a->getType(); TF_ASSIGN_OR_RETURN( - auto exp_b, EmitLibdeviceMathCall("__nv_exp", {imag(operand_value)}, - {component_type}, component_type)); + auto exp_b, + EmitLibdeviceMathCall("__nv_exp", {EmitExtractImag(operand_value)}, + {component_type}, component_type)); TF_ASSIGN_OR_RETURN( auto cos_a, EmitLibdeviceMathCall("__nv_cos", {a}, {component_type}, component_type)); @@ -324,13 +378,71 @@ StatusOr GpuElementalIrEmitter::EmitComplexUnaryOp( ir_builder_->CreateFMul(llvm::ConstantFP::get(llvm_ty, 0.5), exp_b); auto half_exp_neg_b = ir_builder_->CreateFDiv(llvm::ConstantFP::get(llvm_ty, 0.5), exp_b); - return ComposeComplex( + return EmitComposeComplex( op, ir_builder_->CreateFMul( sin_a, ir_builder_->CreateFAdd(half_exp_b, half_exp_neg_b)), ir_builder_->CreateFMul( cos_a, ir_builder_->CreateFSub(half_exp_b, half_exp_neg_b))); } + case HloOpcode::kTanh: { + /* + tanh=(exp(x)-exp(-x)) / (exp(x)+exp(-x)) + e^(a+bi) = e^a*(cos(b)+sin(b)i) + so tanh=(((cos(b)+sin(b)i)e^a - (cos(-b)+sin(-b)i)e^-a)) / + (((cos(b)+sin(b)i)e^a + (cos(-b)+sin(-b)i)e^-a)) + cos(b)=cos(-b), sin(-b)=-sin(b) + so tanh=(((cos(b)+sin(b)i)e^a - (cos(b)-sin(b)i)e^-a)) / + (((cos(b)+sin(b)i)e^a + (cos(b)-sin(b)i)e^-a)) + =(cos(b)e^a+i*sin(b)e^a + cos(b)(-e^-a)+i*sin(b)e^-a) / + (cos(b)e^a+i*sin(b)e^a + cos(b)e^-a+i*sin(b)(-e^-a)) + =(cos(b)(e^a-e^-a) + i*sin(b)(e^a+e^-a)) / + (cos(b)(e^a+e^-a) + i*sin(b)(e^a-e^-a)) + This is a complex division, so we can multiply by denom_conj/denom_conj + =(cos(b)(e^a-e^-a) + i*sin(b)(e^a+e^-a)) * + (cos(b)(e^a+e^-a) - i*sin(b)(e^a-e^-a)) / + ((cos(b)(e^a+e^-a))^2 + (sin(b)(e^a-e^-a))^2) + =(cos(b)^2(e^(2a)-e^(-2a)) + sin(b)^2(e^(2a)-e^(-2a)) + + i*(cos(b)sin(b)(e^a+e^-a)^2 - cos(b)sin(b)(e^a-e^-a)^2)) / + ((cos(b)(e^a+e^-a))^2 + (sin(b)(e^a-e^-a))^2) + */ + auto a = EmitExtractReal(operand_value); + auto b = EmitExtractImag(operand_value); + TF_ASSIGN_OR_RETURN( + auto exp_a, EmitLibdeviceMathCall("__nv_exp", {a}, {component_type}, + component_type)); + TF_ASSIGN_OR_RETURN( + auto cos_b, EmitLibdeviceMathCall("__nv_cos", {b}, {component_type}, + component_type)); + TF_ASSIGN_OR_RETURN( + auto sin_b, EmitLibdeviceMathCall("__nv_sin", {b}, {component_type}, + component_type)); + auto exp_neg_a = ir_builder_->CreateFDiv( + llvm::ConstantFP::get(exp_a->getType(), 1), exp_a); + auto exp_2a_minus_exp_neg_2a = ir_builder_->CreateFSub( + ir_builder_->CreateFMul(exp_a, exp_a), + ir_builder_->CreateFMul(exp_neg_a, exp_neg_a)); + auto cos_b_sq = ir_builder_->CreateFMul(cos_b, cos_b); + auto sin_b_sq = ir_builder_->CreateFMul(sin_b, sin_b); + auto real_num = ir_builder_->CreateFAdd( + ir_builder_->CreateFMul(cos_b_sq, exp_2a_minus_exp_neg_2a), + ir_builder_->CreateFMul(sin_b_sq, exp_2a_minus_exp_neg_2a)); + auto cos_b_sin_b = ir_builder_->CreateFMul(cos_b, sin_b); + auto exp_a_plus_exp_neg_a = ir_builder_->CreateFAdd(exp_a, exp_neg_a); + auto exp_a_plus_exp_neg_a_sq = + ir_builder_->CreateFMul(exp_a_plus_exp_neg_a, exp_a_plus_exp_neg_a); + auto exp_a_minus_exp_neg_a = ir_builder_->CreateFSub(exp_a, exp_neg_a); + auto exp_a_minus_exp_neg_a_sq = + ir_builder_->CreateFMul(exp_a_minus_exp_neg_a, exp_a_minus_exp_neg_a); + auto imag_num = ir_builder_->CreateFMul( + cos_b_sin_b, ir_builder_->CreateFSub(exp_a_plus_exp_neg_a_sq, + exp_a_minus_exp_neg_a_sq)); + auto denom = ir_builder_->CreateFAdd( + ir_builder_->CreateFMul(cos_b_sq, exp_a_plus_exp_neg_a_sq), + ir_builder_->CreateFMul(sin_b_sq, exp_a_minus_exp_neg_a_sq)); + return EmitComposeComplex(op, ir_builder_->CreateFDiv(real_num, denom), + ir_builder_->CreateFDiv(imag_num, denom)); + } default: return ElementalIrEmitter::EmitComplexUnaryOp(op, operand_value); } diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h index 3defa1b696d3addc012702e23102bb1fa140170d..6a537d015209bc507af36b13eeb5d69ce58d8fea 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h @@ -61,6 +61,10 @@ class GpuElementalIrEmitter : public ElementalIrEmitter { const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) const override; + StatusOr EmitComplexBinaryOp( + const HloInstruction* op, llvm::Value* lhs_value, + llvm::Value* rhs_value) const override; + StatusOr EmitErfcInv(PrimitiveType prim_type, llvm::Value* value) const override; diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h index 983cb872924f22be0dfad8aa9ad86f233b909c46..8c6a1f51a8a09ef78950dfe7e89994a3fe247f49 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h @@ -52,6 +52,15 @@ class GemmThunk : public Thunk { const BufferAllocations& buffer_allocations, perftools::gputools::Stream* stream) override; + // Returns true if we'll perform autotuning if run on the given stream. If + // so, we want the GPU to be quiescent during autotuning, so as not to + // introduce noise in our results. + bool ShouldHaltAllActivityBeforeRunning( + perftools::gputools::Stream* stream) override { + return autotune_results_.count( + stream->parent()->GetDeviceDescription().name()) != 0; + } + private: const BufferAllocation::Slice lhs_buffer_; const BufferAllocation::Slice rhs_buffer_; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index b5331fe4e2ba34443555e9bf46dfc188cbd6548a..1ccfe323c58422c99fab5efa578be2a1e23e3d1b 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/gpu_compiler.h" #include +#include #include #include @@ -30,17 +31,18 @@ limitations under the License. #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/buffer_liveness.h" #include "tensorflow/compiler/xla/service/call_inliner.h" +#include "tensorflow/compiler/xla/service/dot_decomposer.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/gpu/convolution_folding.h" -#include "tensorflow/compiler/xla/service/gpu/copy_insertion.h" #include "tensorflow/compiler/xla/service/gpu/fusion_merger.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h" #include "tensorflow/compiler/xla/service/gpu/hlo_schedule.h" #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/gpu/ir_emitter.h" #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" -#include "tensorflow/compiler/xla/service/gpu/layout_assignment.h" #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" #include "tensorflow/compiler/xla/service/gpu/pad_insertion.h" #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" @@ -62,10 +64,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/reshape_mover.h" #include "tensorflow/compiler/xla/service/transpose_folding.h" #include "tensorflow/compiler/xla/service/tuple_simplifier.h" +#include "tensorflow/compiler/xla/service/while_loop_simplifier.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/cuda_libdevice_path.h" @@ -73,6 +77,7 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/subprocess.h" +#include "tensorflow/core/platform/tracing.h" namespace se = ::perftools::gputools; @@ -85,6 +90,7 @@ namespace gpu { namespace { +using tensorflow::port::Tracing; using tensorflow::strings::StrCat; // Any address of a variable residing in global memory or returned by one of the @@ -94,15 +100,13 @@ using tensorflow::strings::StrCat; // http://docs.nvidia.com/cuda/cuda-c-programming-guide/#device-memory-accesses constexpr int64 kMemoryAlignment = 256; -// Returns the directory containing nvvm libdevice files. This function is -// called in GpuCompiler's constructor, so can't return an error. But -// GpuCompiler::Compile will return an error when the wanted libdevice file -// doesn't exist in the folder this function returns. -string GetLibdeviceDir(const HloModuleConfig& config) { +// Returns the directory containing nvvm libdevice files. config_cuda_data_dir +// should be equal to config().debug_options().xla_gpu_cuda_data_dir() of the +// HloModule being compiled. +string GetLibdeviceDir(const string& config_cuda_data_dir) { std::vector potential_libdevice_dirs; - const string datadir = config.debug_options().xla_gpu_cuda_data_dir(); - if (!datadir.empty()) { - potential_libdevice_dirs.push_back(datadir); + if (!config_cuda_data_dir.empty()) { + potential_libdevice_dirs.push_back(config_cuda_data_dir); } potential_libdevice_dirs.push_back(tensorflow::LibdeviceRoot()); @@ -123,7 +127,7 @@ string GetLibdeviceDir(const HloModuleConfig& config) { // Runs optimization passes on the given HLO module. tensorflow::Status OptimizeHloModule( - HloModule* hlo_module, const se::DeviceDescription& device_desc, + HloModule* hlo_module, const HloCostAnalysis::ShapeSizeFunction& shape_size_function) { { HloPassPipeline pipeline("optimization"); @@ -134,7 +138,7 @@ tensorflow::Status OptimizeHloModule( // TODO(b/64094172): make Call work on GPU instead of inlining. pipeline.AddPass(); - + pipeline.AddPass(); { auto& pass = pipeline.AddPass>("simplification"); @@ -151,6 +155,7 @@ tensorflow::Status OptimizeHloModule( /*is_layout_sensitive=*/false, [](const Shape&, const Shape&) { return false; }); pass.AddPass(); + pass.AddPass(); pass.AddPass(); pass.AddPass(); pass.AddPass(); @@ -220,66 +225,94 @@ tensorflow::Status PrepareHloModuleForIrEmitting( // (and sometime after) copy insertion, to avoid dead code from interfering // with the rewrites. pipeline.AddPass(); - pipeline.AddPass(); - pipeline.AddPass(); pipeline.AddPass(); + pipeline.AddPass(); return pipeline.Run(hlo_module).status(); } -// Invokes the ptxas tool on the given PTX string, and dumps its output. -void DumpPtxasInfo(const string& ptx, int cc_major, int cc_minor) { +// Compiles the given PTX string using ptxas and returns the resulting machine +// code (i.e. a cubin) as a byte array. +StatusOr> CompilePtx(const string& ptx, int cc_major, + int cc_minor) { + Tracing::TraceMe annotation("Compile PTX", /*is_expensive=*/true); const string ptxas_path = - tensorflow::io::JoinPath(tensorflow::CudaRoot(), "bin/ptxas"); - // Do not log PTX stats if ptxas is not found at the given path. - if (!tensorflow::Env::Default()->FileExists(ptxas_path).ok()) { - LOG(WARNING) - << "Failed to dump PTX stats because ptxas is not found at path \"" - << ptxas_path << "\"."; - return; + tensorflow::io::JoinPath(tensorflow::CudaRoot(), "bin", "ptxas"); + VLOG(2) << "Using ptxas at " << ptxas_path; + auto env = tensorflow::Env::Default(); + TF_RETURN_IF_ERROR(env->FileExists(ptxas_path)); + + // Write ptx into a temporary file. + string ptx_path; + if (!env->LocalTempFilename(&ptx_path)) { + return InternalError("couldn't get temp PTX file name"); } + auto ptx_cleaner = tensorflow::gtl::MakeCleanup([&ptx_path] { + TF_CHECK_OK(tensorflow::Env::Default()->DeleteFile(ptx_path)); + }); - // Write `ptx` into a temporary file. - char tempdir_template[] = "/tmp/ptxXXXXXX"; - char* tempdir_name = mkdtemp(tempdir_template); - CHECK_NOTNULL(tempdir_name); - string ptx_path = tensorflow::io::JoinPath(tempdir_name, "ptx"); - TF_CHECK_OK( - tensorflow::WriteStringToFile(tensorflow::Env::Default(), ptx_path, ptx)); - LOG(INFO) << "ptx file written to: " << ptx_path; + TF_RETURN_IF_ERROR(tensorflow::WriteStringToFile(env, ptx_path, ptx)); + VLOG(2) << "ptx written to: " << ptx_path; // Invoke ptxas and collect its output. + string cubin_path; + if (!env->LocalTempFilename(&cubin_path)) { + return InternalError("couldn't get temp CUBIN file name"); + } + auto cubin_cleaner = tensorflow::gtl::MakeCleanup([&cubin_path] { + // CUBIN file may never be created, so the failure to delete it should not + // produce TF error. + tensorflow::Env::Default()->DeleteFile(cubin_path).IgnoreError(); + }); tensorflow::SubProcess ptxas_info_dumper; - ptxas_info_dumper.SetProgram(ptxas_path, - {ptxas_path, ptx_path, "-o", "/dev/null", "-v", - StrCat("-arch=sm_", cc_major, cc_minor)}); + std::vector ptxas_args = {ptxas_path, ptx_path, "-o", cubin_path, + StrCat("-arch=sm_", cc_major, cc_minor)}; + if (VLOG_IS_ON(2)) { + ptxas_args.push_back("-v"); + } + ptxas_info_dumper.SetProgram(ptxas_path, ptxas_args); ptxas_info_dumper.SetChannelAction(tensorflow::CHAN_STDERR, tensorflow::ACTION_PIPE); if (!ptxas_info_dumper.Start()) { - LOG(ERROR) << "Failed to launch ptxas."; - return; + return InternalError("Failed to launch ptxas"); } string stderr_output; int exit_status = ptxas_info_dumper.Communicate( /*stdin_input=*/nullptr, /*stdout_output=*/nullptr, &stderr_output); XLA_LOG_LINES(tensorflow::INFO, stderr_output); if (exit_status != 0) { - LOG(ERROR) << "ptxas exited with non-zero error code " << exit_status - << "."; + return InternalError("ptxas exited with non-zero error code %d", + exit_status); } + + // Read in the result of compilation and return it as a byte vector. + string cubin; + TF_RETURN_IF_ERROR(tensorflow::ReadFileToString(tensorflow::Env::Default(), + cubin_path, &cubin)); + std::vector cubin_vector(cubin.begin(), cubin.end()); + return cubin_vector; } } // namespace GpuCompiler::GpuCompiler() - : pointer_size_(llvm::DataLayout(kDataLayout).getPointerSize()) {} + : pointer_size_(llvm::DataLayout(kDataLayout) + .getPointerSize(0 /* default address space */)) {} + +StatusOr> GpuCompiler::RunHloPasses( + std::unique_ptr module, se::StreamExecutor* /*stream_exec*/) { + XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunHloPasses"); + Tracing::TraceMe annotation("HLO Transforms", module->name(), + /*is_expensive=*/true); + TF_RETURN_IF_ERROR(OptimizeHloModule(module.get(), ShapeSizeBytesFunction())); + return std::move(module); +} -StatusOr> GpuCompiler::Compile( +StatusOr> GpuCompiler::RunBackend( std::unique_ptr module, se::StreamExecutor* stream_exec) { + XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend"); + TF_RET_CHECK(stream_exec != nullptr); - TF_RETURN_IF_ERROR(OptimizeHloModule(module.get(), - stream_exec->GetDeviceDescription(), - ShapeSizeBytesFunction())); TF_RETURN_IF_ERROR( PrepareHloModuleForIrEmitting(module.get(), ShapeSizeBytesFunction())); @@ -318,7 +351,7 @@ StatusOr> GpuCompiler::Compile( // BufferAssignment::ToString() includes a header, so no need for us to // print one ourselves. XLA_VLOG_LINES(2, buffer_assignment->ToString()); - + XLA_VLOG_LINES(2, module->ToString()); const string xla_dump_hlo_proto_to = module->config().debug_options().xla_dump_hlo_proto_to(); if (!xla_dump_hlo_proto_to.empty()) { @@ -334,8 +367,11 @@ StatusOr> GpuCompiler::Compile( HloComputation* entry_computation = module->entry_computation(); IrEmitterUnnested ir_emitter(module->config(), entry_computation, &ir_emitter_context); - TF_RETURN_IF_ERROR( - entry_computation->root_instruction()->Accept(&ir_emitter)); + { + XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend - IR emission"); + TF_RETURN_IF_ERROR( + entry_computation->root_instruction()->Accept(&ir_emitter)); + } if (user_pre_optimization_hook_) { TF_CHECK_OK(user_pre_optimization_hook_(llvm_module)); @@ -359,12 +395,21 @@ StatusOr> GpuCompiler::Compile( /*optimized=*/false)); } - // Reserve space for the PTX to be generated for this module. - string* ptx; + string libdevice_dir; { tensorflow::mutex_lock lock(mutex_); - generated_ptxes_.emplace_back(MakeUnique()); - ptx = generated_ptxes_.back().get(); + + // Find the directory containing libdevice. To avoid searching for it every + // time, we have a one-element cache, keyed on the module's config's + // cuda_data_dir. + const auto& config_cuda_data_dir = + module->config().debug_options().xla_gpu_cuda_data_dir(); + if (cached_libdevice_dir_.empty() || + cached_cuda_data_dir_ != config_cuda_data_dir) { + cached_cuda_data_dir_ = config_cuda_data_dir; + cached_libdevice_dir_ = GetLibdeviceDir(config_cuda_data_dir); + } + libdevice_dir = cached_libdevice_dir_; } int cc_major, cc_minor; if (!stream_exec->GetDeviceDescription().cuda_compute_capability(&cc_major, @@ -374,12 +419,13 @@ StatusOr> GpuCompiler::Compile( cc_major = 2; cc_minor = 0; } - if (libdevice_dir_.empty()) { - // Compute libdevice_dir_ just once and cache it in this member. - libdevice_dir_ = GetLibdeviceDir(module->config()); + + string ptx; + { + XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend - CompileToPtx"); + TF_ASSIGN_OR_RETURN(ptx, CompileToPtx(&llvm_module, {cc_major, cc_minor}, + module->config(), libdevice_dir)); } - TF_ASSIGN_OR_RETURN(*ptx, CompileToPtx(&llvm_module, {cc_major, cc_minor}, - module->config(), libdevice_dir_)); if (!ir_dump_directory.empty()) { TF_RETURN_IF_ERROR(llvm_ir::DumpIRToDirectory( @@ -394,20 +440,47 @@ StatusOr> GpuCompiler::Compile( VLOG(2) << "LLVM module after optimizations:"; XLA_VLOG_LINES(2, llvm_ir::DumpModuleToString(llvm_module)); VLOG(2) << "PTX:"; - XLA_VLOG_LINES(2, *ptx); - if (VLOG_IS_ON(2)) { - DumpPtxasInfo(*ptx, cc_major, cc_minor); + XLA_VLOG_LINES(2, ptx); + + // Write PTX to IR dump directory, if IR dumping was requested. + if (!ir_dump_directory.empty()) { + const string ptx_outfile = tensorflow::io::JoinPath( + ir_dump_directory, StrCat(module->name(), ".ptx")); + auto status = [&] { + auto* env = tensorflow::Env::Default(); + TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(ir_dump_directory)); + TF_RETURN_IF_ERROR(tensorflow::WriteStringToFile(env, ptx_outfile, ptx)); + return Status::OK(); + }(); + if (!status.ok()) { + LOG(WARNING) << "Couldn't dump PTX for module " << module->name() + << " to " << ptx_outfile << ": " << status; + } } + const std::vector cubin = + CompilePtxOrGetCachedResult(ptx, cc_major, cc_minor); + auto thunk_schedule = MakeUnique( ir_emitter.ConsumeThunkSequence(), std::move(stream_assignment), hlo_schedule->ThunkLaunchOrder()); VLOG(2) << "Printing the thunk schedule..."; XLA_VLOG_LINES(2, thunk_schedule->ToString()); - auto* gpu_executable = - new GpuExecutable(*ptx, std::move(thunk_schedule), std::move(module), - std::move(buffer_assignment), ShapeSizeBytesFunction()); + std::unique_ptr profile_index_map; + std::unique_ptr profile_printer; + + if (module->config().hlo_profiling_enabled()) { + HloCostAnalysis cost_analysis(ShapeSizeBytesFunction()); + profile_index_map = MakeUnique(*module); + profile_printer = + CreateHloProfilePrinter(*profile_index_map, cost_analysis); + } + + auto* gpu_executable = new GpuExecutable( + ptx, cubin, {cc_major, cc_minor}, std::move(thunk_schedule), + std::move(module), std::move(buffer_assignment), + std::move(profile_printer), std::move(profile_index_map)); if (embed_ir_in_executable) { DCHECK_NE("", ir_module_string_before_opt); gpu_executable->set_ir_module_string(ir_module_string_before_opt); @@ -415,11 +488,75 @@ StatusOr> GpuCompiler::Compile( return std::unique_ptr(gpu_executable); } -StatusOr>> GpuCompiler::Compile( - std::vector> modules, - std::vector> stream_execs) { - return Unimplemented( - "Compilation of multiple HLO modules is not yet supported on GPU."); +std::vector GpuCompiler::CompilePtxOrGetCachedResult(const string& ptx, + int cc_major, + int cc_minor) { + XLA_SCOPED_LOGGING_TIMER("GpuCompiler::CompilePtxOrGetCachedResult"); + Tracing::TraceMe annotation("PTX->CUBIN", /*is_expensive=*/true); + bool inserted; + decltype(compilation_cache_.begin()) iter; + // Pointers into compilation_cache_ where the ptx and (optional) cubin are + // stored. + const string* cache_ptx = nullptr; + CompilationCacheValue* cache_value = nullptr; + + { + tensorflow::mutex_lock lock(mutex_); + std::tie(iter, inserted) = compilation_cache_.emplace( + std::piecewise_construct, + std::forward_as_tuple(ptx, cc_major, cc_minor), + std::forward_as_tuple()); + cache_ptx = &iter->first.ptx; + cache_value = &iter->second; + } + + // Compile the ptx if it wasn't in the cache before we called this function. + // Other threads asking for the same compilation key will block on + // cache_value->mutex_ until compilation is done. + { + tensorflow::mutex_lock lock(cache_value->mutex_); + if (inserted) { + CHECK(!cache_value->compilation_done); + if (!ptx.empty()) { + StatusOr> maybe_cubin = + CompilePtx(*cache_ptx, cc_major, cc_minor); + if (maybe_cubin.ok()) { + cache_value->cubin_data = std::move(maybe_cubin).ValueOrDie(); + VLOG(2) << "Compiled PTX size:" << ptx.size() + << " CUBIN size: " << cache_value->cubin_data.size(); + } else { + bool log_warning = true; + if (maybe_cubin.status().code() == + tensorflow::error::Code::NOT_FOUND) { + // Missing ptxas is expected in some environments where CUDA SDK + // binaries are not available. We don't want to spam logs with + // identical warnings in this case. + + // TODO(zhengxq): we should implement a LOG_FIRST_N and LOG_EVERY_N + // for more general usage. + static std::atomic warning_done(false); + log_warning = !warning_done.exchange(true); + } + if (log_warning) { + LOG(WARNING) + << "Failed to compile ptx to cubin. Will attempt to let " + "GPU driver compile the ptx. " + << maybe_cubin.status(); + } + } + } + cache_value->compilation_done = true; + cache_value->compilation_done_cv_.notify_all(); + } else { + while (!cache_value->compilation_done) { + cache_value->compilation_done_cv_.wait(lock); + } + } + } + + CHECK(cache_value != nullptr); + CHECK(cache_value->compilation_done); + return cache_value->cubin_data; } StatusOr>> diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h index 58e835e5ee3f77b7b5cb3579514b7501bed2a2a1..18e34340205b6f51497e26c45520799d21c55a46 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h @@ -26,6 +26,8 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/optional.h" +#include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -40,14 +42,20 @@ class GpuCompiler : public LLVMCompiler { GpuCompiler(); ~GpuCompiler() override {} - StatusOr> Compile( + // Bring in + // StatusOr>> Compile( + // std::vector> modules, + // std::vector> + // stream_execs) + using LLVMCompiler::Compile; + + StatusOr> RunHloPasses( std::unique_ptr module, perftools::gputools::StreamExecutor* stream_exec) override; - StatusOr>> Compile( - std::vector> modules, - std::vector> - stream_execs) override; + StatusOr> RunBackend( + std::unique_ptr module, + perftools::gputools::StreamExecutor* stream_exec) override; StatusOr>> CompileAheadOfTime(std::vector> module, @@ -71,17 +79,72 @@ class GpuCompiler : public LLVMCompiler { static const char* kDataLayout; private: - // The parent directory of libdevice IR libraries. - string libdevice_dir_; + // The size in bytes of a pointer. Used by ShapeSizeBytesFunction. + const int64 pointer_size_; - // The list of PTX strings generated by this GpuCompiler. We let GpuCompiler - // to own them because they need to be alive across the life span of the - // StreamExecutor (b/24776264). tensorflow::mutex mutex_; - std::vector> generated_ptxes_ GUARDED_BY(mutex_); - // The size in bytes of a pointer. Used by ShapeSizeBytesFunction. - int64 pointer_size_; + // When compiling an HLO module, we need to find a path to the nvvm libdevice + // files. We search in the module's config.debug_options().cuda_data_dir() + // and in tensorflow::LibdeviceRoot(), the latter of which is a constant. + // + // We cache the cuda_data_dir() and the result of our search, so that if the + // next module we have to compile has the same cuda_data_dir(), we can skip + // the search. + string cached_cuda_data_dir_ GUARDED_BY(mutex_); + string cached_libdevice_dir_ GUARDED_BY(mutex_); + + // Tries to compile the given ptx string to cubin. Returns a vector with the + // compiled cubin. If compilation was unsuccessful, returns an empty vector. + std::vector CompilePtxOrGetCachedResult(const string& ptx, + int cc_major, int cc_minor); + + // The compilation_cache_ map is a cache from {ptx string, cc_major, cc_minor} + // -> cubin so we don't recompile the same ptx twice. This is important for + // some interactive workflows. (We also cache at the HLO level, but sometimes + // we can't realize that two modules are the same until we lower to ptx.) + // + // Compilation of distinct PTX happens in parallel. If more than one thread + // attempts to compile the same PTX, the fist thread to obtain + // cache_value_->mutex_ performs the compilation. The rest wait() on + // cache_value_->compilation_done_cv_ until the compilation is done. + // + // If compiling the ptx fails, we return an empty cubin, cross our fingers, + // and leave compilation up to the driver. + struct CompilationCacheKey { + CompilationCacheKey(std::string ptx, int cc_major, int cc_minor) + : ptx(std::move(ptx)), cc_major(cc_major), cc_minor(cc_minor) {} + string ptx; + int cc_major; + int cc_minor; + }; + struct CompilationCacheHash { + size_t operator()(const CompilationCacheKey& key) const { + return tensorflow::Hash64Combine( + tensorflow::Hash64Combine(tensorflow::Hash64(key.ptx), key.cc_major), + key.cc_minor); + } + }; + struct CompilationCacheEq { + size_t operator()(const CompilationCacheKey& a, + const CompilationCacheKey& b) const { + return a.cc_major == b.cc_major && a.cc_minor == b.cc_minor && + a.ptx == b.ptx; + } + }; + struct CompilationCacheValue { + bool compilation_done = false; + std::vector cubin_data; + // mutex and condition variable to serialize compilation completing. + tensorflow::mutex mutex_; + tensorflow::condition_variable compilation_done_cv_; + }; + + // Don't even think about switching this to FlatMap; iterator stability is + // critical here. + std::unordered_map + compilation_cache_ GUARDED_BY(mutex_); TF_DISALLOW_COPY_AND_ASSIGN(GpuCompiler); }; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc new file mode 100644 index 0000000000000000000000000000000000000000..33d739b79d3664fec3586bbc924b7fa2e10d3256 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc @@ -0,0 +1,112 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h" + +#include +#include +#include + +#include "tensorflow/compiler/xla/service/call_graph.h" +#include "tensorflow/compiler/xla/service/copy_insertion.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/flatset.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +namespace gpu { + +StatusOr GpuCopyInsertion::FindOrInsertCopy( + HloInstruction* hlo) { + HloInstruction*& copy = inserted_copies_[hlo]; + if (copy == nullptr) { + TF_ASSIGN_OR_RETURN(copy, hlo->parent()->DeepCopyInstruction(hlo)); + } + return copy; +} + +StatusOr GpuCopyInsertion::Run(HloModule* module) { + CopyInsertion generic_copy_insertion; + + TF_ASSIGN_OR_RETURN(bool changed, generic_copy_insertion.Run(module)); + + TF_ASSIGN_OR_RETURN(std::unique_ptr dataflow, + HloDataflowAnalysis::Run(module)); + + // Make sure all operands of a library call are in memory instead of constants + // in IR. + for (HloInstruction* hlo : + module->entry_computation()->MakeInstructionPostOrder()) { + if (ImplementedAsLibraryCall(*hlo)) { + for (int64 i = 0; i < hlo->operand_count(); ++i) { + HloInstruction* operand = hlo->mutable_operand(i); + TF_RET_CHECK(ShapeUtil::IsArray(operand->shape())); + const auto& values = dataflow->GetValueSet(operand).values(); + if (std::any_of(values.begin(), values.end(), + [](const HloValue* value) { + return value->defining_instruction()->opcode() == + HloOpcode::kConstant; + })) { + TF_ASSIGN_OR_RETURN(HloInstruction * copy, FindOrInsertCopy(operand)); + TF_RETURN_IF_ERROR(hlo->ReplaceOperandWith(i, copy)); + changed = true; + } + } + } + } + + // Init values of a while node cannot be constants. Insert copies for any + // constants found at the operand of a while. + tensorflow::gtl::FlatSet copied_constants; + for (HloComputation* computation : module->computations()) { + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->opcode() != HloOpcode::kWhile) { + continue; + } + for (auto& pair : + dataflow->GetInstructionValueSet(instruction->operand(0))) { + const HloValueSet& value_set = pair.second; + for (const HloValue* value : value_set.values()) { + if (value->defining_instruction()->opcode() == + HloOpcode::kConstant && + !ContainsKey(copied_constants, value->defining_instruction())) { + HloInstruction* constant = value->defining_instruction(); + TF_ASSIGN_OR_RETURN(HloInstruction * copy, + FindOrInsertCopy(constant)); + TF_RETURN_IF_ERROR(constant->ReplaceAllUsesWith(copy)); + copied_constants.insert(constant); + changed = true; + } + } + } + } + } + + // The GPU backend needs additional copies added due to deficiencies in + // buffer assignment. + TF_ASSIGN_OR_RETURN(bool buffer_assignment_changed, + CopyInsertion::AddCopiesForBufferAssignment(module)); + + return changed || buffer_assignment_changed; +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/copy_insertion.h b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h similarity index 56% rename from tensorflow/compiler/xla/service/gpu/copy_insertion.h rename to tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h index 11077dad2e5506eab4fa84d47ad13a26ed1c035a..4d77f337e6eb20f7d79acc0829fde26bbe443f25 100644 --- a/tensorflow/compiler/xla/service/gpu/copy_insertion.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h @@ -13,11 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_COPY_INSERTION_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_COPY_INSERTION_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_COPY_INSERTION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_COPY_INSERTION_H_ -#include "tensorflow/compiler/xla/service/copy_insertion.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" namespace xla { namespace gpu { @@ -25,12 +25,23 @@ namespace gpu { // Besides the modifications made by the generic xla::CopyInsertion, this // GPU-specific copy insertion also materializes operands of library calls by // inserting kCopy instructions. -class GpuCopyInsertion : public CopyInsertion { +class GpuCopyInsertion : public HloPassInterface { public: + tensorflow::StringPiece name() const override { return "copy-insertion"; } + StatusOr Run(HloModule* module) override; + + protected: + // Returns a copy of `hlo`. Looks in inserted_copies_ first to avoid making + // duplicate copies. + StatusOr FindOrInsertCopy(HloInstruction* hlo); + + // A map containing all copies inserted to materialize operands of library + // calls. The key is the copied instruction and the value is the copy. + tensorflow::gtl::FlatMap inserted_copies_; }; } // namespace gpu } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_COPY_INSERTION_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_COPY_INSERTION_H_ diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc index 2c4d5150741d75ec2d1cb7e3d41c07ad24f800b0..21e9fc96f61c4f84490fb4d21748e58272564048 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc @@ -88,7 +88,7 @@ class HloExecutionProfiler { if (do_profile_) { stream_->ThenStopTimer(per_op_timer_.get()); stream_->BlockHostUntilDone(); - profile_->AddProfileResult( + profile_->SetCyclesTakenBy( hlo_instruction, per_op_timer_->Nanoseconds() * clock_rate_ghz_); } } @@ -108,16 +108,20 @@ class HloExecutionProfiler { // Implementation note: HLO profiling is always enabled for GPU executables, // since we can use timers around thunks. GpuExecutable::GpuExecutable( - tensorflow::StringPiece ptx, + const string& ptx, const std::vector& cubin, + std::pair compute_capability, std::unique_ptr thunk_schedule, std::unique_ptr hlo_module, std::unique_ptr assignment, - HloCostAnalysis::ShapeSizeFunction shape_size_function) - : Executable(std::move(hlo_module)), + std::unique_ptr hlo_profile_printer, + std::unique_ptr hlo_profile_index_map) + : Executable(std::move(hlo_module), std::move(hlo_profile_printer), + std::move(hlo_profile_index_map)), ptx_(ptx), + cubin_(cubin), + compute_capability_(compute_capability), thunk_schedule_(std::move(thunk_schedule)), - assignment_(std::move(assignment)), - shape_size_function_(std::move(shape_size_function)) {} + assignment_(std::move(assignment)) {} Status GpuExecutable::ExecuteThunks( const ServiceExecutableRunOptions* run_options, @@ -125,6 +129,16 @@ Status GpuExecutable::ExecuteThunks( HloExecutionProfile* hlo_execution_profile) { se::Stream* main_stream = run_options->stream(); + std::pair stream_compute_compatibility; + main_stream->parent()->GetDeviceDescription().cuda_compute_capability( + &stream_compute_compatibility.first, + &stream_compute_compatibility.second); + TF_RET_CHECK(stream_compute_compatibility == compute_capability_) + << "Compute capability mismatch; expected {" << compute_capability_.first + << ", " << compute_capability_.second << "}, but was {" + << stream_compute_compatibility.first << ", " + << stream_compute_compatibility.second << "}"; + bool do_profile = hlo_execution_profile != nullptr; if (do_profile) { LOG(WARNING) << "PROFILING: profiling is enabled"; @@ -153,9 +167,16 @@ Status GpuExecutable::ExecuteThunks( stream->ThenWaitFor(FindOrDie(thunk_to_finish_event, dependency).get()); } + // If this thunk requests it, wait for all currently-executing thunks to + // finish. This is useful e.g. if the thunk is about to perform autotuning. + if (thunk->ShouldHaltAllActivityBeforeRunning(stream)) { + main_stream->BlockHostUntilDone(); + } + profiler.StartOperation(); VLOG(2) << "Executing the thunk for " - << thunk->hlo_instruction()->ToString(); + << thunk->hlo_instruction()->ToString() << " on stream " + << stream_no; TF_RETURN_IF_ERROR(thunk->ExecuteOnStream(buffer_allocations, stream)); if (thunk_schedule_->Depended(thunk)) { auto finish_event = MakeUnique(main_stream->parent()); @@ -345,9 +366,5 @@ const PointsToSet& GpuExecutable::GetRootPointsToSet() const { module().entry_computation()->root_instruction()); } -std::unique_ptr GpuExecutable::CreateCostAnalysis() const { - return MakeUnique(shape_size_function_); -} - } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.h b/tensorflow/compiler/xla/service/gpu/gpu_executable.h index 748a8f521bc5293d58de19ab52f4bdecec6cb1e5..e7307e07c0b5608e31f15597d31d11c50f81c6d5 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.h @@ -47,11 +47,15 @@ namespace gpu { // This is an immutable data type after initialization, and thus thread safe. class GpuExecutable : public Executable { public: - GpuExecutable(tensorflow::StringPiece ptx, + // cubin (i.e. the compiled ptx) may be empty, in which case we leave + // compilation up to the GPU driver. + GpuExecutable(const string& ptx, const std::vector& cubin, + std::pair compute_capability, std::unique_ptr thunk_schedule, std::unique_ptr hlo_module, std::unique_ptr assignment, - HloCostAnalysis::ShapeSizeFunction shape_size_function); + std::unique_ptr hlo_profile_printer, + std::unique_ptr hlo_profile_index_map); // This should be called after set_ir_module_string. const string& ir_module_string() const { return ir_module_string_; } @@ -64,6 +68,13 @@ class GpuExecutable : public Executable { // Returns the compiled PTX for the computation. tensorflow::StringPiece ptx() const { return ptx_; } + // Returns the cubin (compiled PTX) stored in this GpuExecutable. May be + // empty, in which case compilation is left up to the GPU driver. + const std::vector& cubin() const { return cubin_; } + + // Both overloads of ExecuteOnStream will fail if the compute capability of + // the stream doesn't match the compute capability passed to this object's + // constructor. StatusOr ExecuteOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice @@ -85,8 +96,6 @@ class GpuExecutable : public Executable { return Unimplemented("Equality test on GPU executable is not implemented."); } - std::unique_ptr CreateCostAnalysis() const override; - private: // If `block_host_until_done` is false, execution will not block the host // until the kernels have completed. This is used as an optimization for @@ -110,8 +119,17 @@ class GpuExecutable : public Executable { // This string should be modified only before ExecuteOnStream. string ir_module_string_; - // The reference to the compiled PTX for the computation. - const tensorflow::StringPiece ptx_; + // The PTX for the computation. + const string ptx_; + + // The GPU machine code for the computation, targeting GPUs at + // compute_capability_. + // + // May be empty, in which case we leave compilation up to the GPU driver. + const std::vector cubin_; + + // The compute capability of the GPU we're targeting with this GpuExecutable. + std::pair compute_capability_; // The thunks to be invoked by this GpuExecutable. They are generated by the // IrEmitter. @@ -121,9 +139,6 @@ class GpuExecutable : public Executable { // memory for every output/temp buffers. const std::unique_ptr assignment_; - // Function to compute the size of a given Shape, in bytes. - const HloCostAnalysis::ShapeSizeFunction shape_size_function_; - TF_DISALLOW_COPY_AND_ASSIGN(GpuExecutable); }; diff --git a/tensorflow/compiler/xla/service/gpu/layout_assignment.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc similarity index 93% rename from tensorflow/compiler/xla/service/gpu/layout_assignment.cc rename to tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc index 0bbd63fb7bfc657cb7bb1de673253c198f5bd25f..50a249f448e7b4956e7bf6bd603d256eca88f71d 100644 --- a/tensorflow/compiler/xla/service/gpu/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/gpu/layout_assignment.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h" #include @@ -80,9 +80,9 @@ Status GpuLayoutAssignment::AddBackendConstraints( const ConvolutionDimensionNumbers& dimension_numbers = instruction->convolution_dimension_numbers(); std::vector input_layout; - for (int i = dimension_numbers.spatial_dimensions_size() - 1; i >= 0; - --i) { - input_layout.push_back(dimension_numbers.spatial_dimensions(i)); + for (int i = dimension_numbers.input_spatial_dimensions_size() - 1; + i >= 0; --i) { + input_layout.push_back(dimension_numbers.input_spatial_dimensions(i)); } input_layout.push_back(dimension_numbers.input_feature_dimension()); input_layout.push_back(dimension_numbers.input_batch_dimension()); @@ -102,9 +102,9 @@ Status GpuLayoutAssignment::AddBackendConstraints( *filter_shape.mutable_layout() = LayoutUtil::MakeLayout(filter_layout); std::vector output_layout; - for (int i = dimension_numbers.spatial_dimensions_size() - 1; i >= 0; - --i) { - output_layout.push_back(dimension_numbers.spatial_dimensions(i)); + for (int i = dimension_numbers.output_spatial_dimensions_size() - 1; + i >= 0; --i) { + output_layout.push_back(dimension_numbers.output_spatial_dimensions(i)); } output_layout.push_back(dimension_numbers.output_feature_dimension()); output_layout.push_back(dimension_numbers.output_batch_dimension()); diff --git a/tensorflow/compiler/xla/service/gpu/layout_assignment.h b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h similarity index 86% rename from tensorflow/compiler/xla/service/gpu/layout_assignment.h rename to tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h index 169041eb85c633cb4f1f679bcea127714828308f..7655a3ebf45f83c0125a4257baae7a7229ebdc6d 100644 --- a/tensorflow/compiler/xla/service/gpu/layout_assignment.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_LAYOUT_ASSIGNMENT_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_LAYOUT_ASSIGNMENT_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_LAYOUT_ASSIGNMENT_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_LAYOUT_ASSIGNMENT_H_ #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/layout_assignment.h" @@ -38,4 +38,4 @@ class GpuLayoutAssignment : public LayoutAssignment { } // namespace gpu } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_LAYOUT_ASSIGNMENT_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_LAYOUT_ASSIGNMENT_H_ diff --git a/tensorflow/compiler/xla/service/gpu/layout_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc similarity index 97% rename from tensorflow/compiler/xla/service/gpu/layout_assignment_test.cc rename to tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc index ac206b89d329d7e4ac91ee51162c9694f6899d78..f68b23c8ce969372a01ce77840e016d82ca5d2ed 100644 --- a/tensorflow/compiler/xla/service/gpu/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/gpu/layout_assignment.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/service/computation_layout.h" diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc index f0f036f7f381db15b84db85d3efeec5d8141884e..4cf49d4a723fd2223564afb86f003901f9712b39 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc @@ -44,7 +44,7 @@ GpuTransferManager::GpuTransferManager() : GenericTransferManager( se::cuda::kCudaPlatformId, /*pointer_size=*/llvm::DataLayout(gpu::GpuCompiler::kDataLayout) - .getPointerSize()) {} + .getPointerSize(0 /* default address space */)) {} Status GpuTransferManager::TransferLiteralToInfeed(se::StreamExecutor* executor, const Literal& literal) { diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc index 163a161353fdb90cee2968269d572b8414855551..c2115c49993ef71c4b6dd584e7e0498807666613 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc @@ -166,11 +166,46 @@ void HloToIrBindings::BindHloToIrValue(const HloInstruction& hlo, *(base_ptrs_[&hlo].mutable_element(shape_index)) = typed_ir_value; } +// Determines whether hlo's buffers are never modified within the execution of +// consumer. +static bool BuffersInvariantWithinConsumer( + const HloInstruction& hlo, const HloInstruction& consumer, + const BufferAssignment* buffer_assignment) { + // Check if consumer is inside a fusion node -- if so, "dereference" it until + // we get to a non-fusion node. + const HloInstruction* c = &consumer; + while (c->IsFused()) { + c = c->parent()->FusionInstruction(); + } + + // If, after dereferencing c, we end up with a node that's not inside our + // module's top-level computation (say our node is inside a while loop), we + // give up on marking array as invariant, because this HLO may be run multiple + // times (e.g. multiple while loop iterations, or multiple invocations of a + // reducer's computation). TODO(jlebar): We could relax this constraint if we + // emitted an llvm.invariant.group.barrier at the end of the computation. + return c->parent() == c->GetModule()->entry_computation() && + buffer_assignment->HaveDisjointSlices(&hlo, &consumer); +} + llvm_ir::IrArray HloToIrBindings::GetIrArray(const HloInstruction& hlo, + const HloInstruction& consumer, const ShapeIndex& shape_index) { llvm_ir::IrArray ir_array(GetBasePointer(hlo, shape_index), ShapeUtil::GetSubshape(hlo.shape(), shape_index)); alias_analysis_.AddAliasingInformationToIrArray(hlo, &ir_array); + + // The GPU backend emits one kernel per top-level HLO, and LLVM views + // execution of one kernel as the "whole program" executed on the GPU. + // Therefore if hlo's output buffer is not modified within consumer, and if + // consumer runs hlo only once (so that it doesn't create two different + // outputs), then we can mark ir_array as invariant over the whole program. + if (BuffersInvariantWithinConsumer(hlo, consumer, buffer_assignment_)) { + VLOG(2) << "Marking " << hlo.name() << " as invariant within " + << consumer.name(); + ir_array.MarkInvariantOverWholeProgram(&module_->getContext()); + } + return ir_array; } diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h index a3120f15bcbfb0f2f0bfbd806e7a4ff05316d5dd..62ae1769a1f2fb3b9acaf35bdf18a793232500b0 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h +++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h @@ -76,8 +76,15 @@ class HloToIrBindings { return it->second.element(shape_index); } - // Return the underlying IrArray of the output of the given instruction. + // Returns the IrArray which contains the output of hlo. + // + // consumer is the HLO in which this IrArray is used -- we use this to (try + // to) add metadata indicating that the array is invariant within consumer. + // + // To get the buffer into which hlo should write its own output, call + // GetIrArray(hlo, hlo). llvm_ir::IrArray GetIrArray(const HloInstruction& hlo, + const HloInstruction& consumer, const ShapeIndex& shape_index = {}); private: diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc index 9a4bfd0905bb62c02c70e7f2eea46872c07bca89..1d47ffde4331868cbc8a8afb2d01b11e77a7fab0 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc @@ -156,8 +156,10 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfConvolutionUnfused) { conv_dnums.set_output_batch_dimension(0); conv_dnums.set_input_feature_dimension(1); conv_dnums.set_output_feature_dimension(1); - conv_dnums.add_spatial_dimensions(2); - conv_dnums.add_spatial_dimensions(3); + conv_dnums.add_input_spatial_dimensions(2); + conv_dnums.add_output_spatial_dimensions(2); + conv_dnums.add_input_spatial_dimensions(3); + conv_dnums.add_output_spatial_dimensions(3); conv_dnums.set_kernel_output_feature_dimension(0); conv_dnums.set_kernel_input_feature_dimension(1); conv_dnums.add_kernel_spatial_dimensions(2); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index 8fb7a6adda9dc7c36eb9aabcbcdc9d77e6c22c4a..658fd05cd4b63c923d21b4a1de16468c0aeec65d 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -100,7 +100,7 @@ bool ImplementedAsDnnConvolution(const HloInstruction& hlo) { if (hlo.opcode() == HloOpcode::kConvolution) { const ConvolutionDimensionNumbers& dnums = hlo.convolution_dimension_numbers(); - if (dnums.spatial_dimensions_size() > 3) { + if (dnums.input_spatial_dimensions_size() > 3) { return false; } diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index 57a3f713e35b506ad9d5caab1ced2c7b74f8efcf..f64e93024fe134e585411f555810711763f6fcb5 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -68,7 +68,8 @@ Status IrEmitter::DefaultAction(HloInstruction* hlo) { ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator; for (const HloInstruction* operand : hlo->operands()) { operand_to_generator[operand] = [=](const llvm_ir::IrArray::Index& index) { - return GetIrArray(*operand).EmitReadArrayElement(index, &ir_builder_); + return GetIrArray(*operand, *hlo) + .EmitReadArrayElement(index, &ir_builder_); }; } return EmitTargetElementLoop( @@ -128,16 +129,25 @@ Status IrEmitter::HandleSend(HloInstruction*) { return Unimplemented("Send is not implemented on GPU"); } +Status IrEmitter::HandleSendDone(HloInstruction*) { + return Unimplemented("Send-Done is not implemented on GPU"); +} + Status IrEmitter::HandleRecv(HloInstruction*) { return Unimplemented("Recv is not implemented on GPU"); } +Status IrEmitter::HandleRecvDone(HloInstruction*) { + return Unimplemented("Recv-done is not implemented on GPU"); +} + Status IrEmitter::HandleTuple(HloInstruction* tuple) { std::vector base_ptrs; for (const HloInstruction* operand : tuple->operands()) { base_ptrs.push_back(GetBasePointer(*operand)); } - llvm_ir::EmitTuple(GetIrArray(*tuple), base_ptrs, &ir_builder_, module_); + llvm_ir::EmitTuple(GetIrArray(*tuple, *tuple), base_ptrs, &ir_builder_, + module_); return Status::OK(); } @@ -163,7 +173,7 @@ Status IrEmitter::EmitCallToNestedComputation( return Status::OK(); } -bool IrEmitter::MaybeEmitSpecialAtomicOperation( +bool IrEmitter::MaybeEmitDirectAtomicOperation( const HloComputation& computation, llvm::Value* output_address, llvm::Value* source_address) { CHECK_EQ(2, computation.num_parameters()); @@ -223,101 +233,189 @@ bool IrEmitter::MaybeEmitSpecialAtomicOperation( return false; } -Status IrEmitter::EmitAtomicOperationForNestedComputation( - const HloComputation& computation, llvm::Value* output_address, - llvm::Value* source_address) { - if (computation.num_parameters() != 2) { - // TODO(b/30258929): We only accept binary computations so far. - return Unimplemented( - "We only support atomic functions with exactly two parameters, but " - "computation %s has %lld.", - computation.name().c_str(), computation.num_parameters()); - } - - if (MaybeEmitSpecialAtomicOperation(computation, output_address, - source_address)) { - return Status::OK(); - } +// Implements atomic binary operations using atomic compare-and-swap +// (atomicCAS) as follows: +// 1. Reads the value from the memory pointed to by output_address and +// records it as old_output. +// 2. Uses old_output as one of the source operand to perform the binary +// operation and stores the result in new_output. +// 3. Calls atomicCAS which implements compare-and-swap as an atomic +// operation. In particular, atomicCAS reads the value from the memory +// pointed to by output_address, and compares the value with old_output. If +// the two values equal, new_output is written to the same memory location +// and true is returned to indicate that the atomic operation succeeds. +// Otherwise, the new value read from the memory is returned. In this case, +// the new value is copied to old_output, and steps 2. and 3. are repeated +// until atomicCAS succeeds. +// +// On Nvidia GPUs, atomicCAS can only operate on 32 bit and 64 bit integers. If +// the element type of the binary operation is 32 bits or 64 bits, the integer +// type of the same size is used for the atomicCAS operation. On the other hand, +// if the element type is smaller than 32 bits, int32 is used for the atomicCAS +// operation. In this case, atomicCAS reads and writes 32 bit values from +// the memory, which is larger than the memory size required by the original +// atomic binary operation. We mask off the last two bits of the output_address +// and use the result as an address to read the 32 bit values from the memory. +// This can avoid out of bound memory accesses if tensor buffers are 4 byte +// aligned and have a size of 4N, an assumption that the runtime can guarantee. +// +// The pseudo code is shown below. Variables *_address are pointers to a memory +// region with a size equal to the size of the atomicCAS operation, with the +// exception that new_output_address is a pointer to a memory region with a size +// equal to the element size of the binary operation. +// +// element_size = sizeof(element_type); +// atomic_size = max(32, element_size); +// cas_new_output_address = alloca(atomic_size); +// cas_old_output_address = alloca(atomic_size); +// if (atomic_size != element_size) { +// atomic_address = output_address & ((int64)(-2)); +// new_output_address = cas_new_output_address + (output_address & 3); +// } else { +// atomic_address = output_address; +// new_output_address = cas_new_output_address; +// } +// +// *cas_old_output_address = *atomic_address; +// do { +// *cas_new_output_address = *cas_old_output_address; +// *new_output_address = operation(*new_output_address, *source_address); +// (*cas_old_output_address, success) = +// atomicCAS(atomic_address, *cas_old_output_address, +// *cas_new_output_address); +// } while (!success); +// +Status IrEmitter::EmitAtomicOperationUsingCAS(const HloComputation& computation, + llvm::Value* output_address, + llvm::Value* source_address) { + llvm::PointerType* output_address_type = + llvm::dyn_cast(output_address->getType()); + CHECK_NE(output_address_type, nullptr); + + // element_type is the data type for the binary operation. + llvm::Type* element_type = output_address_type->getPointerElementType(); + int element_size = llvm_ir::GetSizeInBits(element_type); + llvm::Type* element_address_type = element_type->getPointerTo(); + + int atomic_size = (element_size < 32) ? 32 : element_size; + llvm::Type* atomic_type = ir_builder_.getIntNTy(atomic_size); + llvm::Type* atomic_address_type = + atomic_type->getPointerTo(output_address_type->getPointerAddressSpace()); + + // cas_old_output_address and cas_new_output_address point to the scratch + // memory where we store the old and new values for the repeated atomicCAS + // operations. + llvm::Value* cas_old_output_address = ir_builder_.CreateAlloca( + atomic_type, /*ArraySize=*/nullptr, "cas_old_output_address"); + llvm::Value* cas_new_output_address = ir_builder_.CreateAlloca( + atomic_type, /*ArraySize=*/nullptr, "cas_new_output_address"); - // Other binary computations can be made atomic as following (labels are basic - // block names used in the IR emitting code later). - // - // atomic_op_loop_preheader: - // ... - // source = *source_address; - // old_output = *output_address; - // do { - // atomic_op_loop_body_entry: - // new_output = computation(old_output, source); - // (old_output, success) = - // atomicCAS(output_address, old_output, new_output); - // } while (!success); - // - // atomic_op_loop_exit: - // ... - // - // TODO(jingyue): Consider encapsulate the logic of emitting control flow to - // something similar to llvm_ir::ForLoop. - // // Emit preparation code to the preheader. llvm::BasicBlock* loop_preheader_bb = ir_builder_.GetInsertBlock(); - llvm::Type* element_ir_type = - output_address->getType()->getPointerElementType(); - // old_output = *output_address; - llvm::Value* old_output_location = ir_builder_.CreateAlloca( - element_ir_type, /*ArraySize=*/nullptr, "old_output_location"); - ir_builder_.CreateStore(ir_builder_.CreateLoad(output_address, "old_output"), - old_output_location); + + llvm::Value* atomic_memory_address; + // binop_output_address points to the scratch memory that stores the + // result of the binary operation. + llvm::Value* binop_output_address; + if (element_size < 32) { + // Assume the element size is an integer number of bytes. + CHECK_EQ((element_size % sizeof(char)), 0); + llvm::Type* address_int_type = + module_->getDataLayout().getIntPtrType(output_address_type); + atomic_memory_address = + ir_builder_.CreatePtrToInt(output_address, address_int_type); + llvm::Value* mask = llvm::ConstantInt::get(address_int_type, 3); + llvm::Value* offset = ir_builder_.CreateAnd(atomic_memory_address, mask); + mask = llvm::ConstantInt::get(address_int_type, -2); + atomic_memory_address = ir_builder_.CreateAnd(atomic_memory_address, mask); + atomic_memory_address = + ir_builder_.CreateIntToPtr(atomic_memory_address, atomic_address_type); + binop_output_address = ir_builder_.CreateAdd( + ir_builder_.CreatePtrToInt(cas_new_output_address, address_int_type), + offset); + binop_output_address = + ir_builder_.CreateIntToPtr(binop_output_address, element_address_type); + } else { + atomic_memory_address = + ir_builder_.CreateBitCast(output_address, atomic_address_type); + binop_output_address = + ir_builder_.CreateBitCast(cas_new_output_address, element_address_type); + } + + // Use the value from the memory that atomicCAS operates on to initialize + // cas_old_output. + llvm::Value* cas_old_output = + ir_builder_.CreateLoad(atomic_memory_address, "cas_old_output"); + ir_builder_.CreateStore(cas_old_output, cas_old_output_address); + llvm::BasicBlock* loop_exit_bb = loop_preheader_bb->splitBasicBlock( ir_builder_.GetInsertPoint(), "atomic_op_loop_exit"); - - // Emit the body of the loop that repeatedly invokes atomicCAS. llvm::BasicBlock* loop_body_bb = llvm::BasicBlock::Create(ir_builder_.getContext(), "atomic_op_loop_body", ir_builder_.GetInsertBlock()->getParent()); ir_builder_.SetInsertPoint(loop_body_bb); // Change preheader's successor from loop_exit_bb to loop_body_bb. loop_preheader_bb->getTerminator()->setSuccessor(0, loop_body_bb); - // new_output = computation(old_output, source); - llvm::Value* new_output_location = ir_builder_.CreateAlloca( - element_ir_type, /*ArraySize=*/nullptr, "new_output_location"); + + // Emit the body of the loop that repeatedly invokes atomicCAS. + // + // Use cas_old_output to initialize cas_new_output. + cas_old_output = + ir_builder_.CreateLoad(cas_old_output_address, "cas_old_output"); + ir_builder_.CreateStore(cas_old_output, cas_new_output_address); + // Emits code to calculate new_output = operation(old_output, source); TF_RETURN_IF_ERROR(EmitCallToNestedComputation( - computation, {old_output_location, source_address}, new_output_location)); - - // (old_output, success) = atomicCAS(output_address, old_output, new_output); - llvm::Type* element_int_ir_type = - ir_builder_.getIntNTy(element_ir_type->getScalarSizeInBits()); - // cmpxchg accetps integer only, so we bitcast the operands (old_output and - // new_output) to integers of the same bit width, and bitcast the result - // back to the original element type. - llvm::Value* old_output = - ir_builder_.CreateLoad(old_output_location, "old_output"); - llvm::Value* new_output = - ir_builder_.CreateLoad(new_output_location, "new_output"); + computation, {binop_output_address, source_address}, + binop_output_address)); + + llvm::Value* cas_new_output = + ir_builder_.CreateLoad(cas_new_output_address, "cas_new_output"); + + // Emit code to perform the atomicCAS operation + // (cas_old_output, success) = atomicCAS(memory_address, cas_old_output, + // cas_new_output); llvm::Value* ret_value = ir_builder_.CreateAtomicCmpXchg( - ir_builder_.CreateBitCast(output_address, - element_int_ir_type->getPointerTo()), - ir_builder_.CreateBitCast(old_output, element_int_ir_type), - ir_builder_.CreateBitCast(new_output, element_int_ir_type), + atomic_memory_address, cas_old_output, cas_new_output, llvm::AtomicOrdering::SequentiallyConsistent, llvm::AtomicOrdering::SequentiallyConsistent); - // cmpxchg returns a pair. The first element is the original value at - // output_address and the second element is whether the swap is successful. + + // Extract the memory value returned from atomicCAS and store it as + // cas_old_output. ir_builder_.CreateStore( - ir_builder_.CreateBitCast( - ir_builder_.CreateExtractValue(ret_value, 0, "old_output"), - element_ir_type), - old_output_location); + ir_builder_.CreateExtractValue(ret_value, 0, "cas_old_output"), + cas_old_output_address); + // Extract the success bit returned from atomicCAS and generate a + // conditional branch on the success bit. ir_builder_.CreateCondBr( ir_builder_.CreateExtractValue(ret_value, 1, "success"), loop_exit_bb, loop_body_bb); - // Restore the insertion point to the exit basic block so that the caller of + // Set the insertion point to the exit basic block so that the caller of // this method can continue emitting code to the right place. SetToFirstInsertPoint(loop_exit_bb, &ir_builder_); return Status::OK(); } +Status IrEmitter::EmitAtomicOperationForNestedComputation( + const HloComputation& computation, llvm::Value* output_address, + llvm::Value* source_address) { + if (computation.num_parameters() != 2) { + // TODO(b/30258929): We only accept binary computations so far. + return Unimplemented( + "We only support atomic functions with exactly two parameters, but " + "computation %s has %lld.", + computation.name().c_str(), computation.num_parameters()); + } + + if (MaybeEmitDirectAtomicOperation(computation, output_address, + source_address)) { + return Status::OK(); + } + + return EmitAtomicOperationUsingCAS(computation, output_address, + source_address); +} + Status IrEmitter::HandleSelect(HloInstruction* select) { auto pred = select->operand(0); auto on_true = select->operand(1); @@ -325,7 +423,8 @@ Status IrEmitter::HandleSelect(HloInstruction* select) { TF_RET_CHECK(pred->shape().element_type() == PRED); if (ShapeUtil::IsTuple(select->shape())) { - llvm_ir::EmitTupleSelect(GetIrArray(*select), GetIrArray(*pred), + llvm_ir::EmitTupleSelect(GetIrArray(*select, *select), + GetIrArray(*pred, *select), GetBasePointer(*on_true), GetBasePointer(*on_false), &ir_builder_, module_); return Status::OK(); @@ -340,9 +439,9 @@ Status IrEmitter::HandleSelect(HloInstruction* select) { Status IrEmitter::HandleDot(HloInstruction* dot) { auto lhs_instruction = dot->operand(0); auto rhs_instruction = dot->operand(1); - const llvm_ir::IrArray& target_array = GetIrArray(*dot); - const llvm_ir::IrArray& lhs_array = GetIrArray(*lhs_instruction); - const llvm_ir::IrArray& rhs_array = GetIrArray(*rhs_instruction); + const llvm_ir::IrArray& target_array = GetIrArray(*dot, *dot); + const llvm_ir::IrArray& lhs_array = GetIrArray(*lhs_instruction, *dot); + const llvm_ir::IrArray& rhs_array = GetIrArray(*rhs_instruction, *dot); const Shape& lhs_shape = lhs_instruction->shape(); const Shape& rhs_shape = rhs_instruction->shape(); @@ -562,7 +661,8 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce) { // Apply the reduction function to the loaded value. llvm::Value* input_address = - GetIrArray(*arg).EmitArrayElementAddress(input_index, &ir_builder_); + GetIrArray(*arg, *reduce) + .EmitArrayElementAddress(input_index, &ir_builder_); TF_RETURN_IF_ERROR(EmitCallToNestedComputation( *function, {accumulator_addr, input_address}, accumulator_addr)); @@ -578,7 +678,7 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) { std::vector parameter_arrays; for (HloInstruction* operand : fusion->operands()) { - parameter_arrays.push_back(GetIrArray(*operand)); + parameter_arrays.push_back(GetIrArray(*operand, *fusion)); } GpuElementalIrEmitter elemental_emitter(hlo_module_config_, module_, &ir_builder_, GetNestedComputer()); @@ -613,7 +713,8 @@ Status IrEmitter::HandleRng(HloInstruction* random) { ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator; for (const HloInstruction* operand : random->operands()) { operand_to_generator[operand] = [=](const llvm_ir::IrArray::Index& index) { - return GetIrArray(*operand).EmitReadArrayElement(index, &ir_builder_); + return GetIrArray(*operand, *random) + .EmitReadArrayElement(index, &ir_builder_); }; } // Emits a single-threaded loop because the loop body generated by the element @@ -622,10 +723,41 @@ Status IrEmitter::HandleRng(HloInstruction* random) { GpuElementalIrEmitter(hlo_module_config_, module_, &ir_builder_, GetNestedComputer()) .MakeElementGenerator(random, operand_to_generator), - GetIrArray(*random), &ir_builder_) + GetIrArray(*random, *random), &ir_builder_) .EmitLoop(IrName(random)); } +Status IrEmitter::HandleConditional(HloInstruction* conditional) { + auto pred = conditional->operand(0); + auto true_arg = conditional->operand(1); + auto false_arg = conditional->operand(2); + + llvm::Value* conditional_result = GetBasePointer(*conditional); + + llvm::LoadInst* pred_value = ir_builder_.CreateLoad( + GetBasePointer(*pred), + llvm_ir::AsStringRef(IrName(conditional, "load_predicate_value"))); + llvm::Value* pred_cond = ir_builder_.CreateICmpNE( + pred_value, + llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0), + llvm_ir::AsStringRef(IrName(conditional, "boolean_predicate"))); + llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( + pred_cond, IrName(conditional, "if_then_else"), &ir_builder_); + + SetToFirstInsertPoint(if_data.true_block, &ir_builder_); + TF_RETURN_IF_ERROR(EmitCallToNestedComputation( + *conditional->true_computation(), {GetBasePointer(*true_arg)}, + conditional_result)); + + SetToFirstInsertPoint(if_data.false_block, &ir_builder_); + TF_RETURN_IF_ERROR(EmitCallToNestedComputation( + *conditional->false_computation(), {GetBasePointer(*false_arg)}, + conditional_result)); + + SetToFirstInsertPoint(if_data.after_block, &ir_builder_); + return Status::OK(); +} + llvm_ir::IrArray::Index IrEmitter::EmitOperandArrayLoopNest( const llvm_ir::IrArray& operand_array, int64 reduction_dimension, tensorflow::StringPiece name_suffix, llvm_ir::ForLoopNest* loop_nest) { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h index 263992d92544166c0d08a6c60b43e78f10f06aed..08bbbe36c72872ba68104c8f328c2f602eb30fa8 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h @@ -84,7 +84,9 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status HandleOutfeed(HloInstruction* outfeed) override; Status HandleSort(HloInstruction* sort) override; Status HandleSend(HloInstruction* send) override; + Status HandleSendDone(HloInstruction* send_done) override; Status HandleRecv(HloInstruction* recv) override; + Status HandleRecvDone(HloInstruction* recv_done) override; Status HandleParameter(HloInstruction* parameter) override; Status HandleReduce(HloInstruction* reduce) override; Status HandleTuple(HloInstruction* tuple) override; @@ -93,6 +95,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status HandleCall(HloInstruction* call) override; Status HandleCustomCall(HloInstruction* custom_call) override; Status HandleRng(HloInstruction* random) override; + Status HandleConditional(HloInstruction* conditional) override; Status FinishVisit(HloInstruction* root) override { return Status::OK(); } @@ -103,10 +106,16 @@ class IrEmitter : public DfsHloVisitorWithDefault { explicit IrEmitter(const HloModuleConfig& hlo_module_config, IrEmitterContext* ir_emitter_context, bool is_nested); - // A convenient helper for calling HloToIrBindings::GetIrArray. + // Helper for calling HloToIrBindings::GetIrArray. + // + // Gets the IrArray which contains inst. This array has metadata that makes + // it valid only within the IR that implements consumer. If you are + // implementing an HLO and want to get its own output buffer, call + // GetIrArray(hlo, hlo). llvm_ir::IrArray GetIrArray(const HloInstruction& inst, + const HloInstruction& consumer, const ShapeIndex& shape_index = {}) { - return bindings_.GetIrArray(inst, shape_index); + return bindings_.GetIrArray(inst, consumer, shape_index); } // A convenient helper for calling HloToIrBindings::GetBasePointer. llvm::Value* GetBasePointer(const HloInstruction& inst) const { @@ -177,9 +186,16 @@ class IrEmitter : public DfsHloVisitorWithDefault { // be simply implemented using an LLVM atomic instruction. If "computation" is // one of this kind, emits code to do that and returns true; otherwise, // returns false. - bool MaybeEmitSpecialAtomicOperation(const HloComputation& computation, - llvm::Value* output_address, - llvm::Value* source_address); + bool MaybeEmitDirectAtomicOperation(const HloComputation& computation, + llvm::Value* output_address, + llvm::Value* source_address); + + // A helper method for EmitAtomicOperationForNestedComputation. It implements + // binary atomic operations using atomicCAS with special handling to support + // small data types. + Status EmitAtomicOperationUsingCAS(const HloComputation& computation, + llvm::Value* output_address, + llvm::Value* source_address); StatusOr ComputeNestedElement( const HloComputation& computation, @@ -219,6 +235,7 @@ class IrEmitterUnnested : public IrEmitter { // IrEmitterUnnested handles the following instructions differently from // IrEmitter. Status HandleCopy(HloInstruction* copy) override; + Status HandleConditional(HloInstruction* conditional) override; Status HandleConvolution(HloInstruction* convolution) override; Status HandleDot(HloInstruction* dot) override; Status HandleFusion(HloInstruction* fusion) override; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc index 5da1a130d5654b86803396b07a6501c59a182c67..5225ff36ff3a8a1b049479c34aa301de8724f73e 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc @@ -115,7 +115,8 @@ Status IrEmitterNested::HandleParameter(HloInstruction* parameter) { Status IrEmitterNested::EmitTargetElementLoop( const HloInstruction& hlo, const llvm_ir::ElementGenerator& element_generator) { - return llvm_ir::LoopEmitter(element_generator, GetIrArray(hlo), &ir_builder_) + return llvm_ir::LoopEmitter(element_generator, GetIrArray(hlo, hlo), + &ir_builder_) .EmitLoop(); } diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 7b4662fc80c5518135c827489a3724e477b2bad1..8dbc90ee1fb5678f070bdc8999ffa8980197188f 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -123,10 +123,12 @@ void UpdateLaunchDimensions(const LaunchDimensions& launch_dims, Thunk* thunk, llvm::ConstantInt* threads_per_block_ir_value = llvm::ConstantInt::get( llvm::IntegerType::get(llvm_context, /*NumBits=*/32), launch_dims.threads_per_block()); + // Our launch bounds are exact, so we can specify them as reqntidx rather than + // maxntidx. nvvm_annotations_node->addOperand(llvm::MDNode::get( llvm_context, {llvm::ConstantAsMetadata::get(ir_kernel), - llvm::MDString::get(llvm_context, "maxntidx"), + llvm::MDString::get(llvm_context, "reqntidx"), llvm::ConstantAsMetadata::get(threads_per_block_ir_value)})); } } // namespace @@ -246,6 +248,11 @@ Status IrEmitterUnnested::DefaultAction(HloInstruction* hlo) { } Status IrEmitterUnnested::HandleDot(HloInstruction* dot) { + const DotDimensionNumbers& dnums = dot->dot_dimension_numbers(); + if (dnums.lhs_batch_dimensions_size() > 0 || + dnums.rhs_batch_dimensions_size() > 0) { + return Unimplemented("Dot with batch dimensions not implemented."); + } if (ImplementedAsGemm(*dot)) { thunk_sequence_->emplace_back(BuildGemmThunk(dot)); return Status::OK(); @@ -254,6 +261,11 @@ Status IrEmitterUnnested::HandleDot(HloInstruction* dot) { return IrEmitter::HandleDot(dot); } +Status IrEmitterUnnested::HandleConditional(HloInstruction* conditional) { + thunk_sequence_->push_back(BuildKernelThunk(conditional)); + return IrEmitter::HandleConditional(conditional); +} + Status IrEmitterUnnested::HandleConvolution(HloInstruction* convolution) { if (ImplementedAsDnnConvolution(*convolution)) { thunk_sequence_->emplace_back(BuildConvolutionThunk(convolution)); @@ -282,7 +294,7 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { MakeUnique(std::move(thunks), fusion)); std::vector parameter_arrays; for (HloInstruction* operand : fusion->operands()) { - parameter_arrays.push_back(GetIrArray(*operand)); + parameter_arrays.push_back(GetIrArray(*operand, *fusion)); } GpuElementalIrEmitter elemental_emitter( hlo_module_config_, ir_emitter_context_->llvm_module(), @@ -344,7 +356,7 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { thunk_sequence_->emplace_back(BuildKernelThunk(fusion)); std::vector operand_arrays; for (HloInstruction* operand : fusion->operands()) { - operand_arrays.push_back(GetIrArray(*operand)); + operand_arrays.push_back(GetIrArray(*operand, *fusion)); } GpuElementalIrEmitter elemental_emitter(hlo_module_config_, ir_emitter_context_->llvm_module(), @@ -355,7 +367,7 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { // Array to write into. Because this is an in-place operation, this is the // same as operand 0's array. - llvm_ir::IrArray output_array = GetIrArray(*fusion); + llvm_ir::IrArray output_array = GetIrArray(*fusion, *fusion); LaunchDimensions launch_dimensions = CalculateLaunchDimensions( update_shape, ir_emitter_context_->device_description()); @@ -693,9 +705,10 @@ Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) { constexpr int64 tile_size = 32; constexpr int64 num_rows = 8; int64 num_tiles = EmitTranspose021Tiled( - GetIrArray(*(copy->operand(0))) + GetIrArray(*copy->operand(0), *copy) .CastToShape(reduced_input_shape, &ir_builder_), - GetIrArray(*copy).CastToShape(reduced_output_shape, &ir_builder_), + GetIrArray(*copy, *copy) + .CastToShape(reduced_output_shape, &ir_builder_), tile_size, num_rows, &ir_builder_); UpdateLaunchDimensions(LaunchDimensions(num_tiles, num_rows * tile_size), LastThunk(), ir_emitter_context_->llvm_module()); @@ -850,9 +863,11 @@ Status IrEmitterUnnested::EmitColumnReduction( &ir_builder_); const HloInstruction* output = reduce->IsFused() ? reduce->parent()->FusionInstruction() : reduce; - llvm::Value* output_address = GetIrArray(*output).EmitArrayElementAddress( - llvm_ir::IrArray::Index(x, output->shape(), &ir_builder_), &ir_builder_, - "output_element_address"); + llvm::Value* output_address = + GetIrArray(*output, *output) + .EmitArrayElementAddress( + llvm_ir::IrArray::Index(x, output->shape(), &ir_builder_), + &ir_builder_, "output_element_address"); return EmitAtomicOperationForNestedComputation( *reducer, output_address, partial_reduction_result_address); }; @@ -1081,16 +1096,25 @@ Status IrEmitterUnnested::EmitRowReduction( // from the warp. llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.after_block, &ir_builder_); + int bit_width = llvm_ir::GetSizeInBits(element_ir_type); + // bitcast cannot be applied to aggregate types (even packed ones), so we + // instead bitcast addresses of load/store to intN* of the same bit-width. + llvm::Type* shuffle_ir_type = element_ir_type->isStructTy() + ? ir_builder_.getIntNTy(bit_width) + : element_ir_type; for (int shuffle_distance = 16; shuffle_distance >= 1; shuffle_distance /= 2) { llvm::Value* partial_reduction_result = ir_builder_.CreateLoad( - partial_reduction_result_address, "partial_reduction_result"); + ir_builder_.CreateBitCast(partial_reduction_result_address, + shuffle_ir_type->getPointerTo()), + "partial_reduction_result"); llvm::Value* result_from_other_lane = ir_builder_.CreateAlloca( element_ir_type, nullptr, "result_from_other_lane"); ir_builder_.CreateStore( EmitShuffleDown(partial_reduction_result, ir_builder_.getInt32(shuffle_distance), &ir_builder_), - result_from_other_lane); + ir_builder_.CreateBitCast(result_from_other_lane, + shuffle_ir_type->getPointerTo())); TF_RETURN_IF_ERROR(EmitCallToNestedComputation( *reducer, {partial_reduction_result_address, result_from_other_lane}, partial_reduction_result_address)); @@ -1107,9 +1131,11 @@ Status IrEmitterUnnested::EmitRowReduction( "lane_id_is_zero", &ir_builder_); llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, &ir_builder_); - llvm::Value* output_address = GetIrArray(*output).EmitArrayElementAddress( - llvm_ir::IrArray::Index(y, output->shape(), &ir_builder_), &ir_builder_, - "output_element_address"); + llvm::Value* output_address = + GetIrArray(*output, *output) + .EmitArrayElementAddress( + llvm_ir::IrArray::Index(y, output->shape(), &ir_builder_), + &ir_builder_, "output_element_address"); return EmitAtomicOperationForNestedComputation( *reducer, output_address, partial_reduction_result_address); }; @@ -1249,11 +1275,12 @@ Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) { MakeUnique(std::move(thunks), reduce)); return EmitReductionToVector( reduce, input->shape(), - [this, input](const llvm_ir::IrArray::Index& index) { - return GetIrArray(*input).EmitReadArrayElement(index, &ir_builder_); + [&](const llvm_ir::IrArray::Index& index) { + return GetIrArray(*input, *reduce) + .EmitReadArrayElement(index, &ir_builder_); }, - [this, init_value](const llvm_ir::IrArray::Index& index) { - return GetIrArray(*init_value) + [&](const llvm_ir::IrArray::Index& index) { + return GetIrArray(*init_value, *reduce) .EmitReadArrayElement(index, &ir_builder_); }, dimensions_to_reduce, reducer); @@ -1417,7 +1444,7 @@ Status IrEmitterUnnested::HandleSelectAndScatter( ir_builder_.CreateStore(operand_index[i], selected_index_address_slot); } }; - llvm_ir::IrArray operand_array(GetIrArray(*operand)); + llvm_ir::IrArray operand_array = GetIrArray(*operand, *select_and_scatter); llvm::Value* operand_data = operand_array.EmitReadArrayElement(operand_index, &ir_builder_); ir_builder_.CreateStore(operand_data, selected_value_address); @@ -1470,9 +1497,10 @@ Status IrEmitterUnnested::HandleSelectAndScatter( ir_builder_.CreateLoad(selected_index_address_slot)); } llvm::Value* source_value_address = - GetIrArray(*source).EmitArrayElementAddress(source_index, &ir_builder_); + GetIrArray(*source, *select_and_scatter) + .EmitArrayElementAddress(source_index, &ir_builder_); llvm::Value* output_value_address = - GetIrArray(*select_and_scatter) + GetIrArray(*select_and_scatter, *select_and_scatter) .EmitArrayElementAddress(selected_index, &ir_builder_); return EmitAtomicOperationForNestedComputation( *select_and_scatter->scatter(), output_value_address, @@ -1749,7 +1777,7 @@ Status IrEmitterUnnested::EmitInitializer(const HloInstruction* hlo, return EmitTargetElementLoopInThunk( *hlo, [=](const llvm_ir::IrArray::Index& index) { - return GetIrArray(*init_value) + return GetIrArray(*init_value, *hlo) .EmitReadArrayElement(index, &ir_builder_); }, thunk); @@ -1850,7 +1878,7 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk( UpdateLaunchDimensions(launch_dimensions, thunk, ir_emitter_context_->llvm_module()); if (!hlo.IsMultiOutputFusion()) { - return ParallelLoopEmitter(element_generator, GetIrArray(hlo), + return ParallelLoopEmitter(element_generator, GetIrArray(hlo, hlo), launch_dimensions, &ir_builder_) .EmitLoop(IrName(&hlo)); } @@ -1858,7 +1886,7 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk( // For multiple outputs fusion, we need to emit each operand and the root. std::vector output_arrays; for (int64 i = 0; i < ShapeUtil::TupleElementCount(hlo.shape()); ++i) { - output_arrays.push_back(GetIrArray(hlo, {i})); + output_arrays.push_back(GetIrArray(hlo, hlo, {i})); } TF_RETURN_IF_ERROR(ParallelLoopEmitter(element_generator, output_arrays, launch_dimensions, &ir_builder_) @@ -1869,7 +1897,7 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk( tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer()); } ir_builder_.SetInsertPoint(ir_builder_.GetInsertBlock()->getTerminator()); - llvm_ir::EmitTuple(GetIrArray(hlo), tuple_operand_ptrs, &ir_builder_, + llvm_ir::EmitTuple(GetIrArray(hlo, hlo), tuple_operand_ptrs, &ir_builder_, module_); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc index 69399e36c4c4faa7c6ed5c79a3f094490f022001..96606993696354f36e143b3b994bbe6afb902df3 100644 --- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc @@ -48,6 +48,12 @@ tensorflow::Status KernelThunk::Initialize(const GpuExecutable& executable) { // StreamExecutor uses the latter. loader_spec_->AddCudaPtxInMemory( se::port::StringPiece(ptx.data(), ptx.size()), kernel_name_); + + if (!executable.cubin().empty()) { + loader_spec_->AddCudaCubinInMemory( + reinterpret_cast(executable.cubin().data()), kernel_name_); + } + return tensorflow::Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc index 81cca312982a3a5ee98b3914447f2d878354c3a5..059943d48cd34b0ac487b91c3f3079ee3f761229 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc @@ -34,7 +34,7 @@ limitations under the License. #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Bitcode/BitcodeReader.h" #include "llvm/Bitcode/BitcodeWriter.h" -#include "llvm/CodeGen/CommandFlags.h" +#include "llvm/CodeGen/CommandFlags.def" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/LegacyPassManager.h" #include "llvm/IR/Module.h" @@ -60,6 +60,7 @@ limitations under the License. #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/tracing.h" namespace xla { namespace gpu { @@ -76,7 +77,7 @@ static string GetLibdeviceFilename(const string& libdevice_dir_path, // Since CUDA 9.0, all GPU versions are included in a single file const char* unified_libdevice_filename = "libdevice.10.bc"; std::vector unified_libdevice_files; - const tensorflow::Status status = + const tensorflow::Status status = tensorflow::Env::Default()->GetMatchingPaths( tensorflow::io::JoinPath(libdevice_dir_path, unified_libdevice_filename), &unified_libdevice_files); @@ -342,6 +343,13 @@ StatusOr CompileModuleToPtx(llvm::Module* module, std::pair compute_capability, const HloModuleConfig& hlo_module_config, const string& libdevice_dir_path) { + // If the module has no functions or globals, there's nothing to compile. Just + // return an empty string. + if (module->empty() && module->global_empty()) { + VLOG(2) << "Module '" << llvm_ir::AsString(module->getName()) + << "' is empty. Skipping compilation."; + return string(); + } // Link the input module with libdevice, to pull in implementations of some // builtins. TF_RETURN_IF_ERROR( @@ -481,9 +489,11 @@ StatusOr CompileToPtx(llvm::Module* module, string ptx; { - ScopedLoggingTimer compilation_timer( - "Compile module " + llvm_ir::AsString(module->getName()), - /*vlog_level=*/2); + tensorflow::port::Tracing::TraceMe annotation( + "Compiling IR", llvm_ir::AsString(module->getName()), + /*is_expensive=*/true); + XLA_SCOPED_LOGGING_TIMER("Compile module " + + llvm_ir::AsString(module->getName())); TF_ASSIGN_OR_RETURN( ptx, CompileModuleToPtx(module, compute_capability, hlo_module_config, libdevice_dir_path)); diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc index 9274e16a455fc1a958cee5101b6a9ef7ce619347..c29fee0879c02021fdc23ac0e02ab398cf40f99e 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc @@ -49,8 +49,8 @@ HloInstruction* MaybePaddedAndSlicedInput( // applies positive padding and dilation. PaddingConfig padding_config = MakeNoPaddingConfig(input->shape().dimensions_size()); - for (size_t i = 0; i < conv_dnums.spatial_dimensions().size(); ++i) { - int64 dim = conv_dnums.spatial_dimensions(i); + for (size_t i = 0; i < conv_dnums.input_spatial_dimensions().size(); ++i) { + int64 dim = conv_dnums.input_spatial_dimensions(i); padding_config.mutable_dimensions(dim)->set_edge_padding_low( std::max(0LL, conv_window.dimensions(i).padding_low())); padding_config.mutable_dimensions(dim)->set_edge_padding_high( @@ -81,8 +81,8 @@ HloInstruction* MaybePaddedAndSlicedInput( std::vector limit_indices(input->shape().dimensions().begin(), input->shape().dimensions().end()); std::vector strides(input->shape().dimensions_size(), 1); - for (size_t i = 0; i < conv_dnums.spatial_dimensions().size(); ++i) { - int64 dim = conv_dnums.spatial_dimensions(i); + for (size_t i = 0; i < conv_dnums.input_spatial_dimensions().size(); ++i) { + int64 dim = conv_dnums.input_spatial_dimensions(i); // If dimension "dim" has negative padding, increase the start index or // decrement the limit index by the amount of negative padding. start_indices[dim] += @@ -117,8 +117,8 @@ HloInstruction* MaybePaddedKernel(const Window& conv_window, for (size_t i = 0; i < kernel->shape().dimensions_size(); ++i) { padding_config.add_dimensions(); } - for (size_t i = 0; i < conv_dnums.spatial_dimensions().size(); ++i) { - int64 dim = conv_dnums.spatial_dimensions(i); + for (size_t i = 0; i < conv_dnums.kernel_spatial_dimensions().size(); ++i) { + int64 dim = conv_dnums.kernel_spatial_dimensions(i); padding_config.mutable_dimensions(dim)->set_interior_padding( conv_window.dimensions(i).window_dilation() - 1); } @@ -202,8 +202,7 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution( // ABCD0 = Pad(ABCD, padding_high=1) // BackwardFilterConv(ABCD0, xyz, padding_low=pading_high=1) // We choose the lesser of padding_low and padding_high as the new padding. - HloInstruction* transpose = backward_conv->fused_expression_root(); - HloInstruction* forward_conv = transpose->mutable_operand(0); + HloInstruction* forward_conv = backward_conv->fused_expression_root(); HloInstruction* input = backward_conv->mutable_operand(0); Window new_forward_conv_window = forward_conv->window(); Window new_backward_conv_window = backward_conv->window(); @@ -229,7 +228,7 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution( // later. Therefore, the amount of new padding (low or high) is the minimum // of the amount of old padding low and old padding high. int64 new_conv_padding = std::min(padding_low, padding_high); - int64 dim = backward_conv_dnums.spatial_dimensions(i); + int64 dim = backward_conv_dnums.input_spatial_dimensions(i); input_padding_config.mutable_dimensions(dim)->set_edge_padding_low( padding_low - new_conv_padding); input_padding_config.mutable_dimensions(dim)->set_edge_padding_high( @@ -269,19 +268,10 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution( .ConsumeValueOrDie(), padded_input, output, new_forward_conv_window, forward_conv_dnums)); - HloInstruction* new_transpose = - computation->AddInstruction(HloInstruction::CreateTranspose( - ShapeInference::InferTransposeShape(new_forward_conv->shape(), - transpose->dimensions()) - .ConsumeValueOrDie(), - new_forward_conv, transpose->dimensions())); - - // Fuse the new forward convolution and the new transpose to the new backward - // convolution. + // Fuse the new forward convolution to the new backward convolution. HloInstruction* new_backward_conv = computation->CreateFusionInstructionForBackwardConvolution( - {new_transpose, new_forward_conv}, - HloInstruction::FusionKind::kConvBackwardFilter, + {new_forward_conv}, HloInstruction::FusionKind::kConvBackwardFilter, new_backward_conv_window, backward_conv_dnums); VLOG(1) << "Canonicalizing backward filter conv"; @@ -369,12 +359,11 @@ bool PadInsertion::CanonicalizeBackwardInputConvolution( std::vector limit_indices( new_backward_conv->shape().dimensions().begin(), new_backward_conv->shape().dimensions().end()); - std::vector strides(new_backward_conv->shape().dimensions_size(), - 1LL); + std::vector strides(new_backward_conv->shape().dimensions_size(), 1LL); for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) { int64 padding_low = backward_conv->window().dimensions(i).padding_low(); int64 padding_high = backward_conv->window().dimensions(i).padding_high(); - int64 dim = backward_conv_dnums.spatial_dimensions(i); + int64 dim = backward_conv_dnums.output_spatial_dimensions(i); if (padding_low > padding_high) { // If the amount of low padding (of the old backward convolution) is // larger, we internally pad the low end of the activations and slice diff --git a/tensorflow/compiler/xla/service/gpu/partition_assignment.cc b/tensorflow/compiler/xla/service/gpu/partition_assignment.cc index d0d2deee24848184278e3e51dcaa3bb673b5fadc..6cf280df05496716a0780d61ded92efd9982734c 100644 --- a/tensorflow/compiler/xla/service/gpu/partition_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/partition_assignment.cc @@ -44,37 +44,41 @@ std::ostream& operator<<(std::ostream& out, // Calculates the launch dimensions used to invoke `hlo`. LaunchDimensions CalculateLaunchDimensions( - const Shape& shape, const se::DeviceDescription& device_desc, - PartitionStrategy partition_strategy) { - int64 warp_size = device_desc.threads_per_warp(); - + const Shape& shape, const se::DeviceDescription& device_desc) { int64 num_elements = ShapeUtil::ElementsIn(shape); if (num_elements <= 1) { return LaunchDimensions(); } - // Calculate the number of threads per block. - // Initialize threads_per_block as the threads-per-block limit. - int64 threads_per_block = device_desc.threads_per_block_limit(); - VLOG(2) << "Initial # of threads per block = " << threads_per_block; - - if (partition_strategy == PartitionStrategy::kLatency) { - // Limit the thread count to allow maximum number of registers per thread. - // TODO(b/28560520): We don't have to assume the emitted kernel will use up - // all the registers. We could use ptxas to examine the actual number of - // register used, and set the thread count accordingly. - int64 threads_per_block_limit_due_to_registers = - device_desc.registers_per_core_limit() / - device_desc.registers_per_thread_limit(); - CHECK_NE(0, threads_per_block_limit_due_to_registers); - if (threads_per_block_limit_due_to_registers < threads_per_block) { - threads_per_block = - // Make `threads_per_block` a multiple of warp size to use GPU - // efficiently. - warp_size * - std::max(1LL, threads_per_block_limit_due_to_registers / warp_size); - VLOG(2) << "Update # of threads per block due to register pressure = " - << threads_per_block; + // Since we don't do any inter-warp communication, we're free to choose any + // block size we want, subject to hardware constraints. We choose the + // smallest block size that allows the GPU to reach full occupancy (assuming + // the kernel uses sufficiently few registers). This gives us max performance + // when the kernel uses few registers, and lets us scale down gracefully as + // the kernel uses more registers. + // + // Specifically, we choose the number of threads per block such that + // + // * = + + auto threads_per_core = device_desc.threads_per_core_limit(); + auto blocks_per_core = device_desc.blocks_per_core_limit(); + int64 threads_per_block; + if (threads_per_core != 0 && blocks_per_core != 0) { + threads_per_block = device_desc.threads_per_core_limit() / + device_desc.blocks_per_core_limit(); + } else { + static std::atomic log_count{0}; + if (log_count.fetch_add(1) < 8) { + LOG(WARNING) << "Attempting to calculate launch dimensions for GPU " + "without full information about its capabilities. " + "StreamExecutor's PopulateDeviceDescription should be " + "updated for this device."; + } + threads_per_block = device_desc.threads_per_warp(); + if (threads_per_block == 0) { + // Fall back to *something* if we can't even get num threads per warp. + threads_per_block = 32; } } @@ -84,8 +88,6 @@ LaunchDimensions CalculateLaunchDimensions( << threads_per_block << ") because the latter is smaller."; } - // Calculate the block count. We copy the strategy used by Eigen: - // eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h int64 block_count = CeilOfRatio(num_elements, threads_per_block); VLOG(2) << tensorflow::strings::Printf( "Initialized the block count to ceil(# of elements / threads per " diff --git a/tensorflow/compiler/xla/service/gpu/partition_assignment.h b/tensorflow/compiler/xla/service/gpu/partition_assignment.h index 8f7fce884acc93fd39510ad0826b819a6d9731a7..0bf463a6ef95d5a32784838c08ad239752fd1acf 100644 --- a/tensorflow/compiler/xla/service/gpu/partition_assignment.h +++ b/tensorflow/compiler/xla/service/gpu/partition_assignment.h @@ -30,14 +30,6 @@ limitations under the License. namespace xla { namespace gpu { -enum class PartitionStrategy { - // Optimized for latency by allowing maximum number of registers per thread. - kLatency, - // Optimized for throughput. This may limit registers per thread and cause - // longer latency. - kThroughput -}; - // Encapsulates the launch dimensions of a kernel, e.g., the block count and the // number of threads per block. class LaunchDimensions { @@ -66,8 +58,7 @@ std::ostream& operator<<(std::ostream& out, LaunchDimensions CalculateLaunchDimensions( const Shape& shape, - const perftools::gputools::DeviceDescription& device_desc, - PartitionStrategy partition_strategy = PartitionStrategy::kLatency); + const perftools::gputools::DeviceDescription& device_desc); } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/thunk.h b/tensorflow/compiler/xla/service/gpu/thunk.h index 0ff27888ad72f8190400c22a9086d1965448662c..486ea7d7e1dad3f7f37d50565e176fbf567f5cc4 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk.h +++ b/tensorflow/compiler/xla/service/gpu/thunk.h @@ -70,6 +70,19 @@ class Thunk { return tensorflow::Status::OK(); } + // Users of Thunk should call ShouldHaltAllActivityBeforeRunning(stream) + // before calling ExecuteOnStream(stream). If it returns true, it's the + // user's responsibility to wait for all activity on the GPU to finish before + // calling ExecuteOnStream. + // + // This value is not required to be constant for a given Thunk. For example, + // a Thunk that performs autotuning may return true for its first run and + // false thereafter. + virtual bool ShouldHaltAllActivityBeforeRunning( + perftools::gputools::Stream* /*stream*/) { + return false; + } + // Execute the kernel for the thunk on the given stream. This method must be // called after Initialize and can be called multiple times over Thunk's // lifetime. Stream argument must be non-null. diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc index 44188473d39088923c67216facab472a4e4ee09f..f16daa0b5481474e754c880ead1945297ca50168 100644 --- a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc +++ b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc @@ -17,9 +17,12 @@ limitations under the License. #include "tensorflow/compiler/xla/service/copy_insertion.h" #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" +#include "tensorflow/compiler/xla/service/hlo_verifier.h" +#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/lib/core/status_test_util.h" namespace xla { namespace { @@ -33,8 +36,6 @@ class WhileTransformerTest : public HloTestBase { : module_(CreateNewModule()), induction_variable_shape_(ShapeUtil::MakeShape(S32, {})), data_shape_(ShapeUtil::MakeShape(F32, {8})), - loop_state_shape_(ShapeUtil::MakeTupleShape( - {induction_variable_shape_, data_shape_})), condition_result_shape_(ShapeUtil::MakeShape(PRED, {})) {} std::unique_ptr BuildConditionComputation( @@ -42,8 +43,8 @@ class WhileTransformerTest : public HloTestBase { auto builder = HloComputation::Builder(TestName() + ".Condition"); auto limit_const = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(limit))); - auto loop_state = builder.AddInstruction( - HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state")); + auto loop_state = builder.AddInstruction(HloInstruction::CreateParameter( + 0, GetLoopStateShape(tuple_index), "loop_state")); auto induction_variable = builder.AddInstruction(HloInstruction::CreateGetTupleElement( limit_const->shape(), loop_state, tuple_index)); @@ -58,8 +59,8 @@ class WhileTransformerTest : public HloTestBase { const int64 increment) { auto builder = HloComputation::Builder(TestName() + ".Body"); // Create param instruction to access loop state. - auto loop_state = builder.AddInstruction( - HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state")); + auto loop_state = builder.AddInstruction(HloInstruction::CreateParameter( + 0, GetLoopStateShape(ind_var_tuple_index), "loop_state")); // Update the induction variable GTE(ind_var_tuple_index). auto induction_variable = builder.AddInstruction(HloInstruction::CreateGetTupleElement( @@ -73,7 +74,7 @@ class WhileTransformerTest : public HloTestBase { data_shape_, loop_state, data_tuple_index)); // Use 'induction_variable' in computation with no path to output tuple. auto update = builder.AddInstruction( - HloInstruction::CreateBroadcast(data_shape_, induction_variable, {8})); + HloInstruction::CreateBroadcast(data_shape_, induction_variable, {})); auto add1 = builder.AddInstruction(HloInstruction::CreateBinary( data_shape_, HloOpcode::kAdd, data, update)); // Create output Tuple. @@ -98,8 +99,9 @@ class WhileTransformerTest : public HloTestBase { HloInstruction::CreateTuple({induction_var_init, data_init})) : builder.AddInstruction( HloInstruction::CreateTuple({data_init, induction_var_init})); - auto while_hlo = builder.AddInstruction(HloInstruction::CreateWhile( - loop_state_shape_, condition, body, loop_state_init)); + auto while_hlo = builder.AddInstruction( + HloInstruction::CreateWhile(GetLoopStateShape(ind_var_tuple_index), + condition, body, loop_state_init)); module_->AddEntryComputation(builder.Build()); return while_hlo; } @@ -115,18 +117,34 @@ class WhileTransformerTest : public HloTestBase { } void RunCopyInsertionPass() { + HloVerifier verifier([](const Shape& shape) { + return ShapeUtil::ByteSizeOf(shape, /*pointer_size=*/sizeof(void*)); + }); + TF_ASSERT_OK(verifier.Run(module_.get()).status()); CopyInsertion copy_insertion; - EXPECT_IS_OK(copy_insertion.Run(module_.get()).status()); + TF_ASSERT_OK(copy_insertion.Run(module_.get()).status()); + } + + Shape GetLoopStateShape(const int64 ind_var_tuple_index) { + if (ind_var_tuple_index == 0) { + return ShapeUtil::MakeTupleShape( + {induction_variable_shape_, data_shape_}); + } else { + return ShapeUtil::MakeTupleShape( + {data_shape_, induction_variable_shape_}); + } } std::unique_ptr module_; Shape induction_variable_shape_; Shape data_shape_; - Shape loop_state_shape_; Shape condition_result_shape_; }; -TEST_F(WhileTransformerTest, InductionVariableAtTupleElement0) { +// TODO(b/68830972): The while transformer is far too fragile. It patterns +// matches the exact expressions of opcodes. Re-enable when transformation is +// more general +TEST_F(WhileTransformerTest, DISABLED_InductionVariableAtTupleElement0) { // Build computation with induction variable at tuple element 0. auto condition = module_->AddEmbeddedComputation(BuildConditionComputation(0, 10)); @@ -137,13 +155,16 @@ TEST_F(WhileTransformerTest, InductionVariableAtTupleElement0) { RunCopyInsertionPass(); // Run WhileTransformer. auto result = gpu::CanTransformWhileToFor(while_hlo); - ASSERT_TRUE(result.ok()); + TF_ASSERT_OK(result.status()); // Check results. EXPECT_THAT(result.ConsumeValueOrDie(), Eq(std::tuple(0, 10, 1))); } -TEST_F(WhileTransformerTest, InductionVariableAtTupleElement1) { +// TODO(b/68830972): The while transformer is far too fragile. It patterns +// matches the exact expressions of opcodes. Re-enable when transformation is +// more general +TEST_F(WhileTransformerTest, DISABLED_InductionVariableAtTupleElement1) { // Build computation with induction variable at tuple element 1. auto condition = module_->AddEmbeddedComputation(BuildConditionComputation(1, 10)); @@ -154,13 +175,16 @@ TEST_F(WhileTransformerTest, InductionVariableAtTupleElement1) { RunCopyInsertionPass(); // Run WhileTransformer. auto result = gpu::CanTransformWhileToFor(while_hlo); - ASSERT_TRUE(result.ok()); + TF_ASSERT_OK(result.status()); // Check results. EXPECT_THAT(result.ConsumeValueOrDie(), Eq(std::tuple(0, 10, 1))); } -TEST_F(WhileTransformerTest, InvalidLoopLimit) { +// TODO(b/68830972): The while transformer is far too fragile. It patterns +// matches the exact expressions of opcodes. Re-enable when transformation is +// more general +TEST_F(WhileTransformerTest, DISABLED_InvalidLoopLimit) { // Build computation with invalid loop limit. auto condition = module_->AddEmbeddedComputation(BuildConditionComputation(0, 5)); @@ -176,7 +200,10 @@ TEST_F(WhileTransformerTest, InvalidLoopLimit) { HasSubstr("Loop start must be less than loop limit.")); } -TEST_F(WhileTransformerTest, InvalidLoopIncrement) { +// TODO(b/68830972): The while transformer is far too fragile. It patterns +// matches the exact expressions of opcodes. Re-enable when transformation is +// more general +TEST_F(WhileTransformerTest, DISABLED_InvalidLoopIncrement) { // Build computation with invalid loop increment. auto condition = module_->AddEmbeddedComputation(BuildConditionComputation(0, 10)); diff --git a/tensorflow/compiler/xla/service/graphviz_example.cc b/tensorflow/compiler/xla/service/graphviz_example.cc index 049e8d80d80c835bca4a4d38592564ba82a3ecf9..05017008e2ddbe0b9e78d06275fdec5d08d94bfa 100644 --- a/tensorflow/compiler/xla/service/graphviz_example.cc +++ b/tensorflow/compiler/xla/service/graphviz_example.cc @@ -108,8 +108,11 @@ std::unique_ptr MakeBigGraph() { HloInstruction::CreateUnary(vshape, HloOpcode::kCopy, param_v0)); auto clamp = builder.AddInstruction(HloInstruction::CreateTernary( vshape, HloOpcode::kClamp, copy, param_v1, param_v2)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); auto dot = builder.AddInstruction( - HloInstruction::CreateBinary(vshape, HloOpcode::kDot, clamp, param_v0)); + HloInstruction::CreateDot(vshape, clamp, param_v0, dot_dnums)); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({dot, param_s, clamp})); auto scalar = builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc index 17b926c8748e45b55f380e7595711b9e7a748f64..387b649a731ebcbfd8307807469f39f22d192b06 100644 --- a/tensorflow/compiler/xla/service/heap_simulator_test.cc +++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc @@ -259,8 +259,11 @@ TEST_F(HeapSimulatorTest, MultiplyDot) { HloInstruction::CreateParameter(2, f32scalar_, "paramY")); auto mul = builder.AddInstruction(HloInstruction::CreateBinary( f32vec4_, HloOpcode::kMultiply, paramA, paramX)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); auto dot = builder.AddInstruction( - HloInstruction::CreateBinary(f32vec4_, HloOpcode::kDot, mul, paramY)); + HloInstruction::CreateDot(f32vec4_, mul, paramY, dot_dnums)); // The buffer for dot is the output, and it cannot be shared with the buffer // for mul, since dot isn't elementwise. @@ -292,8 +295,11 @@ TEST_F(HeapSimulatorTest, MultiplyDotAdd) { HloInstruction::CreateParameter(2, f32scalar_, "paramY")); auto mul = builder.AddInstruction(HloInstruction::CreateBinary( f32vec4_, HloOpcode::kMultiply, paramA, paramX)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); auto dot = builder.AddInstruction( - HloInstruction::CreateBinary(f32vec4_, HloOpcode::kDot, mul, paramY)); + HloInstruction::CreateDot(f32vec4_, mul, paramY, dot_dnums)); auto add = builder.AddInstruction( HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, dot, paramA)); @@ -327,10 +333,13 @@ TEST_F(HeapSimulatorTest, MultiplyDotDot) { HloInstruction::CreateParameter(2, f32scalar_, "paramY")); auto mul = builder.AddInstruction(HloInstruction::CreateBinary( f32vec4_, HloOpcode::kMultiply, paramA, paramX)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); auto dot0 = builder.AddInstruction( - HloInstruction::CreateBinary(f32vec4_, HloOpcode::kDot, mul, paramY)); + HloInstruction::CreateDot(f32vec4_, mul, paramY, dot_dnums)); auto dot1 = builder.AddInstruction( - HloInstruction::CreateBinary(f32vec4_, HloOpcode::kDot, dot0, paramY)); + HloInstruction::CreateDot(f32vec4_, dot0, paramY, dot_dnums)); // The buffer for dot1 is the output. No buffers can be shared. The buffer // for mul is freed before the end, since it's no longer used after dot0 @@ -365,10 +374,13 @@ TEST_F(HeapSimulatorTest, MultiplyDotDotTuple) { HloInstruction::CreateParameter(2, f32scalar_, "paramY")); auto mul = builder.AddInstruction(HloInstruction::CreateBinary( f32vec4_, HloOpcode::kMultiply, paramA, paramX)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); auto dot0 = builder.AddInstruction( - HloInstruction::CreateBinary(f32vec4_, HloOpcode::kDot, mul, paramY)); + HloInstruction::CreateDot(f32vec4_, mul, paramY, dot_dnums)); auto dot1 = builder.AddInstruction( - HloInstruction::CreateBinary(f32vec4_, HloOpcode::kDot, dot0, paramY)); + HloInstruction::CreateDot(f32vec4_, dot0, paramY, dot_dnums)); auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({dot0, dot1})); diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index 79493c4112804f8454d200f3f83aa85d718f0d0a..5d0cfba1fc8ab255c228c671fee641e9302f5ec6 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -118,6 +118,9 @@ message HloInstructionProto { // Shape of outfeed request. xla.Shape outfeed_shape = 29; + + // Describes the dimension numbers used for a dot operation + xla.DotDimensionNumbers dot_dimension_numbers = 30; } // Serialization of HloComputation. @@ -250,7 +253,3 @@ message HloProto { HloOrderingProto hlo_ordering = 2; BufferAssignmentProto buffer_assignment = 3; } - -message HloProtos { - repeated HloProto hlo_protos = 1; -} diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc index 6f8099475146e6bbcfb61d2e5a91a7a6f9e63e58..6d2a3aa5b531650a658502531e050702ffbd3760 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc @@ -144,8 +144,10 @@ class BufferValueMap { // Move the given value into the given buffer. void MoveValueToBuffer(const HloValue& value, BufferNumber buffer_number) { BufferNumber old_buffer_number = value_to_buffer_number_.at(&value); - buffers_.at(old_buffer_number).erase(&value); - if (buffers_.at(old_buffer_number).empty()) { + tensorflow::gtl::FlatSet& old_value_set = + buffers_.at(old_buffer_number); + old_value_set.erase(&value); + if (old_value_set.empty()) { buffers_.erase(old_buffer_number); } @@ -175,7 +177,7 @@ class BufferValueMap { // Value is init of a while (use is while). std::vector aliased_buffers; for (const HloUse& use : value.uses()) { - VLOG(1) << "use of value " << value.ToShortString() << ": " << use; + VLOG(2) << "use of value " << value.ToShortString() << ": " << use; if (use.instruction->opcode() == HloOpcode::kWhile) { // Determine the while value that this shares a buffer with. const HloValue& while_value = @@ -411,7 +413,7 @@ string HloAliasAnalysis::ToString() const { /* static */ StatusOr> HloAliasAnalysis::Run( HloModule* module) { - VLOG(1) << "HloAliasAnalysis::Run on module " << module->name(); + VLOG(2) << "HloAliasAnalysis::Run on module " << module->name(); XLA_VLOG_LINES(2, module->ToString()); auto alias_analysis = WrapUnique(new HloAliasAnalysis(module)); @@ -444,7 +446,7 @@ StatusOr> HloAliasAnalysis::Run( TF_DCHECK_OK(alias_analysis->Verify()); - XLA_VLOG_LINES(1, alias_analysis->ToString()); + XLA_VLOG_LINES(2, alias_analysis->ToString()); return std::move(alias_analysis); } diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 72c70b38238eedb67622f4816e1de264f3c9ed4b..014a851c96ed1d530cfd5fa4e854cf1df45fc4d0 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -176,10 +176,6 @@ bool HloComputation::IsRemovable(const HloInstruction* instruction) { return false; } - if (instruction->HasSideEffect()) { - return false; - } - return true; } @@ -207,7 +203,8 @@ Status HloComputation::RemoveInstructionAndUnusedOperands( worklist.pop(); if (removed.count(item) != 0 || item->user_count() != 0 || - item == root_instruction() || !IsRemovable(item)) { + item == root_instruction() || !IsRemovable(item) || + item->HasSideEffect()) { continue; } for (int i = 0; i < item->operand_count(); ++i) { @@ -367,7 +364,8 @@ std::list HloComputation::MakeEmbeddedComputationsList() return post_order; } -string HloComputation::ToString(int nested_level) const { +string HloComputation::ToString(int nested_level, + bool include_large_constants) const { std::ostringstream s; for (int i = 0; i < nested_level; i++) { s << " "; @@ -379,12 +377,11 @@ string HloComputation::ToString(int nested_level) const { s << " "; } s << " " << (instruction == root_instruction_ ? "ROOT " : "") - << instruction->ToString() << "\n"; - if (instruction->opcode() == HloOpcode::kFusion) { - s << instruction->fused_instructions_computation()->ToString( - nested_level + 1) - << "\n"; - } + << instruction->ToString( + /*compact_operands=*/false, + /*include_metadata=*/true, + /*include_large_constants=*/include_large_constants) + << "\n"; } for (int i = 0; i < nested_level; i++) { s << " "; @@ -407,16 +404,18 @@ HloComputationProto HloComputation::ToProto() const { /* static */ StatusOr> HloComputation::CreateFromProto( HloModule* module, const HloComputationProto& proto, - tensorflow::gtl::FlatMap* computation_map, + const tensorflow::gtl::FlatMap& computation_map, + const std::function)>& + add_fused_computation, HloInstruction* fusion_instruction) { std::vector> instructions; tensorflow::gtl::FlatMap instruction_map; int64 parameter_count = 0; for (const HloInstructionProto& instruction_proto : proto.instructions()) { - TF_ASSIGN_OR_RETURN( - std::unique_ptr instruction, - HloInstruction::CreateFromProto(module, instruction_proto, - instruction_map, computation_map)); + TF_ASSIGN_OR_RETURN(std::unique_ptr instruction, + HloInstruction::CreateFromProto( + module, instruction_proto, instruction_map, + computation_map, add_fused_computation)); if (instruction->opcode() == HloOpcode::kParameter) { parameter_count++; } @@ -654,7 +653,9 @@ std::vector HloComputation::CollectUnreachableRoots() const { return unreachable_roots; } -Status HloComputation::Accept(DfsHloVisitor* visitor) const { +template +Status HloComputation::Accept( + DfsHloVisitorBase* visitor) const { // Visit unreachable roots. Beware that the visitor might delete the currently // visited root, which would invalidate iterators if the unreachable roots // weren't computed ahead of time. @@ -667,6 +668,10 @@ Status HloComputation::Accept(DfsHloVisitor* visitor) const { return root_instruction()->Accept(visitor, /*call_finish_visit=*/true); } +// Explicit instantiations. +template Status HloComputation::Accept(DfsHloVisitor* visitor) const; +template Status HloComputation::Accept(ConstDfsHloVisitor* visitor) const; + Status HloComputation::AcceptWithOperandOrder( DfsHloVisitor* visitor, const HloInstruction::CompareFunction& operand_order) const { @@ -683,8 +688,9 @@ Status HloComputation::AcceptWithOperandOrder( /*call_finish_visit=*/true); } +template Status HloComputation::AcceptOrdered( - DfsHloVisitor* visitor, + DfsHloVisitorBase* visitor, const std::vector& order) const { VLOG(3) << "Accepting visitor with order."; for (HloInstruction* root : CollectUnreachableRoots()) { @@ -713,49 +719,111 @@ Status HloComputation::AcceptOrdered( return Status::OK(); } +// Explicit instantiations. +template Status HloComputation::AcceptOrdered( + DfsHloVisitor*, const std::vector&) const; +template Status HloComputation::AcceptOrdered( + ConstDfsHloVisitor*, const std::vector&) const; + Status HloComputation::Accept( - const FunctionVisitor::VisitorFunction& visitor_func) const { + const std::function& visitor_func) { FunctionVisitor visitor(visitor_func); return this->Accept(&visitor); } -std::unique_ptr HloComputation::Clone(const string& suffix) { +Status HloComputation::Accept( + const std::function& visitor_func) const { + ConstFunctionVisitor visitor(visitor_func); + return this->Accept(&visitor); +} + +std::unique_ptr HloComputation::Clone(const string& suffix, + HloModule* module) { + return CloneWithReplacements( + /*replacements=*/std::unordered_map>(), + module, suffix); +} + +std::unique_ptr HloComputation::CloneWithReplacements( + std::unordered_map> + replacements, + HloModule* module, const string& suffix) { + // Look up instr in the replacements map, and return either the replacement, + // or instr, if the replacement isn't present. + // + // Note: This can return null, indicating that instr should not be present in + // the new computation. + auto replace = [&](HloInstruction* instr) { + auto it = replacements.find(instr); + if (it == replacements.end()) { + return instr; + } + return it->second.get(); + }; + VLOG(1) << "Cloning " << name() << " --> " << suffix << "\n"; - auto postorder = MakeInstructionPostOrder(); + std::vector postorder; + for (HloInstruction* instr : MakeInstructionPostOrder()) { + if (HloInstruction* replacement = replace(instr)) { + postorder.push_back(replacement); + } + } + std::unordered_map clone_map; std::vector> instructions; std::unique_ptr new_instr = nullptr; for (auto instr : postorder) { std::vector new_operands; for (auto operand : instr->operands()) { - HloInstruction* new_operand = FindOrDie(clone_map, operand); - CHECK(new_operand != nullptr); - new_operands.push_back(new_operand); - } - - new_instr = instr->CloneWithNewOperands(instr->shape(), new_operands); - new_instr->set_metadata(instr->metadata()); - if (instr->has_sharding()) { - new_instr->set_sharding(instr->sharding()); + auto replaced_operand = replace(operand); + // If replaced_operand is null, that means 'replacements' asked us not to + // include operand in the new computation. But we can't do that, because + // operand is used by instr. + CHECK_NE(replaced_operand, nullptr) + << "replacements map tried to eliminate a used instruction " + << operand->ToString() << ", used by " << instr->ToString(); + new_operands.push_back(FindOrDie(clone_map, replaced_operand)); } + new_instr = + instr->CloneWithNewOperands(instr->shape(), new_operands, module); InsertOrDie(&clone_map, instr, new_instr.get()); instructions.push_back(std::move(new_instr)); } - Builder builder(name() + suffix); + Builder builder(name() + "." + suffix); for (auto& instr : instructions) { builder.AddInstruction(std::move(instr)); } auto result = builder.Build( - /*root_instruction=*/FindOrDie(clone_map, root_instruction())); + /*root_instruction=*/FindOrDie(clone_map, replace(root_instruction()))); // Clone control dependencies. for (auto instr : postorder) { HloInstruction* new_instr = FindOrDie(clone_map, instr); for (auto successor : instr->control_successors()) { - TF_CHECK_OK( - new_instr->AddControlDependencyTo(FindOrDie(clone_map, successor))); + auto replaced_successor = replace(successor); + + // successor may not be in clone_map, because it might have been + // removed by the replacements map. + if (replaced_successor == nullptr) { + continue; + } + + TF_CHECK_OK(new_instr->AddControlDependencyTo( + FindOrDie(clone_map, replaced_successor))); + } + } + + // We cloned the elements of 'replacements', so they're all going to be + // destroyed. HloInstructions need to be detached from their operands before + // they're destroyed, otherwise they stick around in the operands' users lists + // and cause use-after-frees. + for (auto& kv : replacements) { + if (std::unique_ptr& new_instr = kv.second) { + new_instr->DetachFromOperands(); } } + return result; } diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index f4edd175016ee30d31cc0cad6bdbd3eaa014c704..ccedda2a03c088b93883dd79a101c832497a937a 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -138,7 +138,8 @@ class HloComputation { void UniquifyName(NameUniquer* name_uniquer); // Return a string representation of the computation. - string ToString(int nested_level = 0) const; + string ToString(int nested_level = 0, + bool include_large_constants = false) const; // Returns a serialized representation of this computation. HloComputationProto ToProto() const; @@ -151,12 +152,16 @@ class HloComputation { // computation_map: a map from computation name to HloComputation*. This map // must contain all computations which the newly constructed computation // calls. - // fusion_instruction: if non-null then the newly created computation will be - // constructed as a fused computation with this instruction as its fusion - // parent. + // add_fused_computation: A function to call to add a fused + // computation. Used only when the instruction is a fusion instruction. + // fusion_instruction: if non-null then the newly created computation will + // be constructed as a fused computation with this instruction as its + // fusion parent. static StatusOr> CreateFromProto( HloModule* module, const HloComputationProto& proto, - tensorflow::gtl::FlatMap* computation_map, + const tensorflow::gtl::FlatMap& computation_map, + const std::function)>& + add_fused_computation, HloInstruction* fusion_instruction = nullptr); // Gets the instructions in this computation. @@ -270,7 +275,8 @@ class HloComputation { // via the root. The root instruction of the computation is visited last, and // the visitor's FinishVisit method is called once upon completion (with the // root instruction as the argument). - Status Accept(DfsHloVisitor* visitor) const; + template + Status Accept(DfsHloVisitorBase* visitor) const; // Same as Accept() above, but the order of operand and control predecessor // visitation is determined by the given operand order; if compare(A, B) == @@ -281,20 +287,43 @@ class HloComputation { // Visit every node in the computation in the given order. 'order' must // be a topological sort of all instructions in the computation. - Status AcceptOrdered(DfsHloVisitor* visitor, + template + Status AcceptOrdered(DfsHloVisitorBase* visitor, const std::vector& order) const; // Same as Accept() above, but the visitor is given as a function. - Status Accept(const FunctionVisitor::VisitorFunction& visitor_func) const; + Status Accept(const std::function& visitor_func); + Status Accept( + const std::function& visitor_func) const; // Returns a deep copy of this computation including all instructions. - std::unique_ptr Clone(const string& suffix = "clone"); - - // Returns true if the given instruction can be removed from the - // computation. Instructions such as parameters and send/receive instructions - // cannot be removed without violating invariants of the HLO computation or - // module with the exception of fusion computation. A parameter instruction - // is removable for a fusion computation. + // If the module pointer is not nullptr, it will be the module where + // the cloned computations will be added to (in order to support deep + // cloning). + std::unique_ptr Clone(const string& suffix = "clone", + HloModule* module = nullptr); + + // Like Clone(), but if an instruction is present in replacement_map, we use + // the map's value to replace that instruction in the cloned computation. + // + // If replacements maps a key to nullptr, we remove that instruction from the + // new computation. + std::unique_ptr CloneWithReplacements( + std::unordered_map> + replacements, + HloModule* module = nullptr, const string& suffix = "clone"); + + // Returns true if the given instruction can be removed from the computation. + // Parameter instructions cannot be removed without violating invariants of + // the HLO computation with the exception of fusion computation. A parameter + // instruction is removable for a fusion computation. + // + // Note that IsRemovable() is a necessariy condition to remove an instruction + // rather than a sufficient condition. For example, instructions with + // side-effect (e.g., Send, Infeed) may be removed from a computation, but the + // transformation must guarantee the invariants relevant to the instructions + // still hold (e.g., Send and Recv must be removed together to make each + // channel complete). bool IsRemovable(const HloInstruction* instruction); // Returns true if this computation has a side effect. A computation has a @@ -307,6 +336,9 @@ class HloComputation { // Returns the owning fusion instruction, or nullptr if this is not a fusion // computation. HloInstruction* FusionInstruction() const { return fusion_instruction_; } + void SetFusionInstruction(HloInstruction* fusion_instruction) { + fusion_instruction_ = fusion_instruction; + } private: explicit HloComputation( diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index ab018c4cf2da770eabe74d7b5a670a19937b1b9a..b933695b823871c6c0174da6d6f99e618219442a 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -22,13 +22,14 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/bits.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/map_util.h" namespace xla { constexpr char HloCostAnalysis::kFlopsKey[]; constexpr char HloCostAnalysis::kTranscendentalsKey[]; constexpr char HloCostAnalysis::kBytesAccessedKey[]; -constexpr char HloCostAnalysis::kSecondsKey[]; +constexpr char HloCostAnalysis::kOptimalSecondsKey[]; HloCostAnalysis::HloCostAnalysis(const ShapeSizeFunction& shape_size) : HloCostAnalysis(shape_size, {}) {} @@ -37,7 +38,7 @@ HloCostAnalysis::HloCostAnalysis(const ShapeSizeFunction& shape_size, const Properties& per_second_rates) : shape_size_(shape_size), per_second_rates_(per_second_rates) {} -Status HloCostAnalysis::Preprocess(HloInstruction* hlo) { +Status HloCostAnalysis::Preprocess(const HloInstruction* hlo) { // Set current instruction cost values to reasonable default values. Each // handler can overwrite these values. In Postprocess, these values are // accumulated and written to the per-instruction maps. @@ -56,20 +57,20 @@ Status HloCostAnalysis::Preprocess(HloInstruction* hlo) { return Status::OK(); } -Status HloCostAnalysis::Postprocess(HloInstruction* hlo) { +Status HloCostAnalysis::Postprocess(const HloInstruction* hlo) { if (current_should_compute_bottleneck_time_) { // Compute the time as the time of the bottleneck, i.e. the slowest property // given the per-second rate of each property. - float max_seconds = 0.0f; + float optimal_seconds = 0.0f; for (const auto& property : current_properties_) { - if (property.first != kSecondsKey) { - max_seconds = std::max( - max_seconds, + if (property.first != kOptimalSecondsKey) { + optimal_seconds = std::max( + optimal_seconds, property.second / GetProperty(property.first, per_second_rates_, INFINITY)); } } - current_properties_[kSecondsKey] = max_seconds; + current_properties_[kOptimalSecondsKey] = optimal_seconds; } TF_RET_CHECK(hlo_properties_.emplace(hlo, current_properties_).second); @@ -80,7 +81,8 @@ Status HloCostAnalysis::Postprocess(HloInstruction* hlo) { return Status::OK(); } -Status HloCostAnalysis::HandleElementwiseOp(HloInstruction* hlo_instruction) { +Status HloCostAnalysis::HandleElementwiseOp( + const HloInstruction* hlo_instruction) { const auto& shape = hlo_instruction->shape(); // For element-wise operations, the number of computations is the same as the // number of elements in the output shape. @@ -118,58 +120,64 @@ Status HloCostAnalysis::HandleElementwiseOp(HloInstruction* hlo_instruction) { } } -Status HloCostAnalysis::HandleElementwiseUnary(HloInstruction* hlo) { +Status HloCostAnalysis::HandleElementwiseUnary(const HloInstruction* hlo) { return HandleElementwiseOp(hlo); } -Status HloCostAnalysis::HandleElementwiseBinary(HloInstruction* hlo) { +Status HloCostAnalysis::HandleElementwiseBinary(const HloInstruction* hlo) { return HandleElementwiseOp(hlo); } -Status HloCostAnalysis::HandleCompare(HloInstruction* compare) { +Status HloCostAnalysis::HandleCompare(const HloInstruction* compare) { return HandleElementwiseOp(compare); } -Status HloCostAnalysis::HandleClamp(HloInstruction* clamp) { +Status HloCostAnalysis::HandleClamp(const HloInstruction* clamp) { return HandleElementwiseOp(clamp); } -Status HloCostAnalysis::HandleReducePrecision(HloInstruction* hlo) { +Status HloCostAnalysis::HandleReducePrecision(const HloInstruction* hlo) { return HandleElementwiseOp(hlo); } -Status HloCostAnalysis::HandleParameter(HloInstruction*) { +Status HloCostAnalysis::HandleParameter(const HloInstruction*) { current_properties_[kBytesAccessedKey] = 0; return Status::OK(); } -Status HloCostAnalysis::HandleConstant(HloInstruction*) { +Status HloCostAnalysis::HandleConstant(const HloInstruction*) { current_properties_[kBytesAccessedKey] = 0; return Status::OK(); } -Status HloCostAnalysis::HandleGetTupleElement(HloInstruction*) { +Status HloCostAnalysis::HandleGetTupleElement(const HloInstruction*) { // GetTupleElement forwards a pointer and does not touch each element in the // output. current_properties_[kBytesAccessedKey] = 0; return Status::OK(); } -Status HloCostAnalysis::HandleSelect(HloInstruction*) { return Status::OK(); } +Status HloCostAnalysis::HandleSelect(const HloInstruction*) { + return Status::OK(); +} -Status HloCostAnalysis::HandleReverse(HloInstruction*) { return Status::OK(); } +Status HloCostAnalysis::HandleReverse(const HloInstruction*) { + return Status::OK(); +} -Status HloCostAnalysis::HandleSlice(HloInstruction*) { return Status::OK(); } +Status HloCostAnalysis::HandleSlice(const HloInstruction*) { + return Status::OK(); +} -Status HloCostAnalysis::HandleDynamicSlice(HloInstruction*) { +Status HloCostAnalysis::HandleDynamicSlice(const HloInstruction*) { return Status::OK(); } -Status HloCostAnalysis::HandleDynamicUpdateSlice(HloInstruction*) { +Status HloCostAnalysis::HandleDynamicUpdateSlice(const HloInstruction*) { return Status::OK(); } -Status HloCostAnalysis::HandleTuple(HloInstruction* tuple) { +Status HloCostAnalysis::HandleTuple(const HloInstruction* tuple) { // The tuple instruction only gathers pointers from inputs (it doesn't iterate // through them). The memory touched is then only the size of the output // index table of the tuple. @@ -178,23 +186,26 @@ Status HloCostAnalysis::HandleTuple(HloInstruction* tuple) { return Status::OK(); } -Status HloCostAnalysis::HandleConcatenate(HloInstruction*) { +Status HloCostAnalysis::HandleConcatenate(const HloInstruction*) { return Status::OK(); } -Status HloCostAnalysis::HandleConvert(HloInstruction* convert) { +Status HloCostAnalysis::HandleConvert(const HloInstruction* convert) { return HandleElementwiseOp(convert); } -Status HloCostAnalysis::HandleCopy(HloInstruction*) { return Status::OK(); } +Status HloCostAnalysis::HandleCopy(const HloInstruction*) { + return Status::OK(); +} -Status HloCostAnalysis::HandleDot(HloInstruction* dot) { +Status HloCostAnalysis::HandleDot(const HloInstruction* dot) { const Shape& lhs_shape = dot->operand(0)->shape(); const Shape& rhs_shape = dot->operand(1)->shape(); + const DotDimensionNumbers& dnums = dot->dot_dimension_numbers(); // Count of elements along the reduction dimension (last dimension for the // rhs). - int64 reduction_width = lhs_shape.dimensions(ShapeUtil::Rank(lhs_shape) - 1); - + int64 reduction_width = + lhs_shape.dimensions(dnums.lhs_contracting_dimensions(0)); // First divide by reduction width before multiplying by rhs elements to avoid // overflow. int64 fma_count; @@ -210,11 +221,15 @@ Status HloCostAnalysis::HandleDot(HloInstruction* dot) { return Status::OK(); } -Status HloCostAnalysis::HandleInfeed(HloInstruction*) { return Status::OK(); } +Status HloCostAnalysis::HandleInfeed(const HloInstruction*) { + return Status::OK(); +} -Status HloCostAnalysis::HandleOutfeed(HloInstruction*) { return Status::OK(); } +Status HloCostAnalysis::HandleOutfeed(const HloInstruction*) { + return Status::OK(); +} -Status HloCostAnalysis::HandleMap(HloInstruction* map) { +Status HloCostAnalysis::HandleMap(const HloInstruction* map) { // Compute properties of the mapped function. TF_ASSIGN_OR_RETURN(const Properties sub_properties, ProcessSubcomputation(map->to_apply())); @@ -229,7 +244,7 @@ Status HloCostAnalysis::HandleMap(HloInstruction* map) { return Status::OK(); } -Status HloCostAnalysis::HandleReduce(HloInstruction* reduce) { +Status HloCostAnalysis::HandleReduce(const HloInstruction* reduce) { auto arg = reduce->operand(0); HloComputation* function = reduce->to_apply(); // Compute the cost of the user function. @@ -247,7 +262,8 @@ Status HloCostAnalysis::HandleReduce(HloInstruction* reduce) { return Status::OK(); } -Status HloCostAnalysis::HandleReduceWindow(HloInstruction* reduce_window) { +Status HloCostAnalysis::HandleReduceWindow( + const HloInstruction* reduce_window) { const Window& window = reduce_window->window(); auto function = reduce_window->to_apply(); // Compute the properties of the reduction function. @@ -272,7 +288,8 @@ Status HloCostAnalysis::HandleReduceWindow(HloInstruction* reduce_window) { return Status::OK(); } -Status HloCostAnalysis::HandleSelectAndScatter(HloInstruction* instruction) { +Status HloCostAnalysis::HandleSelectAndScatter( + const HloInstruction* instruction) { // Compute the properties of the select and scatter function. // Compute the properties of the reduction function. TF_ASSIGN_OR_RETURN(const Properties select_properties, @@ -304,44 +321,60 @@ Status HloCostAnalysis::HandleSelectAndScatter(HloInstruction* instruction) { return Status::OK(); } -Status HloCostAnalysis::HandleBitcast(HloInstruction*) { +Status HloCostAnalysis::HandleBitcast(const HloInstruction*) { // A bitcast does no computation and touches no memory. current_properties_[kBytesAccessedKey] = 0; return Status::OK(); } -Status HloCostAnalysis::HandleBroadcast(HloInstruction*) { +Status HloCostAnalysis::HandleBroadcast(const HloInstruction*) { + return Status::OK(); +} + +Status HloCostAnalysis::HandlePad(const HloInstruction*) { + return Status::OK(); +} + +Status HloCostAnalysis::HandleSend(const HloInstruction*) { return Status::OK(); } -Status HloCostAnalysis::HandlePad(HloInstruction*) { return Status::OK(); } +Status HloCostAnalysis::HandleSendDone(const HloInstruction*) { + return Status::OK(); +} -Status HloCostAnalysis::HandleSend(HloInstruction*) { return Status::OK(); } +Status HloCostAnalysis::HandleRecv(const HloInstruction*) { + return Status::OK(); +} -Status HloCostAnalysis::HandleRecv(HloInstruction*) { return Status::OK(); } +Status HloCostAnalysis::HandleRecvDone(const HloInstruction*) { + return Status::OK(); +} -Status HloCostAnalysis::HandleReshape(HloInstruction*) { return Status::OK(); } +Status HloCostAnalysis::HandleReshape(const HloInstruction*) { + return Status::OK(); +} -Status HloCostAnalysis::HandleBatchNormTraining(HloInstruction*) { +Status HloCostAnalysis::HandleBatchNormTraining(const HloInstruction*) { // TODO(b/62294698): Implement cost analysis for batch-norm-training. return Status::OK(); } -Status HloCostAnalysis::HandleBatchNormInference(HloInstruction*) { +Status HloCostAnalysis::HandleBatchNormInference(const HloInstruction*) { // TODO(b/62294698): Implement cost analysis for batch-norm-inference. return Status::OK(); } -Status HloCostAnalysis::HandleBatchNormGrad(HloInstruction*) { +Status HloCostAnalysis::HandleBatchNormGrad(const HloInstruction*) { // TODO(b/62294698): Implement cost analysis for batch-norm-grad. return Status::OK(); } -Status HloCostAnalysis::HandleTranspose(HloInstruction*) { +Status HloCostAnalysis::HandleTranspose(const HloInstruction*) { return Status::OK(); } -Status HloCostAnalysis::HandleConvolution(HloInstruction* convolution) { +Status HloCostAnalysis::HandleConvolution(const HloInstruction* convolution) { auto rhs_instruction = convolution->operand(1); const auto& dnums = convolution->convolution_dimension_numbers(); const int64 output_features = @@ -359,17 +392,24 @@ Status HloCostAnalysis::HandleConvolution(HloInstruction* convolution) { return Status::OK(); } -Status HloCostAnalysis::HandleCrossReplicaSum(HloInstruction* crs) { +Status HloCostAnalysis::HandleCrossReplicaSum(const HloInstruction* crs) { // We assume 2 replicas, so that each output element is the sum of two input // elements. // // TODO(b/33004697): Compute correct cost here, taking the actual number of // replicas into account. - current_properties_[kFlopsKey] = ShapeUtil::ElementsIn(crs->shape()); + double flops = 0.0; + ShapeUtil::ForEachSubshape( + crs->shape(), [&, this](const Shape& subshape, const ShapeIndex&) { + if (ShapeUtil::IsArray(subshape)) { + flops += ShapeUtil::ElementsIn(subshape); + } + }); + current_properties_[kFlopsKey] = flops; return Status::OK(); } -Status HloCostAnalysis::HandleRng(HloInstruction* random) { +Status HloCostAnalysis::HandleRng(const HloInstruction* random) { // TODO(b/26346211): Implement better estimates for the RNG cost, since the // cost changes with the implementation and the distribution. For now, assume // the cost of each RNG is same as a transcendental operation. @@ -378,7 +418,7 @@ Status HloCostAnalysis::HandleRng(HloInstruction* random) { return Status::OK(); } -Status HloCostAnalysis::HandleFusion(HloInstruction* fusion) { +Status HloCostAnalysis::HandleFusion(const HloInstruction* fusion) { // Compute the properties of the fused expression and attribute them to the // fusion node. Use a dummy shape_size to avoid any errors from trying to // calculate the size of a shape that does not have a layout, since nodes @@ -406,18 +446,18 @@ Status HloCostAnalysis::HandleFusion(HloInstruction* fusion) { return Status::OK(); } -Status HloCostAnalysis::HandleCall(HloInstruction* call) { +Status HloCostAnalysis::HandleCall(const HloInstruction* call) { TF_ASSIGN_OR_RETURN(current_properties_, ProcessSubcomputation(call->to_apply())); current_should_compute_bottleneck_time_ = false; return Status::OK(); } -Status HloCostAnalysis::HandleCustomCall(HloInstruction*) { +Status HloCostAnalysis::HandleCustomCall(const HloInstruction*) { return Unimplemented("Custom-call is not implemented for HLO cost analysis."); } -Status HloCostAnalysis::HandleSort(HloInstruction* sort) { +Status HloCostAnalysis::HandleSort(const HloInstruction* sort) { // This assumes a comparison based N*log(N) algorithm. As for all ops, the // actual properties of the op depend on the backend implementation. int64 elements = ShapeUtil::ElementsIn(sort->operand(0)->shape()); @@ -425,7 +465,7 @@ Status HloCostAnalysis::HandleSort(HloInstruction* sort) { return Status::OK(); } -Status HloCostAnalysis::HandleWhile(HloInstruction* xla_while) { +Status HloCostAnalysis::HandleWhile(const HloInstruction* xla_while) { // Since the number of iterations of the while node will not always be // something that we can statically analyze, we cannot precisely compute the // cost of a while node. For now compute the cost of a single iteration. @@ -449,7 +489,28 @@ Status HloCostAnalysis::HandleWhile(HloInstruction* xla_while) { return Status::OK(); } -Status HloCostAnalysis::FinishVisit(HloInstruction*) { return Status::OK(); } +Status HloCostAnalysis::HandleConditional(const HloInstruction* conditional) { + // Compute the cost of the true and false computations and take the maximum + // from those for each property. + TF_ASSIGN_OR_RETURN(const Properties true_computation_properties, + ProcessSubcomputation(conditional->true_computation())); + TF_ASSIGN_OR_RETURN(const Properties false_computation_properties, + ProcessSubcomputation(conditional->false_computation())); + current_properties_ = true_computation_properties; + for (const auto& property : false_computation_properties) { + if (!tensorflow::gtl::InsertIfNotPresent(¤t_properties_, property)) { + current_properties_[property.first] = + std::max(current_properties_[property.first], property.second); + } + } + current_should_compute_bottleneck_time_ = false; + + return Status::OK(); +} + +Status HloCostAnalysis::FinishVisit(const HloInstruction*) { + return Status::OK(); +} float HloCostAnalysis::flop_count() const { return GetProperty(kFlopsKey, properties_sum_); @@ -463,8 +524,8 @@ float HloCostAnalysis::bytes_accessed() const { return GetProperty(kBytesAccessedKey, properties_sum_); } -float HloCostAnalysis::seconds() const { - return GetProperty(kSecondsKey, properties_sum_); +float HloCostAnalysis::optimal_seconds() const { + return GetProperty(kOptimalSecondsKey, properties_sum_); } int64 HloCostAnalysis::flop_count(const HloInstruction& hlo) const { @@ -479,8 +540,8 @@ int64 HloCostAnalysis::bytes_accessed(const HloInstruction& hlo) const { return GetPropertyForHlo(hlo, kBytesAccessedKey, hlo_properties_); } -float HloCostAnalysis::seconds(const HloInstruction& hlo) const { - return GetPropertyForHlo(hlo, kSecondsKey, hlo_properties_); +float HloCostAnalysis::optimal_seconds(const HloInstruction& hlo) const { + return GetPropertyForHlo(hlo, kOptimalSecondsKey, hlo_properties_); } StatusOr HloCostAnalysis::ProcessSubcomputation( diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h index 93b1b3eb20cf88292d38549016c9a0b662e155ee..fade19522cf0c30eab037aa355de1f9203f80014 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h @@ -34,7 +34,7 @@ namespace xla { // the computation cost of the instruction, and the values are accumulated // during the traversal for the entire graph. We treat normal floating point // operations separately from transcendental operations. -class HloCostAnalysis : public DfsHloVisitor { +class HloCostAnalysis : public ConstDfsHloVisitor { public: // Each HLO is associated to a vector of properties with the indices given // below. Sub-classes can add further properties. @@ -42,61 +42,66 @@ class HloCostAnalysis : public DfsHloVisitor { static constexpr char kFlopsKey[] = "flops"; static constexpr char kTranscendentalsKey[] = "transcendentals"; static constexpr char kBytesAccessedKey[] = "bytes accessed"; - static constexpr char kSecondsKey[] = "seconds"; + static constexpr char kOptimalSecondsKey[] = "optimal_seconds"; // shape_size is a function which returns the size in bytes of the top-level // buffer of a shape. using ShapeSizeFunction = std::function; explicit HloCostAnalysis(const ShapeSizeFunction& shape_size); - Status HandleElementwiseUnary(HloInstruction* hlo) override; - Status HandleElementwiseBinary(HloInstruction* hlo) override; - Status HandleConstant(HloInstruction* constant) override; - Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; - Status HandleSelect(HloInstruction* select) override; - Status HandleCompare(HloInstruction* compare) override; - Status HandleClamp(HloInstruction* clamp) override; - Status HandleReducePrecision(HloInstruction* hlo) override; - Status HandleConcatenate(HloInstruction* concatenate) override; - Status HandleSend(HloInstruction* send) override; - Status HandleRecv(HloInstruction* recv) override; - Status HandleConvert(HloInstruction* convert) override; - Status HandleCopy(HloInstruction* copy) override; - Status HandleDot(HloInstruction* dot) override; - Status HandleConvolution(HloInstruction* convolution) override; - Status HandleCrossReplicaSum(HloInstruction* crs) override; - Status HandleInfeed(HloInstruction* infeed) override; - Status HandleOutfeed(HloInstruction* outfeed) override; - Status HandleRng(HloInstruction* random) override; - Status HandleReverse(HloInstruction* reverse) override; - Status HandleSort(HloInstruction* sort) override; - Status HandleParameter(HloInstruction* parameter) override; - Status HandleReduce(HloInstruction* reduce) override; - Status HandleBatchNormTraining(HloInstruction* batch_norm_training) override; + Status HandleElementwiseUnary(const HloInstruction* hlo) override; + Status HandleElementwiseBinary(const HloInstruction* hlo) override; + Status HandleConstant(const HloInstruction* constant) override; + Status HandleGetTupleElement( + const HloInstruction* get_tuple_element) override; + Status HandleSelect(const HloInstruction* select) override; + Status HandleCompare(const HloInstruction* compare) override; + Status HandleClamp(const HloInstruction* clamp) override; + Status HandleReducePrecision(const HloInstruction* hlo) override; + Status HandleConcatenate(const HloInstruction* concatenate) override; + Status HandleSend(const HloInstruction* send) override; + Status HandleSendDone(const HloInstruction* send_done) override; + Status HandleRecv(const HloInstruction* recv) override; + Status HandleRecvDone(const HloInstruction* recv_done) override; + Status HandleConvert(const HloInstruction* convert) override; + Status HandleCopy(const HloInstruction* copy) override; + Status HandleDot(const HloInstruction* dot) override; + Status HandleConvolution(const HloInstruction* convolution) override; + Status HandleCrossReplicaSum(const HloInstruction* crs) override; + Status HandleInfeed(const HloInstruction* infeed) override; + Status HandleOutfeed(const HloInstruction* outfeed) override; + Status HandleRng(const HloInstruction* random) override; + Status HandleReverse(const HloInstruction* reverse) override; + Status HandleSort(const HloInstruction* sort) override; + Status HandleParameter(const HloInstruction* parameter) override; + Status HandleReduce(const HloInstruction* reduce) override; + Status HandleBatchNormTraining( + const HloInstruction* batch_norm_training) override; Status HandleBatchNormInference( - HloInstruction* batch_norm_inference) override; - Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override; - Status HandleFusion(HloInstruction* fusion) override; - Status HandleCall(HloInstruction* call) override; - Status HandleCustomCall(HloInstruction* custom_call) override; - Status HandleSlice(HloInstruction* slice) override; - Status HandleDynamicSlice(HloInstruction* dynamic_slice) override; + const HloInstruction* batch_norm_inference) override; + Status HandleBatchNormGrad(const HloInstruction* batch_norm_grad) override; + Status HandleFusion(const HloInstruction* fusion) override; + Status HandleCall(const HloInstruction* call) override; + Status HandleCustomCall(const HloInstruction* custom_call) override; + Status HandleSlice(const HloInstruction* slice) override; + Status HandleDynamicSlice(const HloInstruction* dynamic_slice) override; Status HandleDynamicUpdateSlice( - HloInstruction* dynamic_update_slice) override; - Status HandleTuple(HloInstruction* tuple) override; - Status HandleMap(HloInstruction* map) override; - Status HandleReduceWindow(HloInstruction* reduce_window) override; - Status HandleSelectAndScatter(HloInstruction* instruction) override; - Status HandleBitcast(HloInstruction* bitcast) override; - Status HandleBroadcast(HloInstruction* broadcast) override; - Status HandlePad(HloInstruction* pad) override; - Status HandleReshape(HloInstruction* reshape) override; - Status HandleTranspose(HloInstruction* transpose) override; - Status HandleWhile(HloInstruction* xla_while) override; - Status FinishVisit(HloInstruction* root) override; - - Status Preprocess(HloInstruction* hlo) override; - Status Postprocess(HloInstruction* hlo) override; + const HloInstruction* dynamic_update_slice) override; + Status HandleTuple(const HloInstruction* tuple) override; + Status HandleMap(const HloInstruction* map) override; + Status HandleReduceWindow(const HloInstruction* reduce_window) override; + Status HandleSelectAndScatter(const HloInstruction* instruction) override; + Status HandleBitcast(const HloInstruction* bitcast) override; + Status HandleBroadcast(const HloInstruction* broadcast) override; + Status HandlePad(const HloInstruction* pad) override; + Status HandleReshape(const HloInstruction* reshape) override; + Status HandleTranspose(const HloInstruction* transpose) override; + Status HandleWhile(const HloInstruction* xla_while) override; + Status HandleConditional(const HloInstruction* conditional) override; + Status FinishVisit(const HloInstruction* root) override; + + Status Preprocess(const HloInstruction* hlo) override; + Status Postprocess(const HloInstruction* hlo) override; // Set the rates used to calculate the time taken by the computation. These // need to be set before visiting starts. @@ -114,14 +119,14 @@ class HloCostAnalysis : public DfsHloVisitor { float flop_count() const; float transcendental_count() const; float bytes_accessed() const; - float seconds() const; + float optimal_seconds() const; // Returns the respective cost computed for a particular HLO instruction, or 0 // if the HLO was not found to have a cost in the analysis. int64 flop_count(const HloInstruction& hlo) const; int64 transcendental_count(const HloInstruction& hlo) const; int64 bytes_accessed(const HloInstruction& hlo) const; - float seconds(const HloInstruction& hlo) const; + float optimal_seconds(const HloInstruction& hlo) const; const Properties& properties() const { return properties_sum_; } const float property(const string& key) const { @@ -145,7 +150,7 @@ class HloCostAnalysis : public DfsHloVisitor { const ShapeSizeFunction* shape_size = nullptr); // Utility function to handle all element-wise operations. - Status HandleElementwiseOp(HloInstruction* hlo_instruction); + Status HandleElementwiseOp(const HloInstruction* hlo_instruction); // Returns the default value if the key is not present in the // properties. Otherwise, returns the value that the key maps to from the diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc index 0eaa21ef254e3461baaaca57503ab24ce35ac929..3b289c240a45e8f3df8156ed89e879da2132d01a 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc @@ -389,7 +389,7 @@ TEST_F(FusionCostAnalysis, LoopFusion) { static_assert(bytes_accessed == 64, ""); EXPECT_EQ(fusion_analysis.bytes_accessed(), bytes_accessed); - EXPECT_EQ(fusion_analysis.seconds(), 1 << i); + EXPECT_EQ(fusion_analysis.optimal_seconds(), 1 << i); } } diff --git a/tensorflow/compiler/xla/service/hlo_cse_test.cc b/tensorflow/compiler/xla/service/hlo_cse_test.cc index 7c4626e78a3e84c9723a9f8e39d56614c4fa25ce..3601a790c4428ee39c264b217a4b9a991ad8456c 100644 --- a/tensorflow/compiler/xla/service/hlo_cse_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cse_test.cc @@ -79,12 +79,12 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) { // Test that two identical constants with different layouts are commoned if // the pass is not layout sensitive. auto builder = HloComputation::Builder(TestName()); - auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant( - test_utils::CreateR2LiteralWithLayout({{1.0, 2.0}, {3.0, 4.0}}, - /*minor_to_major=*/{0, 1}))); - auto constant2 = builder.AddInstruction(HloInstruction::CreateConstant( - test_utils::CreateR2LiteralWithLayout({{1.0, 2.0}, {3.0, 4.0}}, - /*minor_to_major=*/{1, 0}))); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2WithLayout( + {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1})))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2WithLayout( + {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0})))); auto add = builder.AddInstruction(HloInstruction::CreateBinary( constant1->shape(), HloOpcode::kAdd, constant1, constant2)); @@ -111,12 +111,12 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) { // Test that two identical constants with different layouts are *not* commoned // if the pass is layout sensitive. auto builder = HloComputation::Builder(TestName()); - auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant( - test_utils::CreateR2LiteralWithLayout({{1.0, 2.0}, {3.0, 4.0}}, - /*minor_to_major=*/{0, 1}))); - auto constant2 = builder.AddInstruction(HloInstruction::CreateConstant( - test_utils::CreateR2LiteralWithLayout({{1.0, 2.0}, {3.0, 4.0}}, - /*minor_to_major=*/{1, 0}))); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2WithLayout( + {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1})))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2WithLayout( + {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0})))); auto add = builder.AddInstruction(HloInstruction::CreateBinary( constant1->shape(), HloOpcode::kAdd, constant1, constant2)); diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index 92261bce6270e3c37165c10ed804d036d2abb984..2a335843f507e2071807245d4dd256e1ec6f08c8 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -75,11 +75,43 @@ HloValue* HloDataflowAnalysis::NewHloValue(HloInstruction* instruction, std::forward_as_tuple(value_id, instruction, index, is_phi)); CHECK(emplaced.second); + VLOG(4) << "NewHloValue = " << emplaced.first->second.ToShortString(); + return &emplaced.first->second; } -void HloDataflowAnalysis::DeleteHloValue(HloValue::Id value_id) { - values_.erase(value_id); +void HloDataflowAnalysis::MarkValueForDeletion(HloValue::Id value_id) { + HloValue& value = values_.at(value_id); + VLOG(4) << "MarkValueForDeletion(" << value.ToShortString() << ")"; + + value_ids_to_delete_.push_back(value_id); +} + +void HloDataflowAnalysis::DeleteMarkedValues() { +#ifndef NDEBUG + // Verify that no marked-for-deletion values are in any of the value sets. + tensorflow::gtl::FlatSet id_set(value_ids_to_delete_.begin(), + value_ids_to_delete_.end()); + for (const auto& pair : value_sets_) { + const HloInstruction* instruction = pair.first; + const InstructionValueSet& instruction_value_set = pair.second; + for (const auto& index_value_set : instruction_value_set) { + const HloValueSet& value_set = index_value_set.second; + for (const HloValue* value : value_set.values()) { + DCHECK(!ContainsKey(id_set, value->id())) + << "Value " << value->ToShortString() + << " marked for deletion, but still exists in value set for " + "instruction " + << instruction->name(); + } + } + } +#endif + + for (HloValue::Id value_id : value_ids_to_delete_) { + values_.erase(value_id); + } + value_ids_to_delete_.clear(); } string HloDataflowAnalysis::ToString() const { @@ -121,6 +153,7 @@ bool HloDataflowAnalysis::Phi( HloInstruction* instruction, tensorflow::gtl::ArraySlice inputs) { CHECK(ssa_form_); + VLOG(4) << "Phi(" << instruction->name() << ")"; for (const InstructionValueSet* input : inputs) { DCHECK(ShapeUtil::Compatible(instruction->shape(), input->shape())); @@ -183,7 +216,7 @@ bool HloDataflowAnalysis::Phi( } else if (current_value != &new_value) { if (current_value_defined_here) { // Remove the existing phi. - DeleteHloValue(current_value->id()); + MarkValueForDeletion(current_value->id()); } value_set.Clear(); value_set.AddValue(&new_value); @@ -193,7 +226,8 @@ bool HloDataflowAnalysis::Phi( // Multiple distinct values reach this point. A phi value is // necessary. CHECK_GT(input_value_ids.size(), 1); - if (current_value == nullptr || !current_value->is_phi()) { + if (current_value == nullptr || + !(current_value->is_phi() && current_value_defined_here)) { value_set.Clear(); value_set.AddValue(NewHloValue(instruction, index, /*is_phi=*/true)); changed = true; @@ -242,6 +276,51 @@ bool HloDataflowAnalysis::UpdateBitcastValueSet(HloInstruction* bitcast) { return false; } +bool HloDataflowAnalysis::UpdateSendValueSet(HloInstruction* send) { + CHECK_EQ(send->opcode(), HloOpcode::kSend); + bool changed = false; + // Send forwards the operand value to the output tuple at {0}. + for (auto& pair : GetInstructionValueSet(send->operand(0))) { + const ShapeIndex& operand_index = pair.first; + const HloValueSet& operand_value_set = pair.second; + + ShapeIndex index = {0}; + for (int64 i : operand_index) { + index.push_back(i); + } + + HloValueSet& value_set = GetValueSet(send, index); + if (value_set != operand_value_set) { + value_set = operand_value_set; + changed = true; + } + } + return changed; +} + +bool HloDataflowAnalysis::UpdateRecvDoneValueSet(HloInstruction* recv_done) { + CHECK_EQ(recv_done->opcode(), HloOpcode::kRecvDone); + bool changed = false; + // RecvDone forwards the operand value at {0} to the output. + for (auto& pair : GetInstructionValueSet(recv_done)) { + ShapeIndex& index = pair.first; + HloValueSet& value_set = pair.second; + + ShapeIndex operand_index = {0}; + for (int64 i : index) { + operand_index.push_back(i); + } + + const HloValueSet& operand_value_set = + GetValueSet(recv_done->operand(0), operand_index); + if (value_set != operand_value_set) { + value_set = operand_value_set; + changed = true; + } + } + return changed; +} + bool HloDataflowAnalysis::UpdateCallValueSet(HloInstruction* call) { CHECK_EQ(call->opcode(), HloOpcode::kCall); InstructionValueSet& value_set = GetInstructionValueSet(call); @@ -254,6 +333,21 @@ bool HloDataflowAnalysis::UpdateCallValueSet(HloInstruction* call) { return false; } +bool HloDataflowAnalysis::UpdateConditionalValueSet( + HloInstruction* conditional) { + CHECK_EQ(conditional->opcode(), HloOpcode::kConditional); + std::vector inputs = { + &GetInstructionValueSet( + conditional->true_computation()->root_instruction()), + &GetInstructionValueSet( + conditional->false_computation()->root_instruction())}; + // A phi-node is not defined for a kConditional instruction even though it + // represents a join point. This is because the current approach is to define + // a phi-node only for kWhile to account for the dataflow through back-edges + // and deal with the ambiguity in other cases. + return GetInstructionValueSet(conditional).AssignUnionOf(inputs); +} + bool HloDataflowAnalysis::UpdateCopyValueSet(HloInstruction* copy) { CHECK_EQ(copy->opcode(), HloOpcode::kCopy); bool changed = false; @@ -315,7 +409,7 @@ bool HloDataflowAnalysis::UpdateParameterValueSet(HloInstruction* parameter) { CHECK_EQ(call_graph_node.context(), CallContext::kSequential); std::vector inputs; - bool called_from_while = false; + bool need_phi = false; for (const CallSite& callsite : call_graph_node.caller_callsites()) { if (callsite.instruction()->opcode() == HloOpcode::kCall) { // The operand values of a call instruction are forwarded to the @@ -337,14 +431,32 @@ bool HloDataflowAnalysis::UpdateParameterValueSet(HloInstruction* parameter) { inputs.push_back(&GetInstructionValueSet( callsite.instruction()->while_body()->root_instruction())); } - called_from_while = true; + need_phi = true; + } else if (callsite.instruction()->opcode() == HloOpcode::kConditional) { + CHECK_EQ(parameter->parameter_number(), 0); + auto conditional = callsite.instruction(); + // Conditional has 3 operands. Operand 0 is the predicate, operand 1 is + // the argument to the true computation and operand 2 is the argument to + // the false computation. + // + // If the parameter belongs to conditional's true computation, then + // operand 1 is forwarded to this parameter instruction. If the parameter + // belongs to conditional's false computation, then operand 2 is forwarded + // to this parameter instruction. + if (parameter->parent() == conditional->true_computation()) { + inputs.push_back(&GetInstructionValueSet(conditional->operand(1))); + } else { + CHECK_EQ(parameter->parent(), conditional->false_computation()); + inputs.push_back(&GetInstructionValueSet(conditional->operand(2))); + } + need_phi = true; } else { LOG(FATAL) << "CallContext::kSequential computations should only be " - "called from call or while instructions"; + "called from call, while, or conditional instructions"; } } - if (ssa_form_ && called_from_while) { + if (ssa_form_ && need_phi) { return Phi(parameter, inputs); } else { return GetInstructionValueSet(parameter).AssignUnionOf(inputs); @@ -429,6 +541,12 @@ bool HloDataflowAnalysis::UpdateInstructionValueSet( return UpdateCallValueSet(instruction); case HloOpcode::kWhile: return UpdateWhileValueSet(instruction); + case HloOpcode::kSend: + return UpdateSendValueSet(instruction); + case HloOpcode::kRecvDone: + return UpdateRecvDoneValueSet(instruction); + case HloOpcode::kConditional: + return UpdateConditionalValueSet(instruction); default: // Instruction does not forward HloValues (it defines all values in its // output). No update is necessary. @@ -436,11 +554,13 @@ bool HloDataflowAnalysis::UpdateInstructionValueSet( } } -void HloDataflowAnalysis::UpdateInstructionsAndPropagate( - tensorflow::gtl::ArraySlice instructions) { +void HloDataflowAnalysis::Propagate() { std::queue worklist; - for (HloInstruction* instruction : instructions) { - worklist.push(instruction); + + for (HloComputation* computation : module_->computations()) { + for (HloInstruction* instruction : computation->instructions()) { + worklist.push(instruction); + } } while (!worklist.empty()) { @@ -465,13 +585,31 @@ void HloDataflowAnalysis::UpdateInstructionsAndPropagate( // If user sequentially calls a computation, then the respective // parameter(s) of the computation need to be updated. - for (HloComputation* called_computation : user->called_computations()) { - const CallGraphNode& call_graph_node = - call_graph_->GetNode(called_computation); - if (call_graph_node.context() == CallContext::kSequential) { - for (int64 operand_number : user->OperandIndices(instruction)) { - worklist.push( - called_computation->parameter_instruction(operand_number)); + if (user->opcode() == HloOpcode::kConditional) { + // If operand 0 is the use of instruction, then no parameters need to be + // updated, since that is the predicate of the conditional. + // If operand 1 is the use of instruction, then the true_computation's + // parameter need to be updated. + // If operand 2 is the use of instruction, then the false_computation's + // parameter need to be updated. + // + // Note that the same instruction can be used in both operand 1 and + // operand 2. + if (user->operand(1) == instruction) { + worklist.push(user->true_computation()->parameter_instruction(0)); + } + if (user->operand(2) == instruction) { + worklist.push(user->false_computation()->parameter_instruction(0)); + } + } else { + for (HloComputation* called_computation : user->called_computations()) { + const CallGraphNode& call_graph_node = + call_graph_->GetNode(called_computation); + if (call_graph_node.context() == CallContext::kSequential) { + for (int64 operand_number : user->OperandIndices(instruction)) { + worklist.push( + called_computation->parameter_instruction(operand_number)); + } } } } @@ -483,7 +621,8 @@ void HloDataflowAnalysis::UpdateInstructionsAndPropagate( const CallGraphNode& call_graph_node = call_graph_->GetNode(instruction->parent()); for (const CallSite& callsite : call_graph_node.caller_callsites()) { - if (callsite.instruction()->opcode() == HloOpcode::kCall) { + if ((callsite.instruction()->opcode() == HloOpcode::kCall) || + (callsite.instruction()->opcode() == HloOpcode::kConditional)) { worklist.push(callsite.instruction()); } else if (callsite.instruction()->opcode() == HloOpcode::kWhile) { // Add the while itself, and the body and condition parameters. @@ -537,6 +676,12 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() { GetValueSet(instruction, /*index=*/{}).AddValue(value); }; + // Lambda to set the value set at the given index of the output. + auto define_value_at = [this, &instruction](const ShapeIndex& index) { + HloValue* value = NewHloValue(instruction, index, /*is_phi=*/false); + GetValueSet(instruction, index).AddValue(value); + }; + switch (instruction->opcode()) { case HloOpcode::kBitcast: if (bitcast_defines_value_) { @@ -545,6 +690,7 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() { break; case HloOpcode::kWhile: case HloOpcode::kCall: + case HloOpcode::kConditional: case HloOpcode::kGetTupleElement: // These instructions define no values. The values in their output // flow from their operands or from cross computation dataflow. @@ -577,6 +723,16 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() { // values flow from their operands. define_top_level_only(); break; + case HloOpcode::kRecvDone: + // RecvDone aliases its input tuple element {0}, therefore does not + // define any values. + break; + case HloOpcode::kSend: + // Send produces a tuple of {aliased operand, U32 context}, therefore + // only defines the top-level tuple and the tuple element at {1}. + define_value_at(/*index=*/{}); + define_value_at(/*index=*/{1}); + break; default: define_all_values(); break; @@ -597,20 +753,17 @@ StatusOr> HloDataflowAnalysis::Run( new HloDataflowAnalysis(module, ssa_form, bitcast_defines_value)); TF_RETURN_IF_ERROR(dataflow_analysis->InitializeInstructionValueSets()); + dataflow_analysis->Propagate(); - // Construct list of all instructions to initialize the worklist to propagate - // the data flow. For efficiency sort the instruction in post order so - // producers appear before consumers. - std::vector all_instructions; - for (const HloComputation* computation : module->MakeComputationPostOrder()) { - for (HloInstruction* instruction : - computation->MakeInstructionPostOrder()) { - all_instructions.push_back(instruction); - } - } - dataflow_analysis->UpdateInstructionsAndPropagate(all_instructions); + // Delete all values marked for deletion. + dataflow_analysis->DeleteMarkedValues(); - // Add in positions to all values. + // Gather and set all non-definition positions of all values. Value deletion + // is rare, so just use a vector indexed by Value::Id rather than a map from + // Value::Id to positions. There should be very few holes in the vector, and + // lookup is faster. + std::vector> value_positions( + dataflow_analysis->next_value_id_); for (const HloComputation* computation : module->computations()) { for (HloInstruction* instruction : computation->instructions()) { for (const auto& pair : @@ -619,13 +772,18 @@ StatusOr> HloDataflowAnalysis::Run( const HloValueSet& value_set = pair.second; for (const HloValue* value : value_set.values()) { if (value->defining_instruction() != instruction) { - dataflow_analysis->GetValue(value->id()) - .AddPosition(instruction, index); + value_positions[value->id()].push_back( + HloPosition{instruction, index}); } } } } } + for (auto& pair : dataflow_analysis->values_) { + HloValue::Id value_id = pair.first; + HloValue& value = pair.second; + value.SetPositionsAndComputeUses(value_positions[value_id]); + } // Construct vector of values. dataflow_analysis->values_vector_.reserve(dataflow_analysis->values_.size()); diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h index 207e553bf7fb62e19b9fa89eaf6bfb3234592c11..469620d01295f90e0c36a48cac9be47c12473a68 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h @@ -126,13 +126,16 @@ class HloDataflowAnalysis { HloValue* NewHloValue(HloInstruction* instruction, const ShapeIndex& index, bool is_phi = false); - // Delete the HloValue with the given ID. - void DeleteHloValue(HloValue::Id value_id); + // Mark the HloValue with the given ID for deletion. + void MarkValueForDeletion(HloValue::Id value_id); + + // Delete all HloValues marked for deletion. Should be called after + // propagation is complete. + void DeleteMarkedValues(); // Constructs and initializes the InstructionValueSets of all instructions to // contain exactly the HloValues defined by each instruction. These values can - // then propagated throughout the HLO graph by calling - // UpdateInstructionsAndPropagate. + // then propagated throughout the HLO graph by calling Propagate. Status InitializeInstructionValueSets(); // Updates the value set of the given instruction based on the values flowing @@ -143,17 +146,18 @@ class HloDataflowAnalysis { // the instruction value set changed. bool UpdateBitcastValueSet(HloInstruction* bitcast); bool UpdateCallValueSet(HloInstruction* call); + bool UpdateConditionalValueSet(HloInstruction* conditional); bool UpdateCopyValueSet(HloInstruction* copy); bool UpdateGetTupleElementValueSet(HloInstruction* gte); bool UpdateParameterValueSet(HloInstruction* parameter); + bool UpdateRecvDoneValueSet(HloInstruction* recv_done); bool UpdateSelectValueSet(HloInstruction* select); + bool UpdateSendValueSet(HloInstruction* send); bool UpdateTupleValueSet(HloInstruction* tuple); bool UpdateWhileValueSet(HloInstruction* xla_while); - // Update the value sets of the given instructions and propagate the - // changes to fixed point. - void UpdateInstructionsAndPropagate( - tensorflow::gtl::ArraySlice instructions); + // Propagate the dataflow through the module. + void Propagate(); // Return the result of the SSA Phi function applied to the given inputs at // the given instruction. If skip_top_level is true, then the top level of the @@ -189,6 +193,11 @@ class HloDataflowAnalysis { // A map from instruction to InstructionValueSet. std::unordered_map value_sets_; + // Values marked for deletion during construction. We don't delete them + // immediately because references to them may remain in ValueSets temporarily + // during propagation. After construction, these values are deleted. + std::vector value_ids_to_delete_; + // A vector containing all HloValues sorted by HloValue::Id. std::vector values_vector_; diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index 4b8eb237a6712804657bb7b67cdde9a2d331bd11..e714b2567fd1b3eab607a19f0bb7e3288150dc64 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -34,6 +34,7 @@ limitations under the License. namespace xla { namespace { +using ::testing::ElementsAre; using ::testing::UnorderedElementsAre; // Test is parameterized on a bool which is whether the dataflow analysis is @@ -77,11 +78,23 @@ class HloDataflowAnalysisTest : public HloTestBase, analysis_->GetValueDefinedAt(b), *analysis_); } + std::unique_ptr CreateR0F32UnaryOpComputation( + HloOpcode opcode) { + HloComputation::Builder builder(TestName() + "." + HloOpcodeString(opcode)); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "param0")); + builder.AddInstruction( + HloInstruction::CreateUnary(scalar_shape_, opcode, param0)); + return builder.Build(); + } + std::unique_ptr module_; std::unique_ptr analysis_; const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {}); const Shape vector_shape_ = ShapeUtil::MakeShape(F32, {42}); + const Shape tuple_shape_ = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}); }; TEST_P(HloDataflowAnalysisTest, BinaryOperation) { @@ -211,10 +224,10 @@ TEST_P(HloDataflowAnalysisTest, NestedTuple) { HloPosition{nested_tuple, {0, 0}}, HloPosition{nested_tuple, {1, 0}}, HloPosition{nested_tuple, {2}}, HloPosition{gte_tuple, {0}}, HloPosition{gte_out, {}})); - // Constant values should have no uses though one is live out. The positions - // where they appear as operands are on instructions which do not use the - // values (eg, Tuple). - EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).uses().empty()); + // Constant values should have only a single use, which is the root of the + // computation. + EXPECT_THAT(analysis.GetValueDefinedAt(constant1, /*index=*/{}).uses(), + UnorderedElementsAre(HloUse{gte_out, 0, {0}})); EXPECT_TRUE(analysis.GetValueDefinedAt(constant2).uses().empty()); // The top-level tuple values are used in GTE instructions. @@ -274,12 +287,11 @@ TEST_P(HloDataflowAnalysisTest, SingleCall) { EXPECT_EQ(analysis.GetUniqueValueAt(call), analysis.GetValueDefinedAt(add)); EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(), - UnorderedElementsAre(HloUse{add, 0, {}})); + UnorderedElementsAre(HloUse{call, 0, {}}, HloUse{add, 0, {}})); EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(), - UnorderedElementsAre(HloUse{add, 1, {}})); + UnorderedElementsAre(HloUse{call, 1, {}}, HloUse{add, 1, {}})); EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module()); - EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_computation()); } TEST_P(HloDataflowAnalysisTest, ComputationCalledTwiceWithSameArguments) { @@ -323,18 +335,17 @@ TEST_P(HloDataflowAnalysisTest, ComputationCalledTwiceWithSameArguments) { EXPECT_TRUE(analysis.ValueIsDefinedAt(sub)); EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(), - UnorderedElementsAre(HloUse{add, 0, {}})); + UnorderedElementsAre(HloUse{call1, 0, {}}, HloUse{call2, 0, {}}, + HloUse{add, 0, {}})); EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(), - UnorderedElementsAre(HloUse{add, 1, {}})); + UnorderedElementsAre(HloUse{call1, 1, {}}, HloUse{call2, 1, {}}, + HloUse{add, 1, {}})); // The Add from the subcomputation is used as both operands of the Subtract. EXPECT_THAT(analysis.GetValueDefinedAt(add).uses(), UnorderedElementsAre(HloUse{sub, 0, {}}, HloUse{sub, 1, {}})); EXPECT_FALSE(analysis.GetValueDefinedAt(add).live_out_of_module()); - EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_computation()); - EXPECT_TRUE(analysis.GetValueDefinedAt(sub).live_out_of_module()); - EXPECT_TRUE(analysis.GetValueDefinedAt(sub).live_out_of_computation()); } TEST_P(HloDataflowAnalysisTest, ComputationCalledTwiceWithDifferentArguments) { @@ -408,7 +419,7 @@ TEST_P(HloDataflowAnalysisTest, NestedCalls) { auto outer_param1 = outer_builder.AddInstruction( HloInstruction::CreateParameter(1, scalar_shape_, "param1")); // Swizzle parameters. - outer_builder.AddInstruction(HloInstruction::CreateCall( + auto nested_call = outer_builder.AddInstruction(HloInstruction::CreateCall( scalar_shape_, {outer_param1, outer_param0}, inner_computation)); HloComputation* outer_computation = module_->AddEmbeddedComputation(outer_builder.Build()); @@ -418,7 +429,7 @@ TEST_P(HloDataflowAnalysisTest, NestedCalls) { HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(2.0))); - builder.AddInstruction(HloInstruction::CreateCall( + auto call = builder.AddInstruction(HloInstruction::CreateCall( scalar_shape_, {constant1, constant2}, outer_computation)); module_->AddEntryComputation(builder.Build()); @@ -431,10 +442,14 @@ TEST_P(HloDataflowAnalysisTest, NestedCalls) { // Verify that the uses of the constants are properly swizzled by parameter // permutation in nested_call. - EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(), - UnorderedElementsAre(HloUse{add, 1, {}})); - EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(), - UnorderedElementsAre(HloUse{add, 0, {}})); + EXPECT_THAT( + analysis.GetValueDefinedAt(constant1).uses(), + UnorderedElementsAre(HloUse{call, 0, {}}, HloUse{nested_call, 1, {}}, + HloUse{add, 1, {}})); + EXPECT_THAT( + analysis.GetValueDefinedAt(constant2).uses(), + UnorderedElementsAre(HloUse{call, 1, {}}, HloUse{nested_call, 0, {}}, + HloUse{add, 0, {}})); EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module()); } @@ -469,7 +484,7 @@ TEST_P(HloDataflowAnalysisTest, SingleWhile) { HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1)); auto add = body_builder.AddInstruction(HloInstruction::CreateBinary( scalar_shape_, HloOpcode::kAdd, body_element_0, body_element_1)); - body_builder.AddInstruction( + auto body_root = body_builder.AddInstruction( HloInstruction::CreateTuple({body_element_0, add})); HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build()); @@ -496,8 +511,6 @@ TEST_P(HloDataflowAnalysisTest, SingleWhile) { bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); - EXPECT_TRUE( - analysis.GetValueDefinedAt(cond_constant).live_out_of_computation()); EXPECT_FALSE(analysis.GetValueDefinedAt(cond_constant).live_out_of_module()); if (ssa_form) { @@ -517,14 +530,14 @@ TEST_P(HloDataflowAnalysisTest, SingleWhile) { EXPECT_THAT( analysis.GetValueDefinedAt(constant1).uses(), - UnorderedElementsAre(HloUse{add, 0, {}}, HloUse{xla_while, 0, {0}})); + UnorderedElementsAre(HloUse{add, 0, {}}, HloUse{body_root, 0, {}}, + HloUse{xla_while, 0, {0}})); // Constant1 passes through the body and out of the module. EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module()); EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{1}) .live_out_of_module()); - EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_computation()); EXPECT_FALSE(analysis.GetValueDefinedAt(add).live_out_of_module()); } else { // While instruction and subcomputation parameters should not define values @@ -538,7 +551,6 @@ TEST_P(HloDataflowAnalysisTest, SingleWhile) { EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module()); EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module()); - EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_computation()); } } @@ -915,9 +927,11 @@ TEST_P(HloDataflowAnalysisTest, TupleSelect) { HloUse{select12, 1, {}})); // The two constant values just pass through the Selects and are not - // used. They are live out however. - EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).uses().empty()); - EXPECT_TRUE(analysis.GetValueDefinedAt(constant2).uses().empty()); + // used except at the root. They are live out however. + EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(), + UnorderedElementsAre(HloUse{select1234, 1, {0}})); + EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(), + UnorderedElementsAre(HloUse{select1234, 1, {0}})); EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module()); EXPECT_TRUE(analysis.GetValueDefinedAt(constant2).live_out_of_module()); } @@ -1139,6 +1153,54 @@ TEST_P(HloDataflowAnalysisTest, TupleCopy) { analysis.GetValueDefinedAt(copy, /*index=*/{}).live_out_of_module()); } +TEST_P(HloDataflowAnalysisTest, SendAndSendDone) { + // Test that a Send forwards its operand to the output tuple at {0}. + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "param0")); + auto send = builder.AddInstruction( + HloInstruction::CreateSend(param, /*channel_id=*/0)); + auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send)); + module_->AddEntryComputation(builder.Build()); + + bool ssa_form = GetParam(); + const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); + + EXPECT_EQ(analysis.values().size(), 4); + + EXPECT_TRUE(analysis.ValueIsDefinedAt(param)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(send, /*index=*/{})); + EXPECT_FALSE(analysis.ValueIsDefinedAt(send, /*index=*/{0})); + EXPECT_TRUE(analysis.ValueIsDefinedAt(send, /*index=*/{1})); + EXPECT_TRUE(analysis.ValueIsDefinedAt(send_done)); + EXPECT_THAT(HloValuesAt(send, /*index=*/{0}), + UnorderedElementsAre(analysis.GetValueDefinedAt(param))); +} + +TEST_P(HloDataflowAnalysisTest, RecvAndRecvDone) { + // Test that a RecvDone forwards its operand tuple element at {0} to the + // output. + auto builder = HloComputation::Builder(TestName()); + auto recv = builder.AddInstruction( + HloInstruction::CreateRecv(scalar_shape_, /*channel_id=*/0)); + auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv)); + module_->AddEntryComputation(builder.Build()); + + bool ssa_form = GetParam(); + const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); + + EXPECT_EQ(analysis.values().size(), 3); + + EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{})); + EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{0})); + EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{1})); + EXPECT_FALSE(analysis.ValueIsDefinedAt(recv_done)); + EXPECT_THAT(HloValuesAt(recv_done), + UnorderedElementsAre(analysis.GetValueDefinedAt(recv, {0}))); + EXPECT_TRUE( + analysis.GetValueDefinedAt(recv, /*index=*/{0}).live_out_of_module()); +} + TEST_P(HloDataflowAnalysisTest, ElementwiseChainInterference) { // A simple chain of elementwise operations. No values should interfere. // @@ -1270,7 +1332,7 @@ TEST_P(HloDataflowAnalysisTest, WhileParameters_Sequential) { auto entry = module_->AddEntryComputation(builder.Build()); bool ssa_form = GetParam(); - const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); + RunAnalysis(ssa_form); SequentialHloOrdering::HloModuleSequence sequence; sequence.insert({entry, {param, xla_while}}); @@ -1281,12 +1343,6 @@ TEST_P(HloDataflowAnalysisTest, WhileParameters_Sequential) { SequentialHloOrdering ordering(module_.get(), sequence); - // 'add' is the body root even though later instructions follow in the order - // like 'dead_negate'. Only 'add' should be live out of the computation. - EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_computation()); - EXPECT_FALSE( - analysis.GetValueDefinedAt(dead_negate).live_out_of_computation()); - // 'add' is live out of the body and will interfere with an later instructions // such as 'dead_constant' and 'dead_negate'. EXPECT_TRUE(InstructionsMayInterfere(ordering, add, dead_constant)); @@ -1485,6 +1541,315 @@ TEST_P(HloDataflowAnalysisTest, EmbeddedComputationInterference) { EXPECT_TRUE(InstructionsMayInterfere(ordering, negate, embedded_log)); } +TEST_P(HloDataflowAnalysisTest, ConditionalWithIdentity) { + // Test conditional with identity computations in both true and false cases. + // + // true_computation(F32[] %true_param): + // return %true_param + // + // false_computation(F32[] %false_param): + // return %false_param + // + // entry: + // %pred = Constant(true) + // %constant1 = Constant(56.0) + // %constant2 = Constant(12.0) + // return Conditional(%pred, %constant1, true_computation, + // %constant2, false_computation) + + auto true_builder = HloComputation::Builder(TestName() + "_true"); + auto true_param = true_builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "true_param")); + HloComputation* true_computation = + module_->AddEmbeddedComputation(true_builder.Build()); + + auto false_builder = HloComputation::Builder(TestName() + "_false"); + auto false_param = false_builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "false_param")); + HloComputation* false_computation = + module_->AddEmbeddedComputation(false_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + auto pred = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(true))); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(56.0f))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(12.0f))); + auto conditional = builder.AddInstruction(HloInstruction::CreateConditional( + scalar_shape_, pred, constant1, true_computation, constant2, + false_computation)); + module_->AddEntryComputation(builder.Build()); + + const HloDataflowAnalysis& analysis = RunAnalysis(GetParam()); + + EXPECT_TRUE(analysis.ValueIsDefinedAt(pred)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(constant1)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(constant2)); + + EXPECT_FALSE(analysis.ValueIsDefinedAt(true_param)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(false_param)); + + EXPECT_EQ(analysis.GetUniqueValueAt(true_param), + analysis.GetValueDefinedAt(constant1)); + EXPECT_EQ(analysis.GetUniqueValueAt(false_param), + analysis.GetValueDefinedAt(constant2)); + + EXPECT_THAT(analysis.GetValueDefinedAt(pred).uses(), + ElementsAre(HloUse{conditional, 0, {}})); + EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(), + ElementsAre(HloUse{conditional, 1, {}})); + EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(), + ElementsAre(HloUse{conditional, 2, {}})); + + EXPECT_EQ(analysis.values().size(), 3); + EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional)); + EXPECT_THAT(HloValuesAt(conditional), + UnorderedElementsAre(analysis.GetValueDefinedAt(constant1), + analysis.GetValueDefinedAt(constant2))); +} + +TEST_P(HloDataflowAnalysisTest, ConditionalTakingTupleOperand) { + // Test conditional with true and false computations taking a tuple operand. + // + // true_computation((F32[], F32[]) %true_param): + // %true_x = GetTupleElement(%true_param, 0) + // %true_y = GetTupleElement(%true_param, 1) + // return Add(%true_x, %true_y) + // + // false_computation((F32[], F32[]) %false_param): + // %false_x = GetTupleElement(%false_param, 0) + // %false_y = GetTupleElement(%false_param, 1) + // return Subtract(%false_x, %false_y) + // + // entry: + // %pred = Constant(true) + // %constant1 = Constant(56.0) + // %constant2 = Constant(12.0) + // %tuple_operand = Tuple(%constant1, %constant2) + // return Conditional(%pred, %tuple_operand, true_computation, + // %tuple_operand, false_computation) + + auto true_builder = HloComputation::Builder(TestName() + "_true"); + auto true_param = true_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape_, "true_param")); + auto true_x = true_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, true_param, 0)); + auto true_y = true_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, true_param, 1)); + auto add = true_builder.AddInstruction(HloInstruction::CreateBinary( + scalar_shape_, HloOpcode::kAdd, true_x, true_y)); + HloComputation* true_computation = + module_->AddEmbeddedComputation(true_builder.Build()); + + auto false_builder = HloComputation::Builder(TestName() + "_false"); + auto false_param = false_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape_, "false_param")); + auto false_x = false_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, false_param, 0)); + auto false_y = false_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, false_param, 1)); + auto sub = false_builder.AddInstruction(HloInstruction::CreateBinary( + scalar_shape_, HloOpcode::kSubtract, false_x, false_y)); + HloComputation* false_computation = + module_->AddEmbeddedComputation(false_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + auto pred = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(true))); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(56.0f))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(12.0f))); + auto tuple_operand = builder.AddInstruction( + HloInstruction::CreateTuple({constant1, constant2})); + auto conditional = builder.AddInstruction(HloInstruction::CreateConditional( + scalar_shape_, pred, tuple_operand, true_computation, tuple_operand, + false_computation)); + module_->AddEntryComputation(builder.Build()); + + const HloDataflowAnalysis& analysis = RunAnalysis(GetParam()); + + EXPECT_TRUE(analysis.ValueIsDefinedAt(pred)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(constant1)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(constant2)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(tuple_operand)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(add)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(sub)); + + EXPECT_FALSE(analysis.ValueIsDefinedAt(true_param)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(false_param)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(true_x)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(true_y)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(false_x)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(false_y)); + + EXPECT_EQ(analysis.GetUniqueValueAt(true_param), + analysis.GetValueDefinedAt(tuple_operand)); + EXPECT_EQ(analysis.GetUniqueValueAt(false_param), + analysis.GetValueDefinedAt(tuple_operand)); + EXPECT_EQ(analysis.GetUniqueValueAt(true_x), + analysis.GetValueDefinedAt(constant1)); + EXPECT_EQ(analysis.GetUniqueValueAt(true_y), + analysis.GetValueDefinedAt(constant2)); + EXPECT_EQ(analysis.GetUniqueValueAt(false_x), + analysis.GetValueDefinedAt(constant1)); + EXPECT_EQ(analysis.GetUniqueValueAt(false_y), + analysis.GetValueDefinedAt(constant2)); + + EXPECT_THAT(analysis.GetValueDefinedAt(pred).uses(), + ElementsAre(HloUse{conditional, 0, {}})); + EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(), + UnorderedElementsAre(HloUse{conditional, 1, {0}}, + HloUse{conditional, 2, {0}}, + HloUse{add, 0, {}}, HloUse{sub, 0, {}})); + EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(), + UnorderedElementsAre(HloUse{conditional, 1, {1}}, + HloUse{conditional, 2, {1}}, + HloUse{add, 1, {}}, HloUse{sub, 1, {}})); + EXPECT_THAT(analysis.GetValueDefinedAt(tuple_operand).uses(), + UnorderedElementsAre( + HloUse{conditional, 1, {}}, HloUse{conditional, 2, {}}, + HloUse{true_x, 0, {}}, HloUse{true_y, 0, {}}, + HloUse{false_x, 0, {}}, HloUse{false_y, 0, {}})); + + EXPECT_EQ(analysis.values().size(), 6); + EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional)); + EXPECT_THAT(HloValuesAt(conditional), + UnorderedElementsAre(analysis.GetValueDefinedAt(add), + analysis.GetValueDefinedAt(sub))); +} + +TEST_P(HloDataflowAnalysisTest, NestedConditionals) { + // computation1(F32[] %param1): + // %ceil = Ceil(%param1) + // return %ceil + // + // computation2(F32[] %param2): + // %floor = Floor(%param2) + // return %floor + // + // computation3(F32[] %param3): + // %negate = Negate(%param3) + // return %negate + // + // inner_conditional((PRED, F32[], F32[]) %param_cond): + // %pred_cond = GetTupleElement(%param_cond, 0) + // %true_operand_cond = GetTupleElement(%param_cond, 1) + // %false_opearnd_cond = GetTupleElement(%param_cond, 2) + // return Conditional(%pred_cond, %true_operand_cond, computation1, + // %false_operand_cond, computation2) + // + // entry: + // %pred1 = Constant(true) + // %pred2 = Constant(false) + // %constant1 = Constant(1.1); + // %constant2 = Constant(2.2); + // %constant3 = Constant(3.3); + // return Conditional(%pred1, (%pred2, %constant1, %constant2), + // inner_conditional, %constant3, computation3) + + auto computation1 = module_->AddEmbeddedComputation( + CreateR0F32UnaryOpComputation(HloOpcode::kCeil)); + auto computation2 = module_->AddEmbeddedComputation( + CreateR0F32UnaryOpComputation(HloOpcode::kFloor)); + auto computation3 = module_->AddEmbeddedComputation( + CreateR0F32UnaryOpComputation(HloOpcode::kNegate)); + + // Build inner_conditional computation. + const Shape scalar_bool_shape = ShapeUtil::MakeShape(PRED, {}); + const Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + {scalar_bool_shape, scalar_shape_, scalar_shape_}); + auto inner_builder = + HloComputation::Builder(TestName() + "_inner_conditional"); + auto param_cond = inner_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_param_shape, "param_cond")); + auto pred_cond = inner_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_bool_shape, param_cond, 0)); + auto true_operand_cond = inner_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param_cond, 1)); + auto false_operand_cond = inner_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param_cond, 2)); + auto inner_conditional = + inner_builder.AddInstruction(HloInstruction::CreateConditional( + scalar_shape_, pred_cond, true_operand_cond, computation1, + false_operand_cond, computation2)); + auto inner_conditional_computation = + module_->AddEmbeddedComputation(inner_builder.Build()); + + // Build entry computation. + auto builder = HloComputation::Builder(TestName()); + auto pred1 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(true))); + auto pred2 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(false))); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.1f))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(2.2f))); + auto constant3 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(3.3f))); + auto tuple_operand = builder.AddInstruction( + HloInstruction::CreateTuple({pred2, constant1, constant2})); + auto conditional = builder.AddInstruction(HloInstruction::CreateConditional( + scalar_shape_, pred1, tuple_operand, inner_conditional_computation, + constant3, computation3)); + module_->AddEntryComputation(builder.Build()); + + const HloDataflowAnalysis& analysis = RunAnalysis(GetParam()); + + EXPECT_TRUE(analysis.ValueIsDefinedAt(pred1)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(pred2)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(constant1)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(constant2)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(constant3)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(tuple_operand)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(computation1->root_instruction())); + EXPECT_TRUE(analysis.ValueIsDefinedAt(computation2->root_instruction())); + EXPECT_TRUE(analysis.ValueIsDefinedAt(computation3->root_instruction())); + + auto computation1_param = computation1->parameter_instruction(0); + auto computation2_param = computation2->parameter_instruction(0); + auto computation3_param = computation3->parameter_instruction(0); + EXPECT_FALSE(analysis.ValueIsDefinedAt(computation1_param)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(computation2_param)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(computation3_param)); + EXPECT_EQ(analysis.GetUniqueValueAt(computation1_param), + analysis.GetValueDefinedAt(constant1)); + EXPECT_EQ(analysis.GetUniqueValueAt(computation2_param), + analysis.GetValueDefinedAt(constant2)); + EXPECT_EQ(analysis.GetUniqueValueAt(computation3_param), + analysis.GetValueDefinedAt(constant3)); + + EXPECT_FALSE(analysis.ValueIsDefinedAt(param_cond)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(pred_cond)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(true_operand_cond)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(false_operand_cond)); + EXPECT_EQ(analysis.GetUniqueValueAt(param_cond), + analysis.GetValueDefinedAt(tuple_operand)); + EXPECT_EQ(analysis.GetUniqueValueAt(pred_cond), + analysis.GetValueDefinedAt(pred2)); + EXPECT_EQ(analysis.GetUniqueValueAt(true_operand_cond), + analysis.GetValueDefinedAt(constant1)); + EXPECT_EQ(analysis.GetUniqueValueAt(false_operand_cond), + analysis.GetValueDefinedAt(constant2)); + + EXPECT_EQ(analysis.values().size(), 9); + EXPECT_FALSE(analysis.ValueIsDefinedAt(inner_conditional)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional)); + EXPECT_THAT( + HloValuesAt(inner_conditional), + UnorderedElementsAre( + analysis.GetValueDefinedAt(computation1->root_instruction()), + analysis.GetValueDefinedAt(computation2->root_instruction()))); + EXPECT_THAT( + HloValuesAt(conditional), + UnorderedElementsAre( + analysis.GetValueDefinedAt(computation1->root_instruction()), + analysis.GetValueDefinedAt(computation2->root_instruction()), + analysis.GetValueDefinedAt(computation3->root_instruction()))); +} + INSTANTIATE_TEST_CASE_P(HloDataflowAnalysisInstantiation, HloDataflowAnalysisTest, ::testing::Values(false, true)); diff --git a/tensorflow/compiler/xla/service/hlo_dce.cc b/tensorflow/compiler/xla/service/hlo_dce.cc index a4921232f5848dbe1789c4c641e2b0ba3c1848bb..1e5f0f797a13fd7e7ce1cc934387a274a74153bc 100644 --- a/tensorflow/compiler/xla/service/hlo_dce.cc +++ b/tensorflow/compiler/xla/service/hlo_dce.cc @@ -37,6 +37,9 @@ namespace xla { StatusOr HloDCE::Run(HloModule* module) { bool changed = false; + VLOG(2) << "Before dce:"; + XLA_VLOG_LINES(2, module->ToString()); + for (auto* computation : module->MakeNonfusionComputations()) { std::unordered_set live_instructions; TF_RETURN_IF_ERROR(computation->root_instruction()->Accept( @@ -52,12 +55,15 @@ StatusOr HloDCE::Run(HloModule* module) { for (auto* instruction : computation->instructions()) { if (instruction->user_count() == 0 && live_instructions.count(instruction) == 0 && - computation->IsRemovable(instruction)) { + computation->IsRemovable(instruction) && + !instruction->HasSideEffect()) { dead_roots.push_back(instruction); } } for (HloInstruction* dead_root : dead_roots) { + VLOG(1) << "Removing dead root " << dead_root->ToString() + << " and it's unused operands"; TF_RETURN_IF_ERROR( computation->RemoveInstructionAndUnusedOperands(dead_root)); changed = true; @@ -87,6 +93,9 @@ StatusOr HloDCE::Run(HloModule* module) { } } + VLOG(2) << "After dce:"; + XLA_VLOG_LINES(2, module->ToString()); + return changed; } diff --git a/tensorflow/compiler/xla/service/hlo_dce_test.cc b/tensorflow/compiler/xla/service/hlo_dce_test.cc index d54b9a27087a42fd23eab0bd06e8deaca567312b..5a56607a665c4cbeb7b2572f182b88e890602968 100644 --- a/tensorflow/compiler/xla/service/hlo_dce_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dce_test.cc @@ -70,6 +70,26 @@ TEST_F(HloDceTest, NoDeadCode) { EXPECT_EQ(3, computation->instruction_count()); } +TEST_F(HloDceTest, InstructionsWithSideEffect) { + // Verify that side-effect instructions (Send in this test) are not removed. + auto builder = HloComputation::Builder(TestName()); + auto constant = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + builder.AddInstruction( + HloInstruction::CreateSend(constant, /*channel_id=*/0)); + builder.AddInstruction(HloInstruction::CreateTuple({})); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_EQ(3, computation->instruction_count()); + + HloDCE dce; + EXPECT_FALSE(dce.Run(module.get()).ValueOrDie()); + + EXPECT_EQ(3, computation->instruction_count()); +} + TEST_F(HloDceTest, DeadParameters) { // Verify that dead parameters are not removed, but use of the dead parameters // are. diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc new file mode 100644 index 0000000000000000000000000000000000000000..1773bb401d380031f6c860d295e76d2f62c9e5ff --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc @@ -0,0 +1,137 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_element_type_converter.h" + +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_query.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace xla { +namespace { + +HloInstruction* ToElementType(HloInstruction* hlo, PrimitiveType type) { + if (hlo->shape().element_type() != type) { + Shape shape = ShapeUtil::ChangeElementType(hlo->shape(), type); + hlo = hlo->parent()->AddInstruction( + HloInstruction::CreateConvert(shape, hlo)); + } + CHECK_EQ(hlo->shape().element_type(), type); + return hlo; +} + +bool HasOperandType(HloInstruction* hlo, PrimitiveType type) { + for (HloInstruction* operand : hlo->operands()) { + if (operand->shape().element_type() == type) { + return true; + } + } + return false; +} + +} // namespace + +HloElementTypeConverter::HloElementTypeConverter( + PrimitiveType eliminate_type, PrimitiveType replace_with_type) + : eliminate_type_(eliminate_type), replace_with_type_(replace_with_type) {} + +StatusOr HloElementTypeConverter::Run(HloModule* module) { + XLA_VLOG_LINES( + 3, "HloElementTypeConverter::Run(), before:\n" + module->ToString()); + bool changed = false; + for (auto* computation : module->computations()) { + for (auto* hlo : computation->MakeInstructionPostOrder()) { + // These are ops where it does not make sense to convert them. + if (hlo->opcode() == HloOpcode::kParameter || + hlo->opcode() == HloOpcode::kConstant || + hlo->opcode() == HloOpcode::kTuple || + hlo->opcode() == HloOpcode::kConvert || + hlo->opcode() == HloOpcode::kGetTupleElement || + hlo->opcode() == HloOpcode::kInfeed || + hlo->opcode() == HloOpcode::kOutfeed) { + continue; + } + + // We cannot change a CustomCall since we have no way of adjusting the + // called binary to expect the updated type. + if (hlo->opcode() == HloOpcode::kCustomCall) { + continue; + } + + // These are ops with embedded computations where it suffices to convert + // the embedded computations instead of converting the ops themselves. + if (hlo->opcode() == HloOpcode::kWhile || + hlo->opcode() == HloOpcode::kCall || + hlo->opcode() == HloOpcode::kFusion || + hlo->opcode() == HloOpcode::kMap || + hlo->opcode() == HloOpcode::kReduce || + hlo->opcode() == HloOpcode::kReduceWindow || + hlo->opcode() == HloOpcode::kSelectAndScatter || + hlo->opcode() == HloOpcode::kConditional) { + continue; + } + TF_RET_CHECK(hlo->called_computations().empty()) << hlo->ToString(); + + if (!HasOperandType(hlo, eliminate_type_)) { + // If this CHECK fires, then this was an instruction that does not take + // the elimination type as an operand but it does return it. This pass + // does not have a feature to change the output type in that case, so + // instead of silently failing to eliminate the type, it fails loudly. + TF_RET_CHECK(hlo->shape().element_type() != eliminate_type_); + continue; + } + + std::vector new_operands; + for (HloInstruction* operand : hlo->operands()) { + if (operand->shape().element_type() == eliminate_type_) { + operand = ToElementType(operand, replace_with_type_); + } + new_operands.push_back(operand); + } + + HloInstruction* new_hlo; + if (hlo->shape().element_type() == eliminate_type_) { + Shape shape = + ShapeUtil::ChangeElementType(hlo->shape(), replace_with_type_); + new_hlo = computation->AddInstruction( + hlo->CloneWithNewOperands(shape, new_operands, hlo->GetModule())); + new_hlo = ToElementType(new_hlo, eliminate_type_); + } else { + new_hlo = computation->AddInstruction(hlo->CloneWithNewOperands( + hlo->shape(), new_operands, hlo->GetModule())); + } + TF_RETURN_IF_ERROR(computation->ReplaceInstruction(hlo, new_hlo)); + changed = true; + } + } + XLA_VLOG_LINES( + 2, "HloElementTypeConverter::Run(), after:\n" + module->ToString()); + return changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter.h b/tensorflow/compiler/xla/service/hlo_element_type_converter.h new file mode 100644 index 0000000000000000000000000000000000000000..2b109225d0b192e5c9e4f6d841377ffad8078dc2 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_element_type_converter.h @@ -0,0 +1,49 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_ELEMENT_TYPE_CONVERTER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_ELEMENT_TYPE_CONVERTER_H_ + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// A pass that eliminates certain element types as the input or output of ops by +// inserting Convert ops. This allows a backend to support an element type while +// only actually implementing the Convert op for that element type. This is +// generally not the fastest approach, but it works. +class HloElementTypeConverter : public HloPassInterface { + public: + // eliminate_type is the type to eliminate as the input or output of ops, + // using Convert ops to replace it with replace_with_type. + HloElementTypeConverter(PrimitiveType eliminate_type, + PrimitiveType replace_with_type); + + tensorflow::StringPiece name() const override { + return "element_type_converter"; + } + + // Returns the pass on the module and returns whether the module was modified. + StatusOr Run(HloModule* module) override; + + private: + PrimitiveType eliminate_type_; + PrimitiveType replace_with_type_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_ELEMENT_TYPE_CONVERTER_H_ diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index 88b77ccdd03eb129f81cfa1da430e882ea569df4..e693d167a1f96f65b894d07fb2c8f33e61ff8c49 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -335,9 +335,31 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> + template ::value && + !std::is_same::value>::type* = nullptr> + Status HandleNot(HloInstruction* not_) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_], + ElementWiseUnaryOp(not_, [](ReturnT elem_operand) { + return ~elem_operand; + })); + return Status::OK(); + } + + template ::value>::type* = nullptr> + Status HandleNot(HloInstruction* not_) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_], + ElementWiseUnaryOp(not_, [](ReturnT elem_operand) { + return !elem_operand; + })); + return Status::OK(); + } + + template ::value>::type* = + nullptr> Status HandleNot(HloInstruction* not_) { TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_], ElementWiseUnaryOp(not_, [](ReturnT elem_operand) { @@ -357,7 +379,24 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return HandleNot(not_); } - Status HandleNegate(HloInstruction* negate) override { + template ::value && + !std::is_floating_point::value>::type* = nullptr> + Status HandleNegate(HloInstruction* negate) { + using type = typename std::make_unsigned::type; + TF_ASSIGN_OR_RETURN(parent_->evaluated_[negate], + ElementWiseUnaryOp(negate, [](ReturnT elem_operand) { + return NativeT(-type(elem_operand)); + })); + return Status::OK(); + } + + template ::value || + std::is_floating_point::value>::type* = nullptr> + Status HandleNegate(HloInstruction* negate) { TF_ASSIGN_OR_RETURN(parent_->evaluated_[negate], ElementWiseUnaryOp(negate, [](ReturnT elem_operand) { return -elem_operand; @@ -365,6 +404,10 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } + Status HandleNegate(HloInstruction* negate) override { + return HandleNegate(negate); + } + template < typename NativeT, typename std::enable_if::value>::type* = nullptr> @@ -402,7 +445,26 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } - Status HandleMultiply(HloInstruction* multiply) override { + template ::value && + !std::is_floating_point::value>::type* = nullptr> + Status HandleMultiply(HloInstruction* multiply) { + using type = typename std::make_unsigned::type; + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[multiply], + ElementWiseBinaryOp(multiply, [](ReturnT lhs_elem, ReturnT rhs_elem) { + return NativeT(type(lhs_elem) * type(rhs_elem)); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value || + std::is_floating_point::value || + is_complex_t::value>::type* = nullptr> + Status HandleMultiply(HloInstruction* multiply) { TF_ASSIGN_OR_RETURN( parent_->evaluated_[multiply], ElementWiseBinaryOp(multiply, [](ReturnT lhs_elem, ReturnT rhs_elem) { @@ -411,6 +473,10 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } + Status HandleMultiply(HloInstruction* multiply) override { + return HandleMultiply(multiply); + } + Status HandleSubtract(HloInstruction* subtract) override { TF_ASSIGN_OR_RETURN( parent_->evaluated_[subtract], @@ -516,9 +582,20 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return HandleRemainder(remainder); } - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> + template ::value>::type* = + nullptr> + Status HandleAnd(HloInstruction* and_) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[and_], + ElementWiseBinaryOp(and_, [](ReturnT lhs_el, ReturnT rhs_el) { + return lhs_el & rhs_el; + })); + return Status::OK(); + } + + template ::value>::type* = nullptr> Status HandleAnd(HloInstruction* and_) { TF_ASSIGN_OR_RETURN( parent_->evaluated_[and_], @@ -539,9 +616,20 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return HandleAnd(and_); } - template < - typename NativeT, - typename std::enable_if::value>::type* = nullptr> + template ::value>::type* = + nullptr> + Status HandleOr(HloInstruction* or_) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[or_], + ElementWiseBinaryOp(or_, [](ReturnT lhs_el, ReturnT rhs_el) { + return lhs_el | rhs_el; + })); + return Status::OK(); + } + + template ::value>::type* = nullptr> Status HandleOr(HloInstruction* or_) { TF_ASSIGN_OR_RETURN( parent_->evaluated_[or_], @@ -645,7 +733,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { typename std::enable_if::value>::type* = nullptr> Status HandleClamp(HloInstruction* clamp) { std::function clamp_op = - [](ReturnT low, ReturnT high, ReturnT value) { + [](ReturnT low, ReturnT value, ReturnT high) { return std::fmax(low, std::fmin(value, high)); }; TF_ASSIGN_OR_RETURN(parent_->evaluated_[clamp], @@ -724,7 +812,8 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { CHECK(ShapeUtil::SameElementType(lhs_shape, result_shape)); const auto& dnums = conv->convolution_dimension_numbers(); - const int64 num_spatial_dims = dnums.spatial_dimensions_size(); + const int64 num_spatial_dims = dnums.output_spatial_dimensions_size(); + CHECK_EQ(num_spatial_dims, dnums.input_spatial_dimensions_size()); CHECK_EQ(num_spatial_dims, dnums.kernel_spatial_dimensions_size()); CHECK_GE(num_spatial_dims, 0); CHECK_EQ(window.dimensions_size(), num_spatial_dims); @@ -789,13 +878,15 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { // Find corresponding spatial dimension index for input (lhs). for (int64 ki = 0; ki < rhs_spatial_index.size(); ++ki) { // Spatial dimension number for input (lhs) and output. - const int64 spatial_dim = dnums.spatial_dimensions(ki); + const int64 input_spatial_dim = dnums.input_spatial_dimensions(ki); + const int64 output_spatial_dim = + dnums.output_spatial_dimensions(ki); // Calculate lhs (input) index without taking base dilation into // account. const auto& window_dim = window.dimensions(ki); const int64 undilated_index = - out_index[spatial_dim] * window_dim.stride() - + out_index[output_spatial_dim] * window_dim.stride() - window_dim.padding_low() + rhs_spatial_index[ki] * window_dim.window_dilation(); // Skip if the lhs (input) index is to be dilated. @@ -804,23 +895,26 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { } // Calculate the actual lhs (input) index after dilation. - lhs_index[spatial_dim] = + lhs_index[input_spatial_dim] = undilated_index / window_dim.base_dilation(); // Skip if input index is not in bound. - if (!(lhs_index[spatial_dim] >= 0 && - lhs_index[spatial_dim] < lhs_shape.dimensions(spatial_dim))) { + if (!(lhs_index[input_spatial_dim] >= 0 && + lhs_index[input_spatial_dim] < + lhs_shape.dimensions(input_spatial_dim))) { goto cnt; } rhs_index[dnums.kernel_spatial_dimensions(ki)] = - rhs_spatial_index[ki]; + window_dim.window_reversal() + ? ((window_dim.size() - 1) - rhs_spatial_index[ki]) + : rhs_spatial_index[ki]; } result_val += lhs_literal.Get(lhs_index) * rhs_literal.Get(rhs_index); } - cnt:; + cnt : {} } while (IndexUtil::BumpIndices(window_shape, &rhs_spatial_index)); return result_val; @@ -1287,6 +1381,50 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } + template ::value>::type* = nullptr> + Status HandleSin(HloInstruction* sin) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[sin], + ElementWiseUnaryOp(sin, [](ReturnT elem_operand) { + return std::sin(elem_operand); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value || + is_complex_t::value>::type* = nullptr> + Status HandleSin(HloInstruction* sin) { + return InvalidArgument("Unsupported type for Sin"); + } + + Status HandleSin(HloInstruction* sin) override { + return HandleSin(sin); + } + + template ::value>::type* = nullptr> + Status HandleCos(HloInstruction* cos) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[cos], + ElementWiseUnaryOp(cos, [](ReturnT elem_operand) { + return std::cos(elem_operand); + })); + return Status::OK(); + } + + template < + typename NativeT, + typename std::enable_if::value || + is_complex_t::value>::type* = nullptr> + Status HandleCos(HloInstruction* cos) { + return InvalidArgument("Unsupported type for Cos"); + } + + Status HandleCos(HloInstruction* cos) override { + return HandleCos(cos); + } + private: template StatusOr> DynamicSlice( @@ -1397,8 +1535,8 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { const auto* rhs = instruction->operand(1); const auto* ehs = instruction->operand(2); - // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast is - // removed. + // TODO(b/35950897, b/27796129): add DCHECK back once implicit + // broadcast is removed. if (!(ShapeUtil::SameDimensions(shape, lhs->shape()) && ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()) && ShapeUtil::SameDimensions(rhs->shape(), ehs->shape()))) { @@ -1450,6 +1588,10 @@ HloEvaluator::HloEvaluator() { typed_visitors_[F32] = MakeUnique>(this); typed_visitors_[F64] = MakeUnique>(this); typed_visitors_[C64] = MakeUnique>(this); + + typed_visitors_[BF16] = MakeUnique([](HloInstruction*) { + return Unimplemented("HloEvaluator: unhandled primitive type: BF16."); + }); typed_visitors_[TUPLE] = MakeUnique([](HloInstruction*) { return Unimplemented("HloEvaluator: unhandled primitive type: TUPLE."); }); @@ -1561,6 +1703,7 @@ StatusOr> HloEvaluator::EvaluateWithSubstitutions( } std::vector operands; + operands.reserve(owned_operands.size()); for (auto& operand : owned_operands) { operands.push_back(operand.get()); } diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h index 67b6e215fcb23598f1a8ab6212d6e7e58a64e976..7557aaa2484d184555411a79d8dce2c9241427b0 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -39,16 +39,18 @@ class HloEvaluator : public DfsHloVisitorWithDefault { HloEvaluator(); // Evaluates an HLO module and an array of pointers to literals. // Returns the evaluated result as a literal if successful. - // Precondition: argument literals correspond to each input computation's - // parameters in their post-ordering. See comment below for example. + // Precondition: The indices of arg_literals correspond to the parameter + // numbers of the HLO parameters in the computation. See comment below for an + // example. StatusOr> Evaluate( const HloModule& module, tensorflow::gtl::ArraySlice arg_literals); // Evaluates an HLO computation and an array of pointers to literals. // Returns the evaluated result as a literal if successful. - // Precondition: argument literals correspond to the input computation's - // parameters in their post-ordering. For e.g., consider the following graph: + // Precondition: The indices of arg_literals correspond to the parameter + // numbers of the HLO parameters in the computation. For e.g., consider the + // following graph: // // * // / \ @@ -57,8 +59,9 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // / \ // Parameter0 Constant // - // The input literals array will have its first literal map to Parameter0 and - // the second map to Parameter1. + // where Parameter0 has parameter_number 0 and Parameter1 has parameter_number + // 1 in this computation. The input literals array will then have its first + // literal map to Parameter0 and the second map to Parameter1. StatusOr> Evaluate( const HloComputation& computation, tensorflow::gtl::ArraySlice arg_literals); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index 85477af6fe26f53504c07204348566c16a24392c..a5d39fe08699f1ec17462f3ac5600fbe2191f307 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -46,20 +46,57 @@ class HloEvaluatorTest : public HloVerifiedTestBase { HloEvaluatorTest() { evaluator_ = MakeUnique(); } std::unique_ptr evaluator_; + + void TestUnaryOp(HloOpcode opcode, std::unique_ptr expected, + std::unique_ptr input, float aabs = 0) { + HloComputation::Builder b(TestName()); + auto c1 = + b.AddInstruction(HloInstruction::CreateConstant(std::move(input))); + auto instruction = b.AddInstruction( + HloInstruction::CreateUnary(expected->shape(), opcode, c1)); + module().AddEntryComputation(b.Build()); + + std::unique_ptr result = + evaluator_->Evaluate(instruction, {}).ConsumeValueOrDie(); + + auto element_type = expected->shape().element_type(); + if (element_type == F32 || element_type == F64) { + ErrorSpec error(aabs); + LiteralTestUtil::ExpectNear(*expected, *result, error); + } else { + LiteralTestUtil::ExpectEqual(*expected, *result); + } + } + + void TestBinaryOp(HloOpcode opcode, std::unique_ptr expected, + std::unique_ptr lhs, + std::unique_ptr rhs) { + HloComputation::Builder b(TestName()); + auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs))); + auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs))); + auto instruction = b.AddInstruction( + HloInstruction::CreateBinary(expected->shape(), opcode, c1, c2)); + module().AddEntryComputation(b.Build()); + + std::unique_ptr result = + evaluator_->Evaluate(instruction, {}).ConsumeValueOrDie(); + + LiteralTestUtil::ExpectEqual(*expected, *result); + } }; // Verifies that HloEvaluator evaluates a HLO instruction that performs clamp // with 3 operands. TEST_F(HloEvaluatorTest, DoesClamp) { auto low = Literal::CreateR2({{0.f, 2.f}, {2.f, 4.f}}); - auto high = Literal::CreateR2({{2.f, 4.f}, {4.f, 4.f}}); auto value = Literal::CreateR2({{0.f, 5.f}, {0.f, 4.f}}); + auto high = Literal::CreateR2({{2.f, 4.f}, {4.f, 4.f}}); Shape shape = low->shape(); HloComputation::Builder b(TestName()); auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(low))); - auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(high))); - auto c3 = b.AddInstruction(HloInstruction::CreateConstant(std::move(value))); + auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(value))); + auto c3 = b.AddInstruction(HloInstruction::CreateConstant(std::move(high))); auto instruction = b.AddInstruction( HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3)); module().AddEntryComputation(b.Build()); @@ -72,6 +109,28 @@ TEST_F(HloEvaluatorTest, DoesClamp) { LiteralTestUtil::ExpectEqual(*expected, *result); } +TEST_F(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) { + auto low = Literal::CreateR0(0.f); + auto value = Literal::CreateR2({{-1.f, 0.f}, {1.f, 2.f}}); + auto high = Literal::CreateR0(1.f); + + Shape shape = value->shape(); + HloComputation::Builder b(TestName()); + auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(low))); + auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(value))); + auto c3 = b.AddInstruction(HloInstruction::CreateConstant(std::move(high))); + auto instruction = b.AddInstruction( + HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3)); + module().AddEntryComputation(b.Build()); + + std::unique_ptr result = + evaluator_->Evaluate(instruction, {}).ConsumeValueOrDie(); + + auto expected = Literal::CreateR2({{0, 0}, {1, 1}}); + + LiteralTestUtil::ExpectEqual(*expected, *result); +} + // Verifies that HloEvaluator evaluates a HLO instruction that performs select // with 3 operands. TEST_F(HloEvaluatorTest, DoesSelect) { @@ -103,120 +162,101 @@ TEST_F(HloEvaluatorTest, DoesSelect) { TEST_F(HloEvaluatorTest, DoesAdd) { auto lhs = Literal::CreateR2({{1, 0}, {-100, 4}}); auto rhs = Literal::CreateR2({{2, 4}, {4, 4}}); - - Shape shape = ShapeUtil::MakeShape(S64, {2, 2}); - HloComputation::Builder b(TestName()); - auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs))); - auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs))); - auto instruction = b.AddInstruction( - HloInstruction::CreateBinary(shape, HloOpcode::kAdd, c1, c2)); - module().AddEntryComputation(b.Build()); - - std::unique_ptr result = - evaluator_->Evaluate(instruction, {}).ConsumeValueOrDie(); - auto expected = Literal::CreateR2({{3, 4}, {-96, 8}}); - - LiteralTestUtil::ExpectEqual(*expected, *result); + TestBinaryOp(HloOpcode::kAdd, std::move(expected), std::move(lhs), + std::move(rhs)); +} +// Verifies that HloEvaluator evaluates a HLO instruction that performs +// element-wise and with 2 operands. +TEST_F(HloEvaluatorTest, DoesAnd) { + auto lhs = Literal::CreateR2({{1, 0}, {-100, 4}}); + auto rhs = Literal::CreateR2({{2, 4}, {4, 4}}); + auto expected = Literal::CreateR2({{0, 0}, {4, 4}}); + TestBinaryOp(HloOpcode::kAnd, std::move(expected), std::move(lhs), + std::move(rhs)); +} +// Verifies that HloEvaluator evaluates a HLO instruction that performs +// element-wise or with 2 operands. +TEST_F(HloEvaluatorTest, DoesOr) { + auto lhs = Literal::CreateR2({{1, 0}, {-100, 4}}); + auto rhs = Literal::CreateR2({{2, 4}, {4, 4}}); + auto expected = Literal::CreateR2({{3, 4}, {-100, 4}}); + TestBinaryOp(HloOpcode::kOr, std::move(expected), std::move(lhs), + std::move(rhs)); +} +// Verifies that HloEvaluator evaluates a HLO instruction that performs +// element-wise multiply with 2 operands. +TEST_F(HloEvaluatorTest, DoesMultiply) { + auto lhs = Literal::CreateR2({{-1, 0}, {-100, 4}}); + auto rhs = Literal::CreateR2( + {{std::numeric_limits::min(), 4}, {4, 4}}); + auto expected = Literal::CreateR2( + {{std::numeric_limits::min(), 0}, {-400, 16}}); + TestBinaryOp(HloOpcode::kMultiply, std::move(expected), std::move(lhs), + std::move(rhs)); } - // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise divide with 2 operands. TEST_F(HloEvaluatorTest, DoesDivideInt64) { - auto lhs_s64 = Literal::CreateR2({{1, 0}, {-100, 4}}); - auto rhs_s64 = Literal::CreateR2({{2, 4}, {4, 4}}); - - Shape shape_s64 = ShapeUtil::MakeShape(S64, {2, 2}); - HloComputation::Builder b(TestName()); - auto c1_s64 = - b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_s64))); - auto c2_s64 = - b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_s64))); - auto instruction = b.AddInstruction(HloInstruction::CreateBinary( - shape_s64, HloOpcode::kDivide, c1_s64, c2_s64)); - module().AddEntryComputation(b.Build()); - - std::unique_ptr result = - evaluator_->Evaluate(instruction, {}).ConsumeValueOrDie(); - + auto lhs = Literal::CreateR2({{1, 0}, {-100, 4}}); + auto rhs = Literal::CreateR2({{2, 4}, {4, 4}}); auto expected = Literal::CreateR2({{0, 0}, {-25, 1}}); - - LiteralTestUtil::ExpectEqual(*expected, *result); + TestBinaryOp(HloOpcode::kDivide, std::move(expected), std::move(lhs), + std::move(rhs)); } TEST_F(HloEvaluatorTest, DoesDivideDouble) { - auto lhs_f64 = Literal::CreateR2({{1.0, 0.0}, {-100.0, 4.0}}); - auto rhs_f64 = Literal::CreateR2({{2.2, 4.0}, {4.0, 4.0}}); - - Shape shape_f64 = ShapeUtil::MakeShape(F64, {2, 2}); - HloComputation::Builder b(TestName()); - auto c1_f64 = - b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_f64))); - auto c2_f64 = - b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_f64))); - auto instruction = b.AddInstruction(HloInstruction::CreateBinary( - shape_f64, HloOpcode::kDivide, c1_f64, c2_f64)); - module().AddEntryComputation(b.Build()); - - auto result = evaluator_->Evaluate(instruction, {}).ConsumeValueOrDie(); - + auto lhs = Literal::CreateR2({{1.0, 0.0}, {-100.0, 4.0}}); + auto rhs = Literal::CreateR2({{2.2, 4.0}, {4.0, 4.0}}); auto expected = Literal::CreateR2({{0.45454545454545453, 0}, {-25, 1}}); - - LiteralTestUtil::ExpectEqual(*expected, *result); + TestBinaryOp(HloOpcode::kDivide, std::move(expected), std::move(lhs), + std::move(rhs)); } // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise abs op with 1 operand. TEST_F(HloEvaluatorTest, DoesAbsR2) { auto operand = Literal::CreateR2({{1, -20}, {-100, 4}}); - const Shape& shape = ShapeUtil::MakeShape(S64, {2, 2}); - HloComputation::Builder b(TestName()); - auto c1 = - b.AddInstruction(HloInstruction::CreateConstant(std::move(operand))); - auto instruction = - b.AddInstruction(HloInstruction::CreateUnary(shape, HloOpcode::kAbs, c1)); - module().AddEntryComputation(b.Build()); - - std::unique_ptr result = - evaluator_->Evaluate(instruction, {}).ConsumeValueOrDie(); - auto expected = Literal::CreateR2({{1, 20}, {100, 4}}); - - LiteralTestUtil::ExpectEqual(*expected, *result); + TestUnaryOp(HloOpcode::kAbs, std::move(expected), std::move(operand)); } TEST_F(HloEvaluatorTest, DoesAbsR0) { - // For R0 literal. - const Shape& r0 = ShapeUtil::MakeShape(F32, {}); auto operand = Literal::CreateR0(-1.0f); - HloComputation::Builder b(TestName()); - auto c1 = - b.AddInstruction(HloInstruction::CreateConstant(std::move(operand))); - auto instruction = - b.AddInstruction(HloInstruction::CreateUnary(r0, HloOpcode::kAbs, c1)); - module().AddEntryComputation(b.Build()); - - auto result = evaluator_->Evaluate(instruction).ConsumeValueOrDie(); auto expected = Literal::CreateR0(1.0f); - - LiteralTestUtil::ExpectEqual(*expected, *result); + TestUnaryOp(HloOpcode::kAbs, std::move(expected), std::move(operand)); } TEST_F(HloEvaluatorTest, DoesAbsR1WithZeroSize) { - // For R1 literal with dimension of size 0. - Shape empty_r1 = ShapeUtil::MakeShape(F32, {0}); auto operand = Literal::CreateR1({}); - HloComputation::Builder b(TestName()); - auto c1 = - b.AddInstruction(HloInstruction::CreateConstant(std::move(operand))); - auto instruction = b.AddInstruction( - HloInstruction::CreateUnary(empty_r1, HloOpcode::kAbs, c1)); - module().AddEntryComputation(b.Build()); - - auto result = evaluator_->Evaluate(instruction).ConsumeValueOrDie(); auto expected = Literal::CreateR1({}); - - LiteralTestUtil::ExpectEqual(*expected, *result); + TestUnaryOp(HloOpcode::kAbs, std::move(expected), std::move(operand)); +} +TEST_F(HloEvaluatorTest, DoesNegateR2) { + auto operand = Literal::CreateR2( + {{0, std::numeric_limits::min()}, {-1, 4}}); + auto expected = + Literal::CreateR2({{0, std::numeric_limits::min()}, {1, -4}}); + TestUnaryOp(HloOpcode::kNegate, std::move(expected), std::move(operand)); +} +TEST_F(HloEvaluatorTest, DoesCosR2) { + auto operand = Literal::CreateR2({{0, M_PI}, {-M_PI, 2 * M_PI}}); + auto expected = Literal::CreateR2({{1, -1}, {-1, 1}}); + TestUnaryOp(HloOpcode::kCos, std::move(expected), std::move(operand)); +} +TEST_F(HloEvaluatorTest, DoesSinR2) { + auto operand = Literal::CreateR2({{0, M_PI}, {-M_PI, 2 * M_PI}}); + auto expected = Literal::CreateR2({{0, 0}, {0, 0}}); + TestUnaryOp(HloOpcode::kSin, std::move(expected), std::move(operand), + 0x1.0P-20); +} +TEST_F(HloEvaluatorTest, DoesNotR2) { + auto operand = + Literal::CreateR2({{0, std::numeric_limits::min()}, + {-1, std::numeric_limits::max()}}); + auto expected = + Literal::CreateR2({{-1, std::numeric_limits::max()}, + {0, std::numeric_limits::min()}}); + TestUnaryOp(HloOpcode::kNot, std::move(expected), std::move(operand)); } - // Verifies that HloEvaluator evaluates a HLO Computation with non-parameter nor // constant operands. TEST_F(HloEvaluatorTest, DoesTraverseInstructions) { @@ -581,8 +621,11 @@ TEST_F(HloEvaluatorTest, DotRank2AndRank1) { b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); Shape shape = ShapeUtil::MakeShape(F32, {4, 2}); - b.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kDot, lhs_instruction, rhs_instruction)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction, + rhs_instruction, dot_dnums)); auto computation = module().AddEntryComputation(b.Build()); std::unique_ptr result = @@ -624,8 +667,11 @@ TEST_F(HloEvaluatorTest, DotRank1AndRank2) { b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); Shape shape = ShapeUtil::MakeShape(F32, {2}); - b.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kDot, lhs_instruction, rhs_instruction)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(0); + dot_dnums.add_rhs_contracting_dimensions(0); + b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction, + rhs_instruction, dot_dnums)); auto computation = module().AddEntryComputation(b.Build()); std::unique_ptr result = @@ -665,8 +711,11 @@ TEST_F(HloEvaluatorTest, DotRank2AndRank2) { b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); Shape shape = ShapeUtil::MakeShape(F32, {4, 2}); - b.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kDot, lhs_instruction, rhs_instruction)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction, + rhs_instruction, dot_dnums)); auto computation = module().AddEntryComputation(b.Build()); std::unique_ptr result = @@ -711,7 +760,8 @@ TEST_F(HloEvaluatorTest, SimpleConv1D) { dnums.set_output_batch_dimension(0); dnums.set_input_feature_dimension(1); dnums.set_output_feature_dimension(1); - dnums.add_spatial_dimensions(2); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); dnums.set_kernel_output_feature_dimension(0); dnums.set_kernel_input_feature_dimension(1); @@ -794,6 +844,85 @@ TEST_F(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { LiteralTestUtil::ExpectEqual(*expected, *result); } +TEST_F(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) { + HloComputation::Builder b(TestName()); + + // clang-format off + // Input dimensions: [feature=2, height=3, batch=1, width=4] + Array4D input({ + {{{1, 2, 3, 4}}, + {{5, 6, 7, 8}}, + {{9, 10, 11, 12}}}, + {{{13, 14, 15, 16}}, + {{17, 18, 19, 20}}, + {{21, 22, 23, 24}}} + }); + // Weight dimensions: + // [kernel_output_feature=1, width=3, kernel_input_feature=2, height=3] + Array4D weight({{ + {{1, 7, 13}, + {4, 10, 16}}, + {{2, 8, 14}, + {5, 11, 17}}, + {{3, 9, 15}, + {6, 12, 18}} + }}); + // clang-format on + + auto lhs_literal = Literal::CreateR4FromArray4D(input); + HloInstruction* lhs_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal))); + + auto rhs_literal = Literal::CreateR4FromArray4D(weight); + HloInstruction* rhs_instruction = + b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); + rhs_instruction = b.AddInstruction(HloInstruction::CreateReverse( + rhs_instruction->shape(), rhs_instruction, {3, 1})); + + Window window; + WindowDimension dim; + dim.set_size(3); + dim.set_stride(1); + dim.set_padding_low(0); + dim.set_padding_high(0); + dim.set_window_dilation(1); + dim.set_base_dilation(1); + dim.set_window_reversal(true); + *window.add_dimensions() = dim; + *window.add_dimensions() = dim; + + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(2); + dnums.set_output_batch_dimension(2); + dnums.set_input_feature_dimension(0); + dnums.set_output_feature_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(3); + dnums.add_output_spatial_dimensions(3); + + dnums.set_kernel_output_feature_dimension(0); + dnums.set_kernel_input_feature_dimension(2); + dnums.add_kernel_spatial_dimensions(3); + dnums.add_kernel_spatial_dimensions(1); + + const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2}); + b.AddInstruction(HloInstruction::CreateConvolve( + shape, lhs_instruction, rhs_instruction, window, dnums)); + auto computation = module().AddEntryComputation(b.Build()); + + std::unique_ptr result = + evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + + // clang-format off + // Result dimensions: [feature=1, height=1, batch=1, width=2] + Array4D expected_array({{{{2514, 2685}}}}); + // clang-format on + auto expected = Literal::CreateR4FromArray4D(expected_array); + + LiteralTestUtil::ExpectEqual(*expected, *result); +} + TEST_F(HloEvaluatorTest, Conv2DGeneralDimensions) { HloComputation::Builder b(TestName()); @@ -843,8 +972,10 @@ TEST_F(HloEvaluatorTest, Conv2DGeneralDimensions) { dnums.set_output_batch_dimension(2); dnums.set_input_feature_dimension(0); dnums.set_output_feature_dimension(0); - dnums.add_spatial_dimensions(1); - dnums.add_spatial_dimensions(3); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(3); + dnums.add_output_spatial_dimensions(3); dnums.set_kernel_output_feature_dimension(0); dnums.set_kernel_input_feature_dimension(2); diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile.cc b/tensorflow/compiler/xla/service/hlo_execution_profile.cc index eaeb352183bdf6cc7f4a164c31af4f641e37440e..0809fe780d21baf366b63bdab118653630c33872 100644 --- a/tensorflow/compiler/xla/service/hlo_execution_profile.cc +++ b/tensorflow/compiler/xla/service/hlo_execution_profile.cc @@ -26,45 +26,110 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" namespace xla { - -void HloExecutionProfile::AddProfileResult(const HloInstruction* hlo, - uint64 cycles_taken) { - hlo_to_cycles_taken_[hlo] = cycles_taken; - profiled_computations_.insert(hlo->parent()); +HloProfileIndexMap::HloProfileIndexMap(const HloModule& module) { + size_t current_profile_index = 0; + for (xla::HloComputation* computation : module.MakeComputationPostOrder()) { + InsertOrDie(&computation_to_profile_idx_, computation, + current_profile_index++); + for (const HloInstruction* instruction : computation->instructions()) { + // For simplicity we track all instrutions here, but we could skip + // non-executing instructions like constants and parameters. + InsertOrDie(&instruction_to_profile_idx_, instruction, + current_profile_index++); + } + } } -uint64 HloExecutionProfile::GetProfileResult(const HloInstruction& hlo) const { - auto iter = hlo_to_cycles_taken_.find(&hlo); - if (iter == hlo_to_cycles_taken_.end()) { - return 0; +std::unique_ptr CreateHloProfilePrinter( + const HloProfileIndexMap& hlo_profile_index_map, + const HloCostAnalysis& cost_analysis) { + using HloComputationInfo = HloProfilePrinter::HloComputationInfo; + using HloInstructionInfo = HloProfilePrinter::HloInstructionInfo; + + HloComputationInfo* computation_infos = + new HloComputationInfo[hlo_profile_index_map.computation_count()]; + + // There are two "indices" in play here. The first one is the index of the + // HloComputationInfo or HloInstructionInfo in the array that contains said + // HloComputationInfo or HloInstructionInfo. The second index is the index of + // the HloComputationInfo or HloInstructionInfo in the profile counters array, + // as decided by hlo_profile_index_map. The latter index is always referred + // to as "profile_index". + + size_t computation_index_in_static_data = 0; + size_t max_profile_index = hlo_profile_index_map.total_count(); + for (const auto& pair : hlo_profile_index_map.computation_to_profile_idx()) { + CHECK_LT(pair.second, max_profile_index); + const HloComputation* computation = pair.first; + size_t current_computation_index = computation_index_in_static_data++; + HloComputationInfo* computation_info = + &computation_infos[current_computation_index]; + + computation_info->name = strdup(computation->name().c_str()); + computation_info->profile_index = pair.second; + computation_info->instructions = + new HloInstructionInfo[computation->instruction_count()]; + computation_info->instructions_size = computation->instruction_count(); + + size_t instruction_index_in_static_data = 0; + for (const HloInstruction* hlo : computation->instructions()) { + HloProfilePrinter::HloInstructionInfo* instruction_info = + &computation_info->instructions[instruction_index_in_static_data++]; + instruction_info->long_name = strdup(hlo->ToString().c_str()); + instruction_info->short_name = + strdup(hlo->ToString(/*compact_operands=*/true).c_str()); + instruction_info->category = strdup(hlo->ToCategory().c_str()); + instruction_info->flop_count = cost_analysis.flop_count(*hlo); + instruction_info->transcendental_count = + cost_analysis.transcendental_count(*hlo); + instruction_info->bytes_accessed = cost_analysis.bytes_accessed(*hlo); + instruction_info->optimal_seconds = cost_analysis.optimal_seconds(*hlo); + instruction_info->profile_index = + hlo_profile_index_map.GetProfileIndexFor(*hlo); + CHECK_LT(instruction_info->profile_index, max_profile_index); + } } - return iter->second; + + auto deleter = [](HloProfilePrinter::HloComputationInfo* computation_infos, + int64 computation_infos_size) { + for (int64 i = 0; i < computation_infos_size; i++) { + HloInstructionInfo* instruction_infos = computation_infos[i].instructions; + for (int64 j = 0; j < computation_infos[i].instructions_size; j++) { + // We can't make instruction_infos[j].long_name etc. non-const pointers + // since they may point into static storage, so we have a const_cast + // here. + free(const_cast(instruction_infos[j].long_name)); + free(const_cast(instruction_infos[j].short_name)); + free(const_cast(instruction_infos[j].category)); + } + delete[] instruction_infos; + free(const_cast(computation_infos[i].name)); + } + delete[] computation_infos; + }; + + return MakeUnique( + computation_infos, hlo_profile_index_map.computation_count(), + /*profile_counters_size=*/max_profile_index, deleter); } -string HloExecutionProfile::ToString( - const HloComputation& computation, - const DeviceDescription& device_description, - HloCostAnalysis* cost_analysis) const { - tensorflow::Status analysis_status = computation.Accept(cost_analysis); - if (!analysis_status.ok()) { - return ""; - } +HloExecutionProfile::HloExecutionProfile( + const HloProfilePrinter* hlo_profile_printer, + const HloProfileIndexMap* hlo_profile_index_map) + : hlo_profile_printer_(*hlo_profile_printer), + hlo_profile_index_map_(*hlo_profile_index_map), + profile_counters_( + /*count*/ hlo_profile_index_map_.total_count(), + /*value*/ 0) {} - HumanReadableProfileBuilder builder(computation.name(), - total_cycles_executed(computation), - device_description.clock_rate_ghz()); - for (const auto& item : hlo_to_cycles_taken_) { - const HloInstruction* hlo = item.first; - int64 cycles = item.second; - - builder.AddOp(/*op_name=*/hlo->ToString(), - /*short_name=*/hlo->ToString(/*compact_operands=*/true), - hlo->ToCategory(), cycles, cost_analysis->flop_count(*hlo), - cost_analysis->transcendental_count(*hlo), - cost_analysis->bytes_accessed(*hlo), - cost_analysis->seconds(*hlo)); - } - return builder.ToString(); +void HloExecutionProfile::SetCyclesTakenBy(const HloInstruction* hlo, + uint64 cycles_taken) { + profile_counters_[hlo_profile_index_map_.GetProfileIndexFor(*hlo)] = + cycles_taken; +} + +uint64 HloExecutionProfile::GetCyclesTakenBy(const HloInstruction& hlo) const { + return profile_counters_[hlo_profile_index_map_.GetProfileIndexFor(hlo)]; } } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile.h b/tensorflow/compiler/xla/service/hlo_execution_profile.h index a980c1617f395fc6668b8f8739e04d18fd1b689e..470fd4ce3c205d84152238f4b18daad77e403f68 100644 --- a/tensorflow/compiler/xla/service/hlo_execution_profile.h +++ b/tensorflow/compiler/xla/service/hlo_execution_profile.h @@ -18,7 +18,9 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" +#include "tensorflow/compiler/xla/service/hlo_profile_printer.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/types.h" @@ -27,6 +29,59 @@ namespace xla { class HloInstruction; +// Maps all HloInstructions and HloComputations in an HloModule to integers. +// These integers form the contiguous range [0, total_count()). +class HloProfileIndexMap { + public: + // Scans `module` to populate this instance of HloProfileIndexMap. + explicit HloProfileIndexMap(const HloModule& module); + + HloProfileIndexMap(const HloProfileIndexMap&) = default; + HloProfileIndexMap(HloProfileIndexMap&&) = default; + + HloProfileIndexMap& operator=(const HloProfileIndexMap&) = default; + HloProfileIndexMap& operator=(HloProfileIndexMap&&) = default; + + size_t GetProfileIndexFor(const HloInstruction& instruction) const { + return FindOrDie(instruction_to_profile_idx(), &instruction); + } + + size_t GetProfileIndexFor(const HloComputation& computation) const { + return FindOrDie(computation_to_profile_idx(), &computation); + } + + size_t instruction_count() const { + return instruction_to_profile_idx().size(); + } + + size_t computation_count() const { + return computation_to_profile_idx().size(); + } + + size_t total_count() const { + return instruction_count() + computation_count(); + } + + const std::unordered_map& + instruction_to_profile_idx() const { + return instruction_to_profile_idx_; + } + + const std::unordered_map& + computation_to_profile_idx() const { + return computation_to_profile_idx_; + } + + private: + std::unordered_map instruction_to_profile_idx_; + std::unordered_map computation_to_profile_idx_; +}; + +// Create an instance of `HloProfilePrinter` that owns its memory. +std::unique_ptr CreateHloProfilePrinter( + const HloProfileIndexMap& hlo_profile_index_map, + const HloCostAnalysis& cost_analysis); + // Describes how much time each HLO operation took. // // Each HloComputation takes a certain number of cycles. This class helps break @@ -35,26 +90,27 @@ class HloExecutionProfile { public: using DeviceDescription = perftools::gputools::DeviceDescription; + HloExecutionProfile(const HloProfilePrinter* hlo_profile_printer, + const HloProfileIndexMap* hlo_profile_index_map); + // Record how many cycles this HLO took to execute. - void AddProfileResult(const HloInstruction* hlo, uint64 cycles_taken); + void SetCyclesTakenBy(const HloInstruction* hlo, uint64 cycles_taken); // Returns how many cycles this HLO took to execute. Profiling information // may not be available for some instructions in which case zero is returned. - uint64 GetProfileResult(const HloInstruction& hlo) const; + uint64 GetCyclesTakenBy(const HloInstruction& hlo) const; // Return the number of cycles this computation took to execute. uint64 total_cycles_executed(const HloComputation& computation) const { - auto it = total_cycles_executed_.find(&computation); - if (it != total_cycles_executed_.end()) { - return it->second; - } - return 0; + return profile_counters_[hlo_profile_index_map_.GetProfileIndexFor( + computation)]; } // Record how many cycles a computation took to execute. void set_total_cycles_executed(const HloComputation& computation, uint64 total_cycles_executed) { - total_cycles_executed_[&computation] = total_cycles_executed; + profile_counters_[hlo_profile_index_map_.GetProfileIndexFor(computation)] = + total_cycles_executed; } // Returns a version of the execution profile suitable for performance @@ -63,25 +119,20 @@ class HloExecutionProfile { // for the operations in a given computation. Returns an empty string if it // wasn't possible to generate a printable version. cost_analysis should be a // clean analysis that can be used to visit the computation. - string ToString(const HloComputation& computation, - const DeviceDescription& device_description, - HloCostAnalysis* cost_analysis) const; - - // Returns the computations we have profiled. - std::unordered_set profiled_computations() const { - return profiled_computations_; + string ToString(const DeviceDescription& device_description) const { + return hlo_profile_printer_.ToString(profile_counters_.data(), + device_description.clock_rate_ghz()); } - private: - // Contains a mapping from HLO to the number of cycles it took to execute it. - std::unordered_map hlo_to_cycles_taken_; + std::vector* mutable_profile_counters() { return &profile_counters_; } - // If non-empty, contains the total number of cycles a computation took to - // execute. - std::unordered_map total_cycles_executed_; + private: + const HloProfilePrinter& hlo_profile_printer_; + const HloProfileIndexMap& hlo_profile_index_map_; - // The computations we have profiled. - std::unordered_set profiled_computations_; + // Stores per-Hlo profile counters. This is the only thing that changes when + // we execute an XLA computation. + std::vector profile_counters_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc b/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..b1e6729e2bccad4bdbe075a635d8a9b1ede6fecb --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc @@ -0,0 +1,103 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_execution_profile.h" +#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" + +namespace xla { +namespace { + +class HloExecutionProfileTest : public HloTestBase { + protected: + static constexpr int64 kInstructionCyclesIndex = 0; + static constexpr int64 kInstructionNameIndex = 19; +}; + +// Splits `lines` into a sequence of lines delimited by newlines and then split +// each of those lines into a sequence of words delimited by spaces. Filter out +// empty words. +std::vector> SplitIntoLinesAndWords( + tensorflow::StringPiece lines) { + std::vector> result; + for (const string& line : tensorflow::str_util::Split(lines, '\n')) { + std::vector words; + for (const string& word : tensorflow::str_util::Split(line, ' ')) { + if (!word.empty()) { + words.push_back(word); + } + } + result.push_back(std::move(words)); + } + + return result; +} + +TEST_F(HloExecutionProfileTest, Basic) { + std::unique_ptr hlo_module = CreateNewModule(); + + HloComputation::Builder builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {30, 30}); + HloInstruction* param_lhs = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "lhs")); + HloInstruction* param_rhs = + builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "rhs")); + HloInstruction* add_instruction = + builder.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kAdd, param_lhs, param_rhs)); + HloInstruction* dot_instruction = + builder.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kDot, param_lhs, add_instruction)); + + hlo_module->AddEntryComputation(builder.Build()); + + auto shape_size_function = [&](const Shape& shape) { + const int64 pointer_size = 8; + if (ShapeUtil::IsOpaque(shape)) { + return pointer_size; + } + return ShapeUtil::ByteSizeOf(shape, pointer_size); + }; + + HloCostAnalysis cost_analysis(shape_size_function); + HloProfileIndexMap profile_index_map(*hlo_module); + std::unique_ptr profile_printer = + CreateHloProfilePrinter(profile_index_map, cost_analysis); + HloExecutionProfile execution_profile(profile_printer.get(), + &profile_index_map); + + const int64 add_cycles = 1000; + const int64 dot_cycles = 4000; + + execution_profile.SetCyclesTakenBy(add_instruction, add_cycles); + execution_profile.SetCyclesTakenBy(dot_instruction, dot_cycles); + + string rendered_profile = execution_profile.ToString( + backend().default_stream_executor()->GetDeviceDescription()); + std::vector> lines_and_words = + SplitIntoLinesAndWords(rendered_profile); + ASSERT_EQ(lines_and_words.size(), 8); + + const std::vector& line_2 = lines_and_words[2]; + const std::vector& line_3 = lines_and_words[3]; + + EXPECT_EQ(line_2[kInstructionCyclesIndex], std::to_string(dot_cycles)); + EXPECT_EQ(line_2[kInstructionNameIndex], '%' + dot_instruction->name()); + + EXPECT_EQ(line_3[kInstructionCyclesIndex], std::to_string(add_cycles)); + EXPECT_EQ(line_3[kInstructionNameIndex], '%' + add_instruction->name()); +} +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index d7bdd4117d947add448ff660abc621d9ae3118b6..84187d578346eafd5e32727a15f5eab9cc79feef 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -312,11 +312,11 @@ optional MatchTrivialComputation(const HloComputation* computation) { class HloDotDumper { public: HloDotDumper(const HloComputation* computation, tensorflow::StringPiece label, - bool show_addresses, bool show_metadata, + const DebugOptions& debug_options, bool show_metadata, const HloExecutionProfile* profile, NodeFilter filter) : computation_(computation), label_(label.ToString()), - show_addresses_(show_addresses), + debug_options_(debug_options), show_metadata_(show_metadata), profile_(profile), filter_(std::move(filter)) {} @@ -382,7 +382,7 @@ class HloDotDumper { const HloComputation* computation_; // never null const string label_; // overall name for the graph - const bool show_addresses_; + const DebugOptions& debug_options_; const bool show_metadata_; const HloExecutionProfile* profile_; // may be null const NodeFilter filter_; @@ -414,6 +414,11 @@ class HloDotDumper { // appears before both the inner computation and the destination node are // defined. std::vector edges_; + + // When coloring by sharding information, we track the sharding string + // representation to color association, by round-robin the color schemes. + std::unordered_map sharding_colors_; + int64 next_shard_color_ = 0; }; string HloDotDumper::Dump() { @@ -734,15 +739,16 @@ string HloDotDumper::DumpInstruction(const HloInstruction* instr) { string trivial_subcomputation = GetInstructionTrivialComputationStr(instr); AddInstructionIncomingEdges(instr); - // Override the node's styling if it should be (de-)emphasized. - if (filter_.Deemphasized(instr)) { - color = kDashedBorder; - } - if (filter_.Highlight(instr)) { - node_shape = "diamond"; - color = kDarkRed; + if (!debug_options_.xla_hlo_graph_sharding_color()) { + // Override the node's styling if it should be (de-)emphasized. + if (filter_.Deemphasized(instr)) { + color = kDashedBorder; + } + if (filter_.Highlight(instr)) { + node_shape = "diamond"; + color = kDarkRed; + } } - // Build the text that will be displayed inside the node. string node_body = node_label; for (const string& s : @@ -761,12 +767,22 @@ string HloDotDumper::DumpInstruction(const HloInstruction* instr) { string HloDotDumper::GetInstructionNodeInlinedOperands( const HloInstruction* instr) { auto stringify_constant = [](const HloInstruction* constant) { - if (ShapeUtil::IsEffectiveScalar(constant->shape())) { - auto elem_idx = IndexUtil::LinearIndexToMultidimensionalIndex( - constant->shape(), /*linear_index=*/0); - return Printf("%s (%s)", constant->literal().GetAsString(elem_idx), + const auto& shape = constant->shape(); + + // Print the literal value of constants with <= K elements. + optional elem_count; + if (!ShapeUtil::IsOpaque(shape) && !ShapeUtil::IsTuple(shape)) { + elem_count = 1; + for (int64 dim : shape.dimensions()) { + *elem_count *= dim; + } + } + if (elem_count.has_value() && *elem_count <= 8) { + return Printf("%s (%s)", constant->literal().ToString(), ShapeUtil::HumanString(constant->shape())); } + + // Otherwise, print e.g. "%constant.42 (s32[100])". string constant_name; if (tensorflow::StringPiece(constant->name()).starts_with("%constant")) { constant_name = constant->name(); @@ -817,6 +833,20 @@ string HloDotDumper::GetInstructionNodeInlinedOperands( } ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { + if (debug_options_.xla_hlo_graph_sharding_color()) { + if (!instr->has_sharding()) { + return kDashedBorder; + } + string shard_str = instr->sharding().ToString(); + auto it = sharding_colors_.find(shard_str); + if (it != sharding_colors_.end()) { + return it->second; + } + ColorScheme color = static_cast( + kBlue + (next_shard_color_++ % (kDashedBorder - kBlue))); + sharding_colors_.emplace(shard_str, color); + return color; + } const auto kParameterColor = kOrange; // Special case: If this instruction has a parameter merged into it, paint it @@ -834,9 +864,10 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { // (eg, parameter). switch (instr->opcode()) { case HloOpcode::kAbs: - case HloOpcode::kRoundNearestAfz: case HloOpcode::kAdd: + case HloOpcode::kAnd: case HloOpcode::kAtan2: + case HloOpcode::kBitcastConvert: case HloOpcode::kCeil: case HloOpcode::kClamp: case HloOpcode::kComplex: @@ -852,18 +883,19 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kIsFinite: case HloOpcode::kLe: case HloOpcode::kLog: - case HloOpcode::kAnd: - case HloOpcode::kNot: - case HloOpcode::kOr: case HloOpcode::kLt: case HloOpcode::kMaximum: case HloOpcode::kMinimum: case HloOpcode::kMultiply: case HloOpcode::kNe: case HloOpcode::kNegate: + case HloOpcode::kNot: + case HloOpcode::kOr: case HloOpcode::kPower: case HloOpcode::kReal: case HloOpcode::kRemainder: + case HloOpcode::kRng: + case HloOpcode::kRoundNearestAfz: case HloOpcode::kShiftLeft: case HloOpcode::kShiftRightArithmetic: case HloOpcode::kShiftRightLogical: @@ -873,7 +905,6 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kSort: case HloOpcode::kSubtract: case HloOpcode::kTanh: - case HloOpcode::kRng: // De-emphasize scalar-shaped elementwise ops -- they're generally // uninteresting. if (ShapeUtil::IsEffectiveScalar(instr->shape())) { @@ -881,9 +912,9 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { } return kYellow; case HloOpcode::kBitcast: - case HloOpcode::kTuple: - case HloOpcode::kTrace: case HloOpcode::kGetTupleElement: + case HloOpcode::kTrace: + case HloOpcode::kTuple: return kWhite; case HloOpcode::kBroadcast: // De-emphasize nodes which broadcast a scalar within a fusion node -- @@ -922,25 +953,28 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { return kRed; case HloOpcode::kParameter: return kParameterColor; - case HloOpcode::kBatchNormTraining: - case HloOpcode::kBatchNormInference: case HloOpcode::kBatchNormGrad: + case HloOpcode::kBatchNormInference: + case HloOpcode::kBatchNormTraining: case HloOpcode::kReduce: - case HloOpcode::kSelectAndScatter: case HloOpcode::kReduceWindow: + case HloOpcode::kSelectAndScatter: return kPurple; - case HloOpcode::kMap: case HloOpcode::kFusion: + case HloOpcode::kMap: return kGray; - case HloOpcode::kSend: - case HloOpcode::kRecv: + case HloOpcode::kCrossReplicaSum: case HloOpcode::kInfeed: case HloOpcode::kOutfeed: - case HloOpcode::kCrossReplicaSum: + case HloOpcode::kRecv: + case HloOpcode::kRecvDone: + case HloOpcode::kSend: + case HloOpcode::kSendDone: return kBrown; + case HloOpcode::kCall: + case HloOpcode::kConditional: case HloOpcode::kCustomCall: case HloOpcode::kWhile: - case HloOpcode::kCall: return kDarkGreen; case HloOpcode::kConstant: LOG(FATAL) << "Constants don't get their own nodes in the graph."; @@ -969,10 +1003,13 @@ string HloDotDumper::GetInstructionNodeLabel(const HloInstruction* instr) { .starts_with(StrCat("%", HloOpcodeString(instr->opcode())))) { return Printf("%s", HtmlLikeStringSanitize(instr->name())); } - + string extended_opcode = + StrCat(HloOpcodeString(instr->opcode()), + instr->opcode() != HloOpcode::kFusion + ? "" + : StrCat(":", xla::ToString(instr->fusion_kind()))); // If the name does not contain the opcode, render both. - return Printf("%s
%s", - HtmlLikeStringSanitize(instr->ExtendedOpcodeStr()), + return Printf("%s
%s", HtmlLikeStringSanitize(extended_opcode), HtmlLikeStringSanitize(instr->name())); } @@ -1027,7 +1064,9 @@ string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) { ? "" : StrCat("stride=", VectorString(instr->slice_strides())); case HloOpcode::kSend: + case HloOpcode::kSendDone: case HloOpcode::kRecv: + case HloOpcode::kRecvDone: return StrCat("channel_id=", instr->channel_id()); default: return ""; @@ -1065,12 +1104,11 @@ string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) { } lines.push_back(instr_shape); } - - if (show_addresses_) { + if (debug_options_.xla_hlo_graph_addresses()) { lines.push_back(Printf("[%p]", instr)); } if (profile_ != nullptr) { - double hlo_cycles_executed = profile_->GetProfileResult(*instr); + double hlo_cycles_executed = profile_->GetCyclesTakenBy(*instr); double total_cycles_executed = profile_->total_cycles_executed(*instr->parent()); if (hlo_cycles_executed > 0 && total_cycles_executed > 0) { @@ -1163,70 +1201,36 @@ const HloInstruction* HloDotDumper::GetNodeForEdge( return instr; } -tensorflow::mutex& RendererMutex() { - static tensorflow::mutex* mu = new tensorflow::mutex; - return *mu; -} +class GraphRendererRegistry { + public: + void AddRenderer(GraphRendererInterface* graph_renderer) { + tensorflow::mutex_lock lock(mu_); + graph_renderer_ = graph_renderer; + } -std::map* GraphRenderers() { - static auto* graph_renderers = new std::map(); - return graph_renderers; -} + GraphRendererInterface* GetDefaultRenderer() { + tensorflow::mutex_lock lock(mu_); + return graph_renderer_; + } -GraphRendererInterface* GetGraphRenderer() { - tensorflow::mutex_lock lock(RendererMutex()); - auto* graph_renderers = GraphRenderers(); - auto it = graph_renderers->rbegin(); - CHECK(it != graph_renderers->rend()) << "No registered graph dumpers"; - return it->second; -} + static GraphRendererRegistry* Default() { + static GraphRendererRegistry* registry = new GraphRendererRegistry(); + return registry; + } + + private: + tensorflow::mutex mu_; + GraphRendererInterface* graph_renderer_ = nullptr; +}; } // namespace -Registrar::Registrar(GraphRendererInterface* dumper, int priority) { - tensorflow::mutex_lock lock(RendererMutex()); - auto* graph_renderers = GraphRenderers(); - graph_renderers->emplace(priority, dumper); +Registrar::Registrar(GraphRendererInterface* dumper) { + GraphRendererRegistry::Default()->AddRenderer(dumper); } namespace { -class FileGraphRenderer : public GraphRendererInterface { - public: - string RenderGraph(const string& graph, GraphKind graph_kind, - const DebugOptions& debug_options) override { - static std::atomic output_num(0); - string file_extension; - switch (graph_kind) { - case DOT_GRAPH: - file_extension = ".dot"; - break; - case TF_GRAPHDEF: - file_extension = ".pbtxt"; - break; - } - string path = - JoinPath(debug_options.xla_hlo_graph_path(), - StrCat("hlo_graph_", output_num++, ".XXXXXX", file_extension)); - auto status = Status::OK(); - int fd = mkstemps(&path[0], file_extension.length()); - if (fd < 0) { - status = - Status(tensorflow::error::Code::UNKNOWN, - StrCat("Failed to create temporary file to dump HLO graph: ", - strerror(errno))); - } else { - status = tensorflow::WriteStringToFile(tensorflow::Env::Default(), path, - graph); - close(fd); - } - if (!status.ok()) { - LOG(WARNING) << "Saving HLO graph failed: " << status; - } - return path; - } -}; - // Gets a NodeFilter that includes roughly all instructions whose distance from // root is <= radius. NodeFilter MakeNodeFilter(const HloInstruction* root, int64 radius) { @@ -1289,7 +1293,9 @@ NodeFilter MakeNodeFilter(const HloInstruction* root, int64 radius) { auto is_displayed = [&](const HloInstruction* instr) { // Constants are displayed inline with their users; they're never omitted. - return nodes.count(instr) > 0 || instr->opcode() == HloOpcode::kConstant; + // Nodes in subcomputations are always shown. + return nodes.count(instr) > 0 || instr->opcode() == HloOpcode::kConstant || + instr->parent() != root->parent(); }; // Make a second pass over 'nodes' to fix up the NodeFilterResults now that we @@ -1334,7 +1340,54 @@ NodeFilter MakeNodeFilter(const HloInstruction* root, int64 radius) { }); } -XLA_REGISTER_GRAPH_RENDERER(FileGraphRenderer, 0); +string SaveGraph(const string& graph, + GraphRendererInterface::GraphKind graph_kind, + const string& dest_path) { + static std::atomic output_num(0); + string file_extension; + switch (graph_kind) { + case GraphRendererInterface::DOT_GRAPH: + file_extension = ".dot"; + break; + case GraphRendererInterface::TF_GRAPHDEF: + file_extension = ".pbtxt"; + break; + } + string path = JoinPath( + dest_path, StrCat("hlo_graph_", output_num++, ".XXXXXX", file_extension)); + auto status = Status::OK(); + int fd = mkstemps(&path[0], file_extension.length()); + if (fd < 0) { + status = + Status(tensorflow::error::Code::UNKNOWN, + StrCat("Failed to create temporary file to dump HLO graph: ", + strerror(errno))); + } else { + status = + tensorflow::WriteStringToFile(tensorflow::Env::Default(), path, graph); + close(fd); + } + if (!status.ok()) { + LOG(WARNING) << "Saving HLO graph failed: " << status; + } + return path; +} + +string ExportGraph(const string& graph, + GraphRendererInterface::GraphKind graph_kind, + const DebugOptions& debug_options) { + string path = debug_options.xla_hlo_graph_path(); + if (!path.empty()) { + return SaveGraph(graph, graph_kind, path); + } else { + auto graph_renderer = + GraphRendererRegistry::Default()->GetDefaultRenderer(); + CHECK(graph_renderer != nullptr) + << "No registered renderer for the HLO graph. " + "Use --xla_hlo_graph_path=PATH to export to local file system"; + return graph_renderer->RenderGraph(graph, graph_kind, debug_options); + } +} } // namespace @@ -1342,27 +1395,22 @@ string DumpGraph(const HloComputation& computation, const string& label, const DebugOptions& debug_options, const HloExecutionProfile* hlo_execution_profile, bool show_metadata) { + GraphRendererInterface::GraphKind graph_kind; string graph; - string graph_url; if (debug_options.xla_hlo_dump_as_graphdef()) { - HloTfGraphBuilder builder; + HloTfGraphBuilder builder(debug_options); TF_CHECK_OK(builder.AddComputation(computation)); CHECK(tensorflow::protobuf::TextFormat::PrintToString(builder.GetGraphDef(), &graph)); - // TODO(b/37198616): Use the default registered renderers when all - // renderers support rendering GraphDefs. Always dump GraphDefs to files - // for now. - graph_url = FileGraphRenderer().RenderGraph( - graph, GraphRendererInterface::TF_GRAPHDEF, debug_options); + graph_kind = GraphRendererInterface::TF_GRAPHDEF; } else { - graph = - HloDotDumper(&computation, label, - /*show_addresses=*/debug_options.xla_hlo_graph_addresses(), - show_metadata, hlo_execution_profile, NodeFilter()) - .Dump(); - graph_url = GetGraphRenderer()->RenderGraph( - graph, GraphRendererInterface::DOT_GRAPH, debug_options); + graph = HloDotDumper(&computation, label, debug_options, show_metadata, + hlo_execution_profile, NodeFilter()) + .Dump(); + graph_kind = GraphRendererInterface::DOT_GRAPH; } + + string graph_url = ExportGraph(graph, graph_kind, debug_options); LOG(INFO) << "computation " << computation.name() << " [" << label << "]: " << graph_url; return graph_url; @@ -1375,12 +1423,10 @@ string DumpNeighborhoodAround(const HloInstruction& node, int radius, StrCat("Neighborhood of ", radius, " nodes around ", node.name()); NodeFilter filter = MakeNodeFilter(&node, radius); string graph = - HloDotDumper(node.parent(), label, - /*show_addresses=*/debug_options.xla_hlo_graph_addresses(), - show_metadata, /*profile=*/nullptr, filter) + HloDotDumper(node.parent(), label, debug_options, show_metadata, + /*profile=*/nullptr, filter) .Dump(); - return GetGraphRenderer()->RenderGraph( - graph, GraphRendererInterface::DOT_GRAPH, debug_options); + return ExportGraph(graph, GraphRendererInterface::DOT_GRAPH, debug_options); } void DumpText(const HloModule& module, const string& label, @@ -1391,7 +1437,8 @@ void DumpText(const HloModule& module, const string& label, string filename = do_prefix ? StrCat(prefix, "-", label, ".txt") : StrCat(label, ".txt"); string path = JoinPath(directory_path, filename); - TF_CHECK_OK(WriteStringToFile(env, path, module.ToString())); + TF_CHECK_OK(WriteStringToFile( + env, path, module.ToString(/*include_large_constants=*/true))); LOG(INFO) << "dumping module '" << module.name() << "' to " << path; } diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.h b/tensorflow/compiler/xla/service/hlo_graph_dumper.h index dd304ec76cd903a6175337551fc50808b1797104..2704aae1e3ba7fb131bfcb1287d807d785fd9774 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.h +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.h @@ -84,11 +84,10 @@ void DumpText(const HloModule& module, const string& label, // Internal implementation details below this point. -// Class that registers a graph renderer. Higher-priority renders are chosen -// first. +// Class that registers a graph renderer. class Registrar { public: - Registrar(GraphRendererInterface* dumper, int priority); + Registrar(GraphRendererInterface* dumper); }; #define XLA_INTERNAL_REGISTER_GRAPH_RENDERER(factory, ctr, ...) \ diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc index 7b0f937f383a416f805a799bd6787afe15b324b0..8e1531c87f9c6e133e2d6763b046b1d5dcbcd09f 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc @@ -45,7 +45,7 @@ class DotRenderer : public hlo_graph_dumper::GraphRendererInterface { string last_graph_; }; -XLA_REGISTER_GRAPH_RENDERER(DotRenderer, std::numeric_limits::max()); +XLA_REGISTER_GRAPH_RENDERER(DotRenderer); TEST(HloGraphDumperTest, NestedFusion) { HloComputation::Builder b("b"); diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index e6a4f68fb38001a65ea4d9d0b2b1ddaca4d85106..784930195796220646e80cc1cd7a1b342083acfc 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -26,7 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/ptr_util.h" -#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" @@ -43,6 +43,7 @@ limitations under the License. namespace xla { +using tensorflow::str_util::CEscape; using ::tensorflow::str_util::Join; using ::tensorflow::strings::StrAppend; using ::tensorflow::strings::StrCat; @@ -51,7 +52,9 @@ using ::tensorflow::strings::StrCat; StatusOr> HloInstruction::CreateFromProto( HloModule* module, const HloInstructionProto& proto, const tensorflow::gtl::FlatMap& instruction_map, - tensorflow::gtl::FlatMap* computation_map) { + const tensorflow::gtl::FlatMap& computation_map, + const std::function)>& + add_fused_computation) { TF_RET_CHECK(!proto.opcode().empty()); TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(proto.opcode())); TF_RET_CHECK(proto.has_shape()); @@ -77,19 +80,19 @@ StatusOr> HloInstruction::CreateFromProto( TF_RET_CHECK(!proto.fusion_kind().empty()); TF_ASSIGN_OR_RETURN(instruction->fusion_kind_, StringToFusionKind(proto.fusion_kind())); - TF_ASSIGN_OR_RETURN( - std::unique_ptr fused_computation, - HloComputation::CreateFromProto( - module, proto.fused_instructions_computation(), computation_map, - /*fusion_instruction=*/instruction.get())); - instruction->called_computations_.push_back( - module->AddEmbeddedComputation(std::move(fused_computation))); + TF_ASSIGN_OR_RETURN(std::unique_ptr fused_computation, + HloComputation::CreateFromProto( + module, proto.fused_instructions_computation(), + computation_map, add_fused_computation, + /*fusion_instruction=*/instruction.get())); + instruction->called_computations_.push_back(fused_computation.get()); + add_fused_computation(std::move(fused_computation)); } else { for (const string& computation_name : proto.called_computation_names()) { - TF_RET_CHECK(ContainsKey(*computation_map, computation_name)) + TF_RET_CHECK(ContainsKey(computation_map, computation_name)) << "No computation named " << computation_name; instruction->called_computations_.push_back( - computation_map->at(computation_name)); + computation_map.at(computation_name)); } } @@ -115,6 +118,10 @@ StatusOr> HloInstruction::CreateFromProto( MakeUnique( proto.convolution_dimension_numbers()); } + if (proto.has_dot_dimension_numbers()) { + instruction->dot_dimension_numbers_ = + MakeUnique(proto.dot_dimension_numbers()); + } for (const HloInstructionProto::SliceDimensions& slice_dimensions : proto.slice_dimensions()) { instruction->slice_starts_.push_back(slice_dimensions.start()); @@ -148,7 +155,7 @@ StatusOr> HloInstruction::CreateFromProto( WrapUnique(new HloInstruction(HloOpcode::kParameter, shape)); instruction->parameter_number_ = parameter_number; instruction->parameter_name_ = name; - instruction->name_ = "%" + name; + instruction->name_ = name; return instruction; } @@ -329,6 +336,17 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, return instruction; } +/* static */ std::unique_ptr HloInstruction::CreateDot( + const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, + const DotDimensionNumbers& dimension_numbers) { + auto instruction = WrapUnique(new HloInstruction(HloOpcode::kDot, shape)); + instruction->AppendOperand(lhs); + instruction->AppendOperand(rhs); + instruction->dot_dimension_numbers_ = + MakeUnique(dimension_numbers); + return instruction; +} + /* static */ std::unique_ptr HloInstruction::CreateReducePrecision(const Shape& shape, HloInstruction* operand, @@ -343,12 +361,9 @@ HloInstruction::CreateReducePrecision(const Shape& shape, } /* static */ std::unique_ptr -HloInstruction::CreateCrossReplicaSum(const Shape& shape, - HloInstruction* operand) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kCrossReplicaSum, shape)); - instruction->AppendOperand(operand); - return instruction; +HloInstruction::CreateCrossReplicaSum( + const Shape& shape, tensorflow::gtl::ArraySlice operands) { + return CreateNary(shape, HloOpcode::kCrossReplicaSum, operands); } /* static */ std::unique_ptr HloInstruction::CreateInfeed( @@ -371,20 +386,50 @@ HloInstruction::CreateCrossReplicaSum(const Shape& shape, /* static */ std::unique_ptr HloInstruction::CreateSend( HloInstruction* operand, int64 channel_id) { + // Send instruction produces a tuple of {aliased operand, U32 context}. + Shape output_shape = ShapeUtil::MakeTupleShape( + {operand->shape(), ShapeUtil::MakeShape(U32, {})}); auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kSend, ShapeUtil::MakeNil())); + WrapUnique(new HloInstruction(HloOpcode::kSend, output_shape)); instruction->AppendOperand(operand); instruction->channel_id_ = channel_id; return instruction; } +/* static */ std::unique_ptr HloInstruction::CreateSendDone( + HloInstruction* operand) { + CHECK(operand->opcode() == HloOpcode::kSend) + << "SendDone must take the context operand from Send"; + auto instruction = WrapUnique( + new HloInstruction(HloOpcode::kSendDone, ShapeUtil::MakeNil())); + instruction->AppendOperand(operand); + instruction->channel_id_ = operand->channel_id(); + return instruction; +} + /* static */ std::unique_ptr HloInstruction::CreateRecv( const Shape& shape, int64 channel_id) { - auto instruction = WrapUnique(new HloInstruction(HloOpcode::kRecv, shape)); + // Recv instruction produces a tuple of {receive buffer, U32 context}. + Shape output_shape = + ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {})}); + auto instruction = + WrapUnique(new HloInstruction(HloOpcode::kRecv, output_shape)); instruction->channel_id_ = channel_id; return instruction; } +/* static */ std::unique_ptr HloInstruction::CreateRecvDone( + HloInstruction* operand) { + CHECK(operand->opcode() == HloOpcode::kRecv) + << "RecvDone must take the context operand from Recv"; + Shape output_shape = ShapeUtil::GetTupleElementShape(operand->shape(), 0); + auto instruction = + WrapUnique(new HloInstruction(HloOpcode::kRecvDone, output_shape)); + instruction->AppendOperand(operand); + instruction->channel_id_ = operand->channel_id(); + return instruction; +} + /* static */ std::unique_ptr HloInstruction::CreateReverse( const Shape& shape, HloInstruction* operand, tensorflow::gtl::ArraySlice dimensions) { @@ -405,6 +450,23 @@ HloInstruction::CreateCrossReplicaSum(const Shape& shape, return instruction; } +/* static */ std::unique_ptr HloInstruction::CreateConditional( + const Shape& shape, HloInstruction* pred, + HloInstruction* true_computation_arg, HloComputation* true_computation, + HloInstruction* false_computation_arg, HloComputation* false_computation) { + auto instruction = + WrapUnique(new HloInstruction(HloOpcode::kConditional, shape)); + instruction->AppendOperand(pred); + instruction->AppendOperand(true_computation_arg); + instruction->AppendOperand(false_computation_arg); + // In called_computations_, the index of true_computation must be 0 and that + // of false computation must be 1, as defined by kTrueComputationIndex and + // kFalseComputationIndex. + instruction->called_computations_.push_back(true_computation); + instruction->called_computations_.push_back(false_computation); + return instruction; +} + /* static */ std::unique_ptr HloInstruction::CreateSlice( const Shape& shape, HloInstruction* operand, tensorflow::gtl::ArraySlice start_indices, @@ -468,6 +530,15 @@ HloInstruction::CreateDynamicUpdateSlice(const Shape& shape, return instruction; } +/* static */ std::unique_ptr +HloInstruction::CreateBitcastConvert(const Shape& shape, + HloInstruction* operand) { + auto instruction = + WrapUnique(new HloInstruction(HloOpcode::kBitcastConvert, shape)); + instruction->AppendOperand(operand); + return instruction; +} + /* static */ std::unique_ptr HloInstruction::CreateReduce( const Shape& shape, HloInstruction* arg, HloInstruction* init_value, tensorflow::gtl::ArraySlice dimensions_to_reduce, @@ -600,7 +671,10 @@ HloInstruction::CreateSelectAndScatter( CHECK_EQ(shape.dimensions().size(), operand->shape().dimensions().size()); CHECK(std::equal(operand->shape().dimensions().begin(), operand->shape().dimensions().end(), - Permute(dimensions, shape.dimensions()).begin())); + Permute(dimensions, shape.dimensions()).begin())) + << "shape: " << ShapeUtil::HumanString(shape) + << ", operand->shape(): " << ShapeUtil::HumanString(shape) + << ", dimensions: {" << Join(dimensions, ", ") << "}"; auto instruction = WrapUnique(new HloInstruction(HloOpcode::kTranspose, shape)); instruction->AppendOperand(operand); @@ -618,6 +692,20 @@ HloInstruction::CreateSelectAndScatter( return instruction; } +/* static */ std::unique_ptr HloInstruction::CreateFusion( + const Shape& shape, FusionKind fusion_kind, + tensorflow::gtl::ArraySlice operands, + HloComputation* fusion_computation) { + auto instruction = WrapUnique(new HloInstruction(HloOpcode::kFusion, shape)); + for (auto operand : operands) { + instruction->AppendOperand(operand); + } + instruction->fusion_kind_ = fusion_kind; + instruction->called_computations_.push_back(fusion_computation); + fusion_computation->SetFusionInstruction(instruction.get()); + return instruction; +} + /* static */ std::unique_ptr HloInstruction::CreateFusionForBackwardConvolution( const Shape& shape, FusionKind fusion_kind, const Window& window, @@ -746,7 +834,7 @@ HloInstruction* HloInstruction::FuseInstructionInternal( HloInstruction* HloInstruction::CloneAndFuseInternal( HloInstruction* instruction_to_fuse, bool add_output) { CHECK_EQ(opcode_, HloOpcode::kFusion); - CHECK(instruction_to_fuse->IsFusable()); + CHECK(instruction_to_fuse->IsFusable()) << instruction_to_fuse->ToString(); VLOG(3) << "CloneAndFuseInternal:\n" << instruction_to_fuse->ToString(); HloInstruction* clone = nullptr; if (called_computations_.empty()) { @@ -824,10 +912,8 @@ HloInstruction* HloInstruction::CloneAndFuseInternal( // parameter instruction. int64 param_no = fused_parameters.size(); // Name the parameter after the instruction it represents in the outer - // (non-fusion) computation. Strip the leading "%" from the operand name - // to avoid a double %%. - string param_name = - StrCat(operand->name().substr(1), ".param_", param_no); + // (non-fusion) computation. + string param_name = StrCat(operand->name(), ".param_", param_no); fused_param = fused_instructions_computation()->AddParameter( CreateParameter(param_no, operand->shape(), param_name)); AppendOperand(operand); @@ -908,7 +994,10 @@ RandomDistribution HloInstruction::random_distribution() const { bool HloInstruction::HasSideEffect() const { switch (opcode_) { case HloOpcode::kSend: + case HloOpcode::kSendDone: case HloOpcode::kRecv: + case HloOpcode::kRecvDone: + case HloOpcode::kRng: case HloOpcode::kInfeed: case HloOpcode::kOutfeed: case HloOpcode::kTrace: @@ -961,11 +1050,12 @@ bool HloInstruction::HasSideEffect() const { std::unique_ptr HloInstruction::CloneWithNewOperands( const Shape& shape, - tensorflow::gtl::ArraySlice new_operands) const { + tensorflow::gtl::ArraySlice new_operands, + HloModule* module) const { VLOG(3) << "CloneWithNewOperands:\n " << ToString(); VLOG(3) << " new operands:"; for (const HloInstruction* new_operand : new_operands) { - VLOG(3) << " " << new_operand->name(); + VLOG(3) << " %" << new_operand->name(); } std::unique_ptr clone; @@ -1009,7 +1099,6 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kLe: case HloOpcode::kLt: case HloOpcode::kNe: - case HloOpcode::kDot: case HloOpcode::kMaximum: case HloOpcode::kMinimum: case HloOpcode::kPower: @@ -1047,6 +1136,10 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( CHECK_EQ(new_operands.size(), 1); clone = CreateConvert(shape, new_operands[0]); break; + case HloOpcode::kBitcastConvert: + CHECK_EQ(new_operands.size(), 1); + clone = CreateBitcastConvert(shape, new_operands[0]); + break; case HloOpcode::kReducePrecision: CHECK_EQ(new_operands.size(), 1); clone = CreateReducePrecision(shape, new_operands[0], exponent_bits_, @@ -1057,9 +1150,13 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( clone = CreateConvolve(shape, new_operands[0], new_operands[1], *window_, *convolution_dimension_numbers_); break; + case HloOpcode::kDot: + CHECK_EQ(new_operands.size(), 2); + clone = CreateDot(shape, new_operands[0], new_operands[1], + *dot_dimension_numbers_); + break; case HloOpcode::kCrossReplicaSum: - CHECK_EQ(new_operands.size(), 1); - clone = CreateCrossReplicaSum(shape, new_operands[0]); + clone = CreateCrossReplicaSum(shape, new_operands); break; case HloOpcode::kGetTupleElement: CHECK_EQ(new_operands.size(), 1); @@ -1131,7 +1228,7 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( clone = CreateConstant(literal_->CloneToUnique()); break; case HloOpcode::kFusion: - clone = CloneFusionWithNewOperands(shape, new_operands); + clone = CloneFusionWithNewOperands(shape, new_operands, module); break; case HloOpcode::kParameter: clone = CreateParameter(parameter_number_, shape, parameter_name_); @@ -1162,21 +1259,28 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( new_operands[2], new_operands[3], new_operands[4], epsilon(), feature_index()); break; + case HloOpcode::kConditional: case HloOpcode::kRecv: + case HloOpcode::kRecvDone: case HloOpcode::kSend: + case HloOpcode::kSendDone: case HloOpcode::kTrace: LOG(FATAL) << "Not yet implemented, clone: " << HloOpcodeString(opcode_); } clone->set_metadata(metadata_); + if (has_sharding()) { + clone->set_sharding(sharding()); + } + clone->set_parent(parent_); return clone; } HloInstruction::~HloInstruction() {} -std::unique_ptr HloInstruction::Clone( - const string& suffix) const { +std::unique_ptr HloInstruction::Clone(const string& suffix, + HloModule* module) const { std::unique_ptr clone = - CloneWithNewOperands(shape_, operands_); + CloneWithNewOperands(shape_, operands_, module); if (suffix.empty()) { clone->name_ = name(); } else { @@ -1210,16 +1314,12 @@ std::unique_ptr HloInstruction::Clone( } } } - clone->set_parent(parent_); - if (has_sharding()) { - clone->set_sharding(sharding()); - } return clone; } std::unique_ptr HloInstruction::CloneFusionWithNewOperands( - const Shape& shape, - tensorflow::gtl::ArraySlice operands) const { + const Shape& shape, tensorflow::gtl::ArraySlice operands, + HloModule* module) const { CHECK_EQ(opcode_, HloOpcode::kFusion); CHECK(parent() != nullptr); @@ -1230,13 +1330,14 @@ std::unique_ptr HloInstruction::CloneFusionWithNewOperands( new_instruction->AppendOperand(new_operand); } // Clone all the fused instructions for the new fusion instruction. - std::map old_to_new; + HloInstructionMap old_to_new; std::list> new_fused_instructions; // Create the list of fused parameters by mapping through the cloned, // fused instructions. for (HloInstruction* old_fused_parameter : fused_instructions_computation()->parameter_instructions()) { - new_fused_instructions.push_back(old_fused_parameter->Clone()); + new_fused_instructions.push_back( + old_fused_parameter->Clone("clone", module)); HloInstruction* new_fusion_parameter = new_fused_instructions.back().get(); InsertOrDie(&old_to_new, old_fused_parameter, new_fusion_parameter); } @@ -1255,7 +1356,7 @@ std::unique_ptr HloInstruction::CloneFusionWithNewOperands( } new_fused_instructions.push_back( old_fused_instruction->CloneWithNewOperands( - old_fused_instruction->shape(), new_operands)); + old_fused_instruction->shape(), new_operands, module)); HloInstruction* new_fused_instruction = new_fused_instructions.back().get(); new_fused_instruction->set_parent(parent_); InsertOrDie(&old_to_new, old_fused_instruction, new_fused_instruction); @@ -1271,12 +1372,13 @@ std::unique_ptr HloInstruction::CloneFusionWithNewOperands( ++new_fused_instruction_iter) { computation_builder.AddInstruction(std::move(*new_fused_instruction_iter)); } + if (module == nullptr) { + module = GetModule(); + } auto fused_root_ = fused_expression_root(); new_instruction->called_computations_.push_back( - CHECK_NOTNULL(GetModule()) - ->AddEmbeddedComputation( - computation_builder.Build(FindOrDie(old_to_new, fused_root_)))); - new_instruction->set_parent(parent_); + CHECK_NOTNULL(module)->AddEmbeddedComputation( + computation_builder.Build(FindOrDie(old_to_new, fused_root_)))); return new_instruction; } @@ -1350,7 +1452,7 @@ int64 HloInstruction::operand_index(const HloInstruction* target) const { return i; } } - LOG(FATAL) << "target was not an operand"; + LOG(FATAL) << "target was not an operand: " << target->ToString(); } Status HloInstruction::AddControlDependencyTo(HloInstruction* instruction) { @@ -1423,7 +1525,6 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kCos: case HloOpcode::kCrossReplicaSum: case HloOpcode::kDivide: - case HloOpcode::kDot: case HloOpcode::kEq: case HloOpcode::kExp: case HloOpcode::kFloor: @@ -1482,6 +1583,7 @@ bool HloInstruction::IdenticalSlowPath( // A convert result is determined by the primitive type that the operand is // converted into. case HloOpcode::kConvert: + case HloOpcode::kBitcastConvert: return shape().element_type() == other.shape().element_type(); // A reduce-precision operation is determined by the bit sizes. @@ -1495,6 +1597,10 @@ bool HloInstruction::IdenticalSlowPath( protobuf_util::ProtobufEquals( convolution_dimension_numbers(), other.convolution_dimension_numbers()); + // Check dot dimension numbers. + case HloOpcode::kDot: + return protobuf_util::ProtobufEquals(dot_dimension_numbers(), + other.dot_dimension_numbers()); // Reduction results are determined by the reduction dimension and the // reduction computation. @@ -1535,7 +1641,8 @@ bool HloInstruction::IdenticalSlowPath( other.padding_config()); case HloOpcode::kSlice: return slice_starts_ == other.slice_starts_ && - slice_limits_ == other.slice_limits_; + slice_limits_ == other.slice_limits_ && + slice_strides_ == other.slice_strides_; case HloOpcode::kDynamicSlice: return ShapeUtil::Compatible(shape(), other.shape()) && dynamic_slice_sizes_ == other.dynamic_slice_sizes_; @@ -1550,11 +1657,14 @@ bool HloInstruction::IdenticalSlowPath( return dimensions() == other.dimensions(); // These opcodes are not yet supported. + case HloOpcode::kConditional: case HloOpcode::kInfeed: case HloOpcode::kOutfeed: case HloOpcode::kSort: - case HloOpcode::kSend: case HloOpcode::kRecv: + case HloOpcode::kRecvDone: + case HloOpcode::kSend: + case HloOpcode::kSendDone: return false; } } @@ -1757,6 +1867,32 @@ void HloInstruction::set_scatter(HloComputation* computation) { called_computations_[kScatterComputationIndex] = computation; } +HloComputation* HloInstruction::true_computation() const { + CHECK_EQ(HloOpcode::kConditional, opcode_); + return called_computations_[kTrueComputationIndex]; +} + +HloComputation* HloInstruction::false_computation() const { + CHECK_EQ(HloOpcode::kConditional, opcode_); + return called_computations_[kFalseComputationIndex]; +} + +void HloInstruction::set_true_computation(HloComputation* true_computation) { + // Don't allow changing the computation for fused instructions so we don't + // have to recompute called_instructions for the entire fusion instruction. + CHECK(!IsFused()); + CHECK_EQ(HloOpcode::kConditional, opcode_); + called_computations_[kTrueComputationIndex] = true_computation; +} + +void HloInstruction::set_false_computation(HloComputation* false_computation) { + // Don't allow changing the computation for fused instructions so we don't + // have to recompute called_instructions for the entire fusion instruction. + CHECK(!IsFused()); + CHECK_EQ(HloOpcode::kConditional, opcode_); + called_computations_[kFalseComputationIndex] = false_computation; +} + string HloInstruction::SignatureString() const { string operands = Join(operands_, ", ", [](string* out, HloInstruction* operand) { @@ -1765,36 +1901,31 @@ string HloInstruction::SignatureString() const { return StrCat("(", operands, ") -> ", ShapeUtil::HumanString(shape())); } -string HloInstruction::ExtendedOpcodeStr() const { - string opc_name = HloOpcodeString(opcode()); - HloOpcode opc = opcode(); - if (HloOpcode::kFusion == opc) { - opc_name += ":" + xla::ToString(fusion_kind()); - } - return opc_name; -} - -string HloInstruction::ToString(bool compact_operands, - bool include_metadata) const { +string HloInstruction::ToString(bool compact_operands, bool include_metadata, + bool include_large_constants) const { string result = - StrCat(name(), " = ", ShapeUtil::HumanStringWithLayout(shape()), " ", - ExtendedOpcodeStr(), "(", OperandsToString(compact_operands), ")"); + StrCat("%", name(), " = ", ShapeUtil::HumanStringWithLayout(shape()), " ", + HloOpcodeString(opcode()), "(", + OperandsToString(compact_operands, include_large_constants), ")"); for (const string& extra : ExtraAttributesToString()) { StrAppend(&result, ", ", extra); } if (include_metadata && (!metadata_.op_type().empty() || !metadata_.op_name().empty() || !metadata_.source_file().empty())) { - StrAppend(&result, " # metadata=", metadata_.ShortDebugString()); + StrAppend(&result, ", metadata={", xla::OpMetadataToString(metadata_), "}"); } return result; } -string HloInstruction::OperandsToString(bool compact) const { +string HloInstruction::OperandsToString(bool compact, + bool include_large_constants) const { string operands; if (opcode() == HloOpcode::kConstant) { // For constants, show the actual value in place of an empty operand list. - if (!ShapeUtil::IsTuple(shape()) && ShapeUtil::ElementsIn(shape()) <= 10) { + if ((!ShapeUtil::IsTuple(shape()) && + ShapeUtil::ElementsIn(shape()) <= 10) || + include_large_constants) { // Literal::ToString emits multidimensional arrays over multiple // lines. Compact this into one line by stripping out white space. string tmp = literal().ToString(); @@ -1825,7 +1956,7 @@ string HloInstruction::OperandsToString(bool compact) const { operands = Join(slice, ", ", [&](string* out, HloInstruction* operand) { *out += ShapeUtil::HumanStringWithLayout(operand->shape()); if (!compact) { - StrAppend(out, " ", operand->name()); + StrAppend(out, " %", operand->name()); } }); const int64 remaining = operands_.size() - slice.size(); @@ -1838,16 +1969,20 @@ string HloInstruction::OperandsToString(bool compact) const { std::vector HloInstruction::ExtraAttributesToString() const { std::vector extra; + if (opcode() == HloOpcode::kFusion) { + extra.push_back(StrCat("kind=", xla::ToString(fusion_kind()))); + } if (CanHaveDimensionsField()) { extra.push_back(StrCat("dimensions={", Join(dimensions(), ","), "}")); } - if (window_ != nullptr) { - extra.push_back(window_util::ToString(*window_)); + if (window_ != nullptr && window_->dimensions_size() != 0) { + extra.push_back(StrCat("window={", window_util::ToString(*window_), "}")); } if (padding_config_ != nullptr) { - extra.push_back(StrCat("padding=", padding_config_->ShortDebugString())); + extra.push_back( + StrCat("padding=", xla::PaddingConfigToString(*padding_config_))); } - if (!slice_starts_.empty() && !slice_limits_.empty()) { + if (opcode() == HloOpcode::kSlice) { std::vector bounds; bounds.reserve(slice_starts_.size()); const bool omit_stride = @@ -1860,10 +1995,23 @@ std::vector HloInstruction::ExtraAttributesToString() const { } extra.push_back(StrCat("slice={", Join(bounds, ", "), "}")); } + if (opcode() == HloOpcode::kDynamicSlice) { + extra.push_back( + StrCat("dynamic_slice_sizes={", Join(dynamic_slice_sizes(), ","), "}")); + } + if (opcode() == HloOpcode::kBatchNormTraining || + opcode() == HloOpcode::kBatchNormInference || + opcode() == HloOpcode::kBatchNormGrad) { + extra.push_back(StrCat("epsilon=", epsilon())); + extra.push_back(StrCat("feature_index=", feature_index())); + } if (convolution_dimension_numbers_ != nullptr) { extra.push_back(ConvolutionDimensionNumbersToString()); } + if (dot_dimension_numbers_ != nullptr) { + extra.push_back(DotDimensionNumbersToString()); + } if (opcode() == HloOpcode::kWhile) { extra.push_back(StrCat("condition=%", while_condition()->name())); @@ -1883,7 +2031,8 @@ std::vector HloInstruction::ExtraAttributesToString() const { }))); } - if (opcode() == HloOpcode::kSend || opcode() == HloOpcode::kRecv) { + if (opcode() == HloOpcode::kSend || opcode() == HloOpcode::kRecv || + opcode() == HloOpcode::kSendDone || opcode() == HloOpcode::kRecvDone) { extra.push_back(StrCat("channel_id=", channel_id_)); } @@ -1893,21 +2042,37 @@ std::vector HloInstruction::ExtraAttributesToString() const { if (has_sharding()) { extra.push_back(StrCat("sharding=", sharding().ToString())); } - if (!control_successors_.empty()) { - extra.push_back(StrCat( - "control-successors=", - Join(control_successors_, ", ", [](string* out, HloInstruction* succ) { - StrAppend(out, succ->name()); - }))); + if (!control_predecessors_.empty()) { + extra.push_back(StrCat("control-predecessors={", + Join(control_predecessors_, ", ", + [](string* out, HloInstruction* pre) { + StrAppend(out, "%", pre->name()); + }), + "}")); + } + if (opcode() == HloOpcode::kInfeed && !infeed_config_.empty()) { + extra.push_back(StrCat("infeed_config=\"", CEscape(infeed_config_), "\"")); + } + if (opcode() == HloOpcode::kOutfeed && !outfeed_config_.empty()) { + extra.push_back( + StrCat("outfeed_config=\"", CEscape(outfeed_config_), "\"")); + } + if (opcode() == HloOpcode::kRng) { + extra.push_back( + StrCat("distribution=", RandomDistributionToString(distribution_))); + } + if (opcode() == HloOpcode::kReducePrecision) { + extra.push_back(StrCat("exponent_bits=", exponent_bits_)); + extra.push_back(StrCat("mantissa_bits=", mantissa_bits_)); } return extra; } string HloInstruction::ToShortString() const { - return StrCat(name(), " = ", HloOpcodeString(opcode()), "(", + return StrCat("%", name(), " = ", HloOpcodeString(opcode()), "(", Join(operands_, ", ", [](string* out, HloInstruction* operand) { - StrAppend(out, operand->name()); + StrAppend(out, "%", operand->name()); }), ")"); } @@ -1951,6 +2116,9 @@ HloInstructionProto HloInstruction::ToProto() const { *proto.mutable_convolution_dimension_numbers() = *convolution_dimension_numbers_; } + if (dot_dimension_numbers_ != nullptr) { + *proto.mutable_dot_dimension_numbers() = *dot_dimension_numbers_; + } for (int i = 0; i < slice_starts_.size(); ++i) { auto* slice_dimension = proto.add_slice_dimensions(); slice_dimension->set_start(slice_starts_[i]); @@ -2001,8 +2169,10 @@ string HloInstruction::ToCategory() const { bool saw_rank_1 = false; bool saw_higher_rank = false; for (const auto* operand : operands()) { - saw_rank_1 |= ShapeUtil::Rank(operand->shape()) == 1; - saw_higher_rank |= ShapeUtil::Rank(operand->shape()) > 1; + if (!ShapeUtil::IsTuple(operand->shape())) { + saw_rank_1 |= ShapeUtil::Rank(operand->shape()) == 1; + saw_higher_rank |= ShapeUtil::Rank(operand->shape()) > 1; + } } if (saw_rank_1 && saw_higher_rank) { return "rank-1-broadcast binary fusion"; @@ -2055,23 +2225,13 @@ bool HloInstruction::IsFusable() const { if (tracing()) { return false; } - // Some kinds of instructions don't make sense to fuse. switch (opcode_) { - case HloOpcode::kInfeed: - case HloOpcode::kOutfeed: case HloOpcode::kParameter: - case HloOpcode::kTrace: - case HloOpcode::kSend: - case HloOpcode::kRecv: return false; - // Only fuse Rng if it is used once, otherwise the random numbers generated - // will be different in each fusion. If it is the root (user count = 0) - // then it is the equivalent of having one user. - case HloOpcode::kRng: - return users_.size() <= 1; + // Side effecting instrutions cannot be fused. default: - return true; + return !HasSideEffect(); } } @@ -2122,11 +2282,12 @@ HloInstruction::HloInstruction(HloOpcode opcode, const Shape& shape) : unique_id_(-1), opcode_(opcode), shape_(shape), - name_("%" + HloOpcodeString(opcode)) { + name_(HloOpcodeString(opcode)) { TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape_)); } -Status HloInstruction::Visit(DfsHloVisitor* visitor) { +template +Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { switch (opcode_) { case HloOpcode::kAbs: return visitor->HandleAbs(this); @@ -2181,6 +2342,8 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) { return visitor->HandleConcatenate(this); case HloOpcode::kConvert: return visitor->HandleConvert(this); + case HloOpcode::kBitcastConvert: + return visitor->HandleBitcastConvert(this); case HloOpcode::kCopy: return visitor->HandleCopy(this); case HloOpcode::kMultiply: @@ -2267,12 +2430,18 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) { return visitor->HandleFusion(this); case HloOpcode::kCall: return visitor->HandleCall(this); + case HloOpcode::kConditional: + return visitor->HandleConditional(this); case HloOpcode::kCustomCall: return visitor->HandleCustomCall(this); - case HloOpcode::kSend: - return visitor->HandleSend(this); case HloOpcode::kRecv: return visitor->HandleRecv(this); + case HloOpcode::kRecvDone: + return visitor->HandleRecvDone(this); + case HloOpcode::kSend: + return visitor->HandleSend(this); + case HloOpcode::kSendDone: + return visitor->HandleSendDone(this); // These opcodes are not handled here. case HloOpcode::kTrace: @@ -2282,25 +2451,30 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) { HloOpcodeString(opcode_).c_str()); } +// Explicit instantiations. +template Status HloInstruction::Visit(DfsHloVisitor* visitor); +template Status HloInstruction::Visit(ConstDfsHloVisitor* visitor); + using DFSStack = tensorflow::gtl::InlinedVector, 16>; // Push "child" onto the dfs_stack if not already visited. Returns false if a // cycle was detected, and true otherwise. -inline bool PushDFSChild(DfsHloVisitor* visitor, DFSStack* dfs_stack, +template +inline bool PushDFSChild(Visitor* visitor, DFSStack* dfs_stack, HloInstruction* child) { CHECK(child != nullptr); const int id = child->unique_id(); CHECK_GE(id, 0) << "instruction may not have a parent computation"; switch (visitor->GetVisitState(id)) { - case DfsHloVisitor::kVisiting: + case Visitor::kVisiting: return false; - case DfsHloVisitor::kVisited: + case Visitor::kVisited: // Nothing to do return true; - case DfsHloVisitor::kNotVisited: + case Visitor::kNotVisited: dfs_stack->push_back(std::make_pair(id, child)); return true; } @@ -2309,7 +2483,8 @@ inline bool PushDFSChild(DfsHloVisitor* visitor, DFSStack* dfs_stack, using InternalCompareFunction = std::function, std::pair)>; -static Status PostOrderDFS(HloInstruction* root, DfsHloVisitor* visitor, +template +static Status PostOrderDFS(HloInstruction* root, Visitor* visitor, const InternalCompareFunction* operand_order, bool ignore_control_predecessors) { visitor->ReserveVisitStates(root->GetModule()->NumUniqueInstructionIds()); @@ -2330,26 +2505,27 @@ static Status PostOrderDFS(HloInstruction* root, DfsHloVisitor* visitor, HloInstruction* current_node = dfs_stack.back().second; CHECK_GE(current_id, 0) << current_id << ": " << current_node << ": instruction may not have parent computation"; - DfsHloVisitor::VisitState visit_state = visitor->GetVisitState(current_id); - if (visit_state == DfsHloVisitor::kVisited) { + typename Visitor::VisitState visit_state = + visitor->GetVisitState(current_id); + if (visit_state == Visitor::kVisited) { dfs_stack.pop_back(); - VLOG(3) << "Not visiting HLO " << current_node->name() + VLOG(3) << "Not visiting HLO %" << current_node->name() << " as it was already visited."; continue; } - if (visit_state == DfsHloVisitor::kVisiting) { + if (visit_state == Visitor::kVisiting) { dfs_stack.pop_back(); TF_RETURN_IF_ERROR(visitor->Preprocess(current_node)); - VLOG(2) << "Visiting HLO " << current_node->name(); + VLOG(2) << "Visiting HLO %" << current_node->name(); TF_RETURN_IF_ERROR(current_node->Visit(visitor)); - visitor->SetVisitState(current_id, DfsHloVisitor::kVisited); + visitor->SetVisitState(current_id, Visitor::kVisited); TF_RETURN_IF_ERROR(visitor->Postprocess(current_node)); continue; } - visitor->SetVisitState(current_id, DfsHloVisitor::kVisiting); + visitor->SetVisitState(current_id, Visitor::kVisiting); const size_t old_dfs_stack_size = dfs_stack.size(); for (HloInstruction* child : current_node->operands()) { @@ -2383,9 +2559,11 @@ static Status PostOrderDFS(HloInstruction* root, DfsHloVisitor* visitor, return Status::OK(); } -Status HloInstruction::Accept(DfsHloVisitor* visitor, bool call_finish_visit, +template +Status HloInstruction::Accept(DfsHloVisitorBase* visitor, + bool call_finish_visit, bool ignore_control_predecessors) { - VLOG(3) << "HloInstruction::Accept(" << name() << ")"; + VLOG(3) << "HloInstruction::Accept(%" << name() << ")"; TF_RETURN_IF_ERROR( PostOrderDFS(this, visitor, nullptr, ignore_control_predecessors)); if (call_finish_visit) { @@ -2394,10 +2572,14 @@ Status HloInstruction::Accept(DfsHloVisitor* visitor, bool call_finish_visit, return Status::OK(); } +// Explicit instantiations. +template Status HloInstruction::Accept(DfsHloVisitor*, bool, bool); +template Status HloInstruction::Accept(ConstDfsHloVisitor*, bool, bool); + Status HloInstruction::AcceptWithOperandOrder( DfsHloVisitor* visitor, const CompareFunction& operand_order, bool call_finish_visit) { - VLOG(2) << "HloInstruction::AcceptWithOperandOrder(" << name() << ")"; + VLOG(2) << "HloInstruction::AcceptWithOperandOrder(%" << name() << ")"; InternalCompareFunction func = [&operand_order]( std::pair a, std::pair b) { @@ -2447,14 +2629,20 @@ bool OrderIsTopologicalSort(const std::vector& order) { } // namespace Status HloInstruction::Accept( - const FunctionVisitor::VisitorFunction& visitor_func) { + const std::function& visitor_func) { FunctionVisitor visitor(visitor_func); return this->Accept(&visitor); } +Status HloInstruction::Accept( + const std::function& visitor_func) const { + ConstFunctionVisitor visitor(visitor_func); + return this->Accept(&visitor); +} + Status HloInstruction::AcceptOrdered( DfsHloVisitor* visitor, const std::vector& order) { - VLOG(2) << "HloInstruction::AcceptOrdered(" << name() << ")"; + VLOG(2) << "HloInstruction::AcceptOrdered(%" << name() << ")"; TF_RET_CHECK(OrderIsTopologicalSort(order)); // Compute the predecessors of this instruction. @@ -2473,7 +2661,7 @@ Status HloInstruction::AcceptOrdered( // The visitor can mark instructions as visited to skip particular // instructions. if (visitor->DidVisit(*const_instruction)) { - VLOG(3) << "Not visiting HLO " << const_instruction->name() + VLOG(3) << "Not visiting HLO %" << const_instruction->name() << " as it was already visited."; continue; } @@ -2482,7 +2670,7 @@ Status HloInstruction::AcceptOrdered( const_cast(const_instruction); TF_RETURN_IF_ERROR(visitor->Preprocess(instruction)); - VLOG(2) << "Visiting HLO " << instruction->name(); + VLOG(2) << "Visiting HLO %" << instruction->name(); TF_RETURN_IF_ERROR(instruction->Visit(visitor)); visitor->SetVisited(*instruction); TF_RETURN_IF_ERROR(visitor->Postprocess(instruction)); @@ -2514,33 +2702,7 @@ std::vector HloInstruction::OperandIndices( } bool HloInstruction::IsElementwiseBinary() const { - switch (opcode_) { - // Binary elementwise operations. If you update this, please update - // IsElementwise() accordingly. - case HloOpcode::kAdd: - case HloOpcode::kComplex: - case HloOpcode::kDivide: - case HloOpcode::kEq: - case HloOpcode::kGe: - case HloOpcode::kGt: - case HloOpcode::kLe: - case HloOpcode::kLt: - case HloOpcode::kMaximum: - case HloOpcode::kMinimum: - case HloOpcode::kMultiply: - case HloOpcode::kNe: - case HloOpcode::kPower: - case HloOpcode::kRemainder: - case HloOpcode::kSubtract: - case HloOpcode::kAnd: - case HloOpcode::kOr: - case HloOpcode::kShiftLeft: - case HloOpcode::kShiftRightArithmetic: - case HloOpcode::kShiftRightLogical: - return true; - default: - return false; - } + return IsElementwise() && operand_count() == 2; } bool HloInstruction::IsElementwise() const { @@ -2551,10 +2713,10 @@ bool HloInstruction::IsElementwise() const { // Unary elementwise operations. case HloOpcode::kAbs: - case HloOpcode::kAtan2: case HloOpcode::kRoundNearestAfz: case HloOpcode::kCeil: case HloOpcode::kConvert: + case HloOpcode::kBitcastConvert: case HloOpcode::kCopy: case HloOpcode::kCos: case HloOpcode::kExp: @@ -2569,11 +2731,12 @@ bool HloInstruction::IsElementwise() const { case HloOpcode::kSign: case HloOpcode::kSin: case HloOpcode::kTanh: + CHECK_EQ(1, operand_count()); return true; // Binary elementwise operations, the same as in IsElementwiseBinary(). - // If you update this, please update IsElementwiseBinary() accordingly. case HloOpcode::kAdd: + case HloOpcode::kAtan2: case HloOpcode::kComplex: case HloOpcode::kDivide: case HloOpcode::kEq: @@ -2593,6 +2756,7 @@ bool HloInstruction::IsElementwise() const { case HloOpcode::kShiftLeft: case HloOpcode::kShiftRightArithmetic: case HloOpcode::kShiftRightLogical: + CHECK_EQ(2, operand_count()); return true; // Ternary elementwise operations. @@ -2837,6 +3001,61 @@ StatusOr StringToFusionKind( return InvalidArgument("Unknown fusion kind: %s", kind_name.c_str()); } +string PaddingConfigToString(const PaddingConfig& padding) { + bool has_interior_padding = + std::any_of(padding.dimensions().begin(), padding.dimensions().end(), + [](const PaddingConfig::PaddingConfigDimension& dim) { + return dim.interior_padding() != 0; + }); + return Join( + padding.dimensions(), "x", + [&](string* out, const PaddingConfig::PaddingConfigDimension& dim) { + StrAppend( + out, dim.edge_padding_low(), "_", dim.edge_padding_high(), + has_interior_padding ? StrCat("_", dim.interior_padding()) : ""); + }); +} + +string OpMetadataToString(const OpMetadata& metadata) { + std::vector result; + if (!metadata.op_type().empty()) { + result.push_back(StrCat("op_type=\"", CEscape(metadata.op_type()), "\"")); + } + if (!metadata.op_name().empty()) { + result.push_back(StrCat("op_name=\"", CEscape(metadata.op_name()), "\"")); + } + if (!metadata.source_file().empty()) { + result.push_back( + StrCat("source_file=\"", CEscape(metadata.source_file()), "\"")); + } + if (metadata.source_line() != 0) { + result.push_back(StrCat("source_line=", metadata.source_line())); + } + return Join(result, " "); +} + +string RandomDistributionToString(const RandomDistribution& distribution) { + return tensorflow::str_util::Lowercase(RandomDistribution_Name(distribution)); +} + +StatusOr StringToRandomDistribution(const string& name) { + static std::unordered_map* map = [] { + static auto* map = new std::unordered_map; + for (int i = 0; i < RandomDistribution_ARRAYSIZE; i++) { + if (RandomDistribution_IsValid(i)) { + auto value = static_cast(i); + (*map)[RandomDistributionToString(value)] = value; + } + } + return map; + }(); + auto found = map->find(tensorflow::str_util::Lowercase(name)); + if (found == map->end()) { + return InvalidArgument("Unknown distribution"); + } + return found->second; +} + std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind) { return os << ToString(kind); } @@ -2852,36 +3071,30 @@ string HloInstruction::ConvolutionDimensionNumbersToString() const { const auto append_dims = [&](const std::vector& dims, const Shape& shape) { CHECK_EQ(dims.size(), ShapeUtil::Rank(shape)); - for (int64 logical = 0; logical < dims.size(); ++logical) { - int64 physical = logical; - if (!shape.layout().minor_to_major().empty()) { - physical = LayoutUtil::Major(shape.layout(), logical); - } - result += dims[physical]; - } + StrAppend(&result, Join(dims, "")); }; // lhs_dims[i] is the symbol of the logical dimension i for the lhs // operand. E.g. if batch has dimension number 2, then lhs_dims[2] == "b". - std::vector lhs_dims(2 + dnums.spatial_dimensions().size()); + std::vector lhs_dims(2 + dnums.input_spatial_dimensions().size()); lhs_dims[dnums.input_batch_dimension()] = 'b'; lhs_dims[dnums.input_feature_dimension()] = 'f'; - for (int64 i = 0; i < dnums.spatial_dimensions().size(); ++i) { - lhs_dims[dnums.spatial_dimensions(i)] = StrCat(i); + for (int64 i = 0; i < dnums.input_spatial_dimensions().size(); ++i) { + lhs_dims[dnums.input_spatial_dimensions(i)] = StrCat(i); } std::vector rhs_dims(2 + dnums.kernel_spatial_dimensions().size()); rhs_dims[dnums.kernel_input_feature_dimension()] = "i"; rhs_dims[dnums.kernel_output_feature_dimension()] = "o"; - for (int64 i = 0; i < dnums.spatial_dimensions().size(); ++i) { + for (int64 i = 0; i < dnums.kernel_spatial_dimensions().size(); ++i) { rhs_dims[dnums.kernel_spatial_dimensions(i)] = StrCat(i); } - std::vector output_dims(2 + dnums.spatial_dimensions().size()); + std::vector output_dims(2 + dnums.output_spatial_dimensions().size()); output_dims[dnums.output_batch_dimension()] = 'b'; output_dims[dnums.output_feature_dimension()] = 'f'; - for (int64 i = 0; i < dnums.spatial_dimensions().size(); ++i) { - output_dims[dnums.spatial_dimensions(i)] = StrCat(i); + for (int64 i = 0; i < dnums.output_spatial_dimensions().size(); ++i) { + output_dims[dnums.output_spatial_dimensions(i)] = StrCat(i); } result += "dim_labels="; @@ -2893,6 +3106,30 @@ string HloInstruction::ConvolutionDimensionNumbersToString() const { return result; } +string HloInstruction::DotDimensionNumbersToString() const { + string result; + if (dot_dimension_numbers_ == nullptr) { + return result; + } + const DotDimensionNumbers& dnums = *dot_dimension_numbers_; + if (!dnums.lhs_batch_dimensions().empty()) { + result += "lhs_batch_dims="; + StrAppend(&result, Join(dnums.lhs_batch_dimensions(), ",")); + } + result += "lhs_contracting_dims="; + StrAppend(&result, Join(dnums.lhs_contracting_dimensions(), ",")); + + result += ","; + if (!dnums.rhs_batch_dimensions().empty()) { + result += "rhs_batch_dims="; + StrAppend(&result, Join(dnums.rhs_batch_dimensions(), ",")); + } + result += "rhs_contracting_dims="; + StrAppend(&result, Join(dnums.rhs_contracting_dimensions(), ",")); + + return result; +} + bool HloInstruction::CouldBeBitcast() const { switch (opcode_) { case HloOpcode::kTranspose: diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index e714d7bc71d86815b1b2df44cdd5c67281cdeb62..03cf9aaf907e7437596b9cc1f093fd79d22963b9 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -27,6 +27,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -34,7 +35,6 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" -#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_sharding.h" @@ -44,6 +44,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/lib/gtl/iterator_range.h" #include "tensorflow/core/platform/logging.h" @@ -83,12 +84,16 @@ class HloInstruction { // must contain all operands of the newly constructed instruction. // computation_map: a map from computation name to HloComputation*. This map // must contain all computations which the newly constructed instruction - // calls. If the instruction is a fusion instruction, then the fusion - // computation is added to this map and the module. + // calls. + // add_fused_computation: A function to call to add a fused + // computation. Used (clearly) when the instruction is a fusion + // instruction. static StatusOr> CreateFromProto( HloModule* module, const HloInstructionProto& proto, const tensorflow::gtl::FlatMap& instruction_map, - tensorflow::gtl::FlatMap* computation_map); + const tensorflow::gtl::FlatMap& computation_map, + const std::function)>& + add_fused_computation); // Creates a parameter-retrieving instruction. static std::unique_ptr CreateParameter(int64 parameter_number, @@ -155,6 +160,12 @@ class HloInstruction { const Window& window, const ConvolutionDimensionNumbers& dimension_numbers); + // Creates a dot op with operands 'lhs' and 'rhs' with contracting and batch + // dimensions specified in 'dimension_numbers'. + static std::unique_ptr CreateDot( + const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, + const DotDimensionNumbers& dimension_numbers); + // Creates a reduce-precision op, where operand is the data to reduce in // precision, and exponent_bits and mantissa_bits describe the precision to // reduce it to. @@ -164,13 +175,19 @@ class HloInstruction { // Creates a cross replica sum op. static std::unique_ptr CreateCrossReplicaSum( - const Shape& shape, HloInstruction* operand); + const Shape& shape, + tensorflow::gtl::ArraySlice operands); // Creates a conversion instruction, where operand is the data to convert and // shape is the target shape for the conversion. static std::unique_ptr CreateConvert(const Shape& shape, HloInstruction* operand); + // Creates a bitcast conversion instruction, where operand is the data to + // convert and shape is the target shape for the conversion. + static std::unique_ptr CreateBitcastConvert( + const Shape& shape, HloInstruction* operand); + // Creates an infeed instruction, which reads data of the given shape from the // Infeed interface of the device. static std::unique_ptr CreateInfeed(const Shape& shape, @@ -181,18 +198,28 @@ class HloInstruction { const Shape& shape, HloInstruction* operand, tensorflow::StringPiece outfeed_config); - // Creates a send instruction with the given channel id, which sends the - // operand data to a unique receive instruction in another computation that - // has the same channel id. + // Creates an asynchronous send instruction with the given channel id, which + // initiates sending the operand data to a unique receive instruction in + // another computation that has the same channel id. static std::unique_ptr CreateSend(HloInstruction* operand, int64 channel_id); - // Creates a receive instruction with the given channel id, which receives - // data of the given shape from a unique send instruction in another - // computation that has the same channel id. + // Blocks until data transfer for the Send instruction (operand) is complete. + // The operand must be kSend. + static std::unique_ptr CreateSendDone( + HloInstruction* operand); + + // Creates an asynchronous receive instruction with the given channel id, + // which allocates resources to receive data of the given shape from a unique + // send instruction in another computation that has the same channel id. static std::unique_ptr CreateRecv(const Shape& shape, int64 channel_id); + // Blocks until data transfer for the Recv instruction (operand) is complete + // and returns the receive buffer. The operand must be kRecv. + static std::unique_ptr CreateRecvDone( + HloInstruction* operand); + // Creates a slice instruction, where the operand is sliced by the given // start/limit indices. static std::unique_ptr CreateSlice( @@ -202,7 +229,7 @@ class HloInstruction { tensorflow::gtl::ArraySlice strides); // Creates a slice instruction, where the first operand is sliced by - // start indices specified in the second operand, and by size specfied in + // start indices specified in the second operand, and by size specified in // 'slice_sizes'. static std::unique_ptr CreateDynamicSlice( const Shape& shape, HloInstruction* operand, @@ -295,6 +322,11 @@ class HloInstruction { HloComputation* body, HloInstruction* init); + static std::unique_ptr CreateConditional( + const Shape& shape, HloInstruction* pred, + HloInstruction* true_computation_arg, HloComputation* true_computation, + HloInstruction* false_computation_arg, HloComputation* false_computation); + // Creates a fusion instruction. A fusion instruction contains one or more // fused instructions forming an expression with a single root // "fused_root". Additional instructions can be added to the fusion @@ -302,6 +334,11 @@ class HloInstruction { static std::unique_ptr CreateFusion( const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root); + static std::unique_ptr CreateFusion( + const Shape& shape, FusionKind fusion_kind, + tensorflow::gtl::ArraySlice operands, + HloComputation* fusion_computation); + // Creates a fusion instruction that represents backward convolution. This is // similar to CreateFusion, but with extra arguments indicating the window and // dimemsion mapping of the backward convolution. @@ -391,7 +428,7 @@ class HloInstruction { Status RemoveControlDependencyTo(HloInstruction* instruction); // Returns the set of control predecessors (successors) of this - // instruction. Control predecessors (sucessors) must execute before (after) + // instruction. Control predecessors (successors) must execute before (after) // the current instruction. const std::vector& control_predecessors() const { return control_predecessors_; @@ -458,8 +495,15 @@ class HloInstruction { // reachable via control dependencies will not be visited, and the postorder // will not take control dependencies into account. It is as if the control // dependencies didn't exist in the graph at all. - Status Accept(DfsHloVisitor* visitor, bool call_finish_visit = true, + template + Status Accept(DfsHloVisitorBase* visitor, + bool call_finish_visit = true, bool ignore_control_predecessors = false); + Status Accept(ConstDfsHloVisitor* visitor, bool call_finish_visit = true, + bool ignore_control_predecessors = false) const { + return const_cast(this)->Accept( + visitor, call_finish_visit, ignore_control_predecessors); + } // Same as Accept() above, but the order of operand and control predecessor // visitation is determined by the given operand order; if compare(A, B) == @@ -472,7 +516,9 @@ class HloInstruction { // Performs a postorder DFS visit using this node as the root. Calls the given // visitor function at each instruction. - Status Accept(const FunctionVisitor::VisitorFunction& visitor_func); + Status Accept(const std::function& visitor_func); + Status Accept( + const std::function& visitor_func) const; // Visits all instructions rooted at this instruction using the given visitor // in the given order. 'order' must contain at least the set of instructions @@ -485,7 +531,8 @@ class HloInstruction { const std::vector& order); // Visit this instruction and only this instruction with the given visitor. - Status Visit(DfsHloVisitor* visitor); + template + Status Visit(DfsHloVisitorBase* visitor); // Returns the literal associated with this instruction. // @@ -583,18 +630,27 @@ class HloInstruction { void set_select(HloComputation* select); void set_scatter(HloComputation* scatter); + // Gets/sets the true and false HloComputation for Conditional. The setters + // should only be called by HloModule or HloComputation methods. + // + // Precondition: The instruction is a Conditional instruction. + HloComputation* true_computation() const; + HloComputation* false_computation() const; + void set_true_computation(HloComputation* true_computation); + void set_false_computation(HloComputation* false_computation); + // Returns a string for the signature of this instruction if considered as a // function, e.g. the signature of an F32 add is (F32, F32) -> F32. string SignatureString() const; // Returns a debugging string that represents this instruction. - string ToString(bool compact_operands = false, - bool include_metadata = true) const; + string ToString(bool compact_operands = false, bool include_metadata = true, + bool include_large_constants = false) const; // Components of the ToString() representation: // Returns a string representation of the operand list. - string OperandsToString(bool compact) const; + string OperandsToString(bool compact, bool include_large_constants) const; // Returns string representation of op-specific attributes. std::vector ExtraAttributesToString() const; @@ -843,6 +899,11 @@ class HloInstruction { return *window_; } + // Sets the window data in a windowed operation such as convolution. + void set_window(const Window& window) { + window_ = MakeUnique(window); + } + // Returns the padding configuration for a pad node. // // Precondition: opcode() == HloOpcode::kPad @@ -861,6 +922,15 @@ class HloInstruction { // Returns the dump string of the convolution dimension numbers. string ConvolutionDimensionNumbersToString() const; + // Returns data on the dimension numbers used for a dot operation. + const DotDimensionNumbers& dot_dimension_numbers() const { + CHECK(dot_dimension_numbers_ != nullptr); + return *dot_dimension_numbers_; + } + + // Returns the dump string of the dot dimension numbers. + string DotDimensionNumbersToString() const; + // Returns the random distribution for this rng node. // // Precondition: opcode() == HloOpcode::kRng @@ -870,12 +940,19 @@ class HloInstruction { // operands. After creation the clone has no uses. "this" (the instruction // cloned from) is not changed. Suffix is the string to append to the name of // the instruction to form the name of the cloned instruction. - std::unique_ptr Clone(const string& suffix = "clone") const; + // If the module pointer is not nullptr, it will be the module where + // the cloned computations will be added to (in order to support deep + // cloning). + std::unique_ptr Clone(const string& suffix = "clone", + HloModule* module = nullptr) const; // Clones the HLO instruction as above but with new shape and operands. + // If the module pointer is not nullptr, it will be the module where + // the cloned computations will be added to (in order to support deep + // cloning). std::unique_ptr CloneWithNewOperands( - const Shape& shape, - tensorflow::gtl::ArraySlice operands) const; + const Shape& shape, tensorflow::gtl::ArraySlice operands, + HloModule* module = nullptr) const; // Returns the computations this instruction directly calls (if any). const std::vector& called_computations() const { @@ -945,11 +1022,6 @@ class HloInstruction { std::tuple, std::vector> ReshapeMerelyInsertsOrDeletes1SizedDimensions() const; - // Returns the opcode string for this instruction. This is the result from - // HloOpcodeString plus, for fusion nodes, the fusion kind, separated by a - // ':'. - string ExtendedOpcodeStr() const; - // Returns a string identifier for this instruction. If no string identifier // has been explicitly set, then the identifier is the serialized pointer to // this instruction. @@ -1061,8 +1133,8 @@ class HloInstruction { // Clones a fusion instruction with a new shape and operands. std::unique_ptr CloneFusionWithNewOperands( - const Shape& shape, - tensorflow::gtl::ArraySlice operands) const; + const Shape& shape, tensorflow::gtl::ArraySlice operands, + HloModule* module = nullptr) const; // Returns true if this instruction can legally have the dimensions field // set. Used for checking precondition of dimensions field accessors. @@ -1117,6 +1189,9 @@ class HloInstruction { // Describes the dimension numbers used for a convolution. std::unique_ptr convolution_dimension_numbers_; + // Describes the dimension numbers used for a dot. + std::unique_ptr dot_dimension_numbers_; + // Describes the [begin, end) index range for a slice. std::vector slice_starts_; std::vector slice_limits_; @@ -1160,6 +1235,10 @@ class HloInstruction { // kSelectAndScatter computations. kSelectComputationIndex = 0, kScatterComputationIndex = 1, + + // kConditional computations. + kTrueComputationIndex = 0, + kFalseComputationIndex = 1, }; // Outfeed configuration information, only present for kOutfeed. @@ -1207,8 +1286,37 @@ string ToString(HloInstruction::FusionKind kind); StatusOr StringToFusionKind( const string& kind_name); +// Custom (de)stringification functions for protos that live inside +// HloInstruction. +string PaddingConfigToString(const PaddingConfig& padding); +string OpMetadataToString(const OpMetadata& metadata); +string RandomDistributionToString(const RandomDistribution& distribution); +StatusOr StringToRandomDistribution(const string& name); + std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind); +// Map classes that guarantee a deterministic iteration order when the key is +// an HloInstruction* or a const HloInstruction*. +// To make the iteration order over the map deterministic, the comparator +// should not be using the pointer values, but rather an intrinsic property of +// the hlo. +// +// Note that this cannot be used for HLO instructions across multiple modules +// since the id of HLO instructions are only unique within each HLO module. +struct HloPtrComparator { + bool operator()(const HloInstruction* const& lhs, + const HloInstruction* const& rhs) const { + return lhs->unique_id() < rhs->unique_id(); + } +}; + +template +using HloInstructionMap = std::map; + +template +using ConstHloInstructionMap = + std::map; + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_ diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index 4ead64d997df1a6a85b028374949a4e5c9eab549..aa3fd0cf4f7410ed7034c65d72e16489d4f0ba71 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -792,8 +792,8 @@ TEST_F(HloInstructionTest, ComplexFusionOp) { // sub = Sub(mul, clamp) // tuple = Tuple({sub, sub, mul, C1}) // - // Notable complexities are repeated operands in a same instruction, different - // shapes, use of value in different expressions. + // Notable complexities are repeated operands in the same instruction, + // different shapes, use of value in different expressions. auto c1 = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(1.1f))); auto c2 = builder.AddInstruction( @@ -1068,8 +1068,11 @@ TEST_F(HloInstructionTest, CloneOfFusionPreservesShape) { builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y")); HloInstruction* reshape = builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0})); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateBinary(sout, HloOpcode::kDot, x, reshape)); + HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); HloModule module(TestName()); auto* computation = module.AddEntryComputation(builder.Build()); @@ -1088,48 +1091,6 @@ TEST_F(HloInstructionTest, CloneOfFusionPreservesShape) { root2->operand(1)->operand(0)->shape())); } -TEST_F(HloInstructionTest, IsRandomFusable) { - auto shape = ShapeUtil::MakeShape(F32, {2, 2}); - { - auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); - auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR0(0.0))); - auto const1 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR0(1.0))); - auto rng = builder.AddInstruction(HloInstruction::CreateRng( - shape, RandomDistribution::RNG_NORMAL, {const0, const1})); - - auto* computation = hlo_module->AddEntryComputation(builder.Build()); - computation->CreateFusionInstruction({rng, const0, const1}, - HloInstruction::FusionKind::kLoop); - - auto* root = computation->root_instruction(); - - EXPECT_EQ(HloOpcode::kFusion, root->opcode()); - } - { - auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); - auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR0(0.0))); - auto const1 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR0(1.0))); - auto rng = builder.AddInstruction(HloInstruction::CreateRng( - shape, RandomDistribution::RNG_NORMAL, {const0, const1})); - builder.AddInstruction(HloInstruction::CreateUnary( - shape, HloOpcode::kNegate, rng)); - auto* computation = hlo_module->AddEntryComputation(builder.Build()); - computation->CreateFusionInstruction({rng, const0, const1}, - HloInstruction::FusionKind::kLoop); - - auto* root = computation->root_instruction(); - - EXPECT_EQ(HloOpcode::kFusion, root->operand(0)->opcode()); - } -} - - TEST_F(HloInstructionTest, CloneSuffixNames) { // Test that the suffix string added to cloned instructions is not // duplicated. Rather a numeric incrementing value should be appended. That @@ -1138,35 +1099,34 @@ TEST_F(HloInstructionTest, CloneSuffixNames) { // Test cloning the same instruction multiple times. auto foo = HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "foo"); - EXPECT_EQ(foo->Clone()->name(), "%foo.clone"); - EXPECT_EQ(foo->Clone()->Clone()->name(), "%foo.clone2"); - EXPECT_EQ(foo->Clone()->Clone()->Clone()->name(), "%foo.clone3"); + EXPECT_EQ(foo->Clone()->name(), "foo.clone"); + EXPECT_EQ(foo->Clone()->Clone()->name(), "foo.clone2"); + EXPECT_EQ(foo->Clone()->Clone()->Clone()->name(), "foo.clone3"); // Test custom suffixes. - EXPECT_EQ(foo->Clone("bar")->name(), "%foo.bar"); - EXPECT_EQ(foo->Clone("bar")->Clone("bar")->name(), "%foo.bar2"); - EXPECT_EQ(foo->Clone("bar")->Clone("bar")->Clone()->name(), - "%foo.bar2.clone"); + EXPECT_EQ(foo->Clone("bar")->name(), "foo.bar"); + EXPECT_EQ(foo->Clone("bar")->Clone("bar")->name(), "foo.bar2"); + EXPECT_EQ(foo->Clone("bar")->Clone("bar")->Clone()->name(), "foo.bar2.clone"); // Test instruction name with a dot. auto foo_baz = HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {}), "foo.baz"); - EXPECT_EQ(foo_baz->Clone()->name(), "%foo.baz.clone"); + EXPECT_EQ(foo_baz->Clone()->name(), "foo.baz.clone"); // Test incrementing a large number after the suffix. auto foo_clone234 = HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {}), "foo.clone234"); - EXPECT_EQ(foo_clone234->Clone()->name(), "%foo.clone235"); + EXPECT_EQ(foo_clone234->Clone()->name(), "foo.clone235"); // Test a non-numeric string after the cloning suffix. auto foo_clonexyz = HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {}), "foo.clonexyz"); - EXPECT_EQ(foo_clonexyz->Clone()->name(), "%foo.clonexyz.clone"); + EXPECT_EQ(foo_clonexyz->Clone()->name(), "foo.clonexyz.clone"); // Test a name with multiple appearances of the suffix. auto foo_clone_clone3 = HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {}), "foo.clone.clone3"); - EXPECT_EQ(foo_clone_clone3->Clone()->name(), "%foo.clone.clone4"); + EXPECT_EQ(foo_clone_clone3->Clone()->name(), "foo.clone.clone4"); } TEST_F(HloInstructionTest, Stringification) { @@ -1183,21 +1143,25 @@ TEST_F(HloInstructionTest, Stringification) { builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y")); HloInstruction* reshape = builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0})); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateBinary(sout, HloOpcode::kDot, x, reshape)); + HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); EXPECT_EQ(dot->ToString(false, false), "%dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} " - "%transpose)"); + "%transpose), lhs_contracting_dims=1,rhs_contracting_dims=0"); HloModule module(TestName()); auto* computation = module.AddEntryComputation(builder.Build()); HloInstruction* fusion = computation->CreateFusionInstruction( {dot, reshape}, HloInstruction::FusionKind::kTransposeDot); - EXPECT_EQ(fusion->ToString(false, false), - "%fusion = f32[5,20]{1,0} fusion:kTransposeDot(f32[5,10]{1,0} %x, " - "f32[20,10]{1,0} %y), calls=%fused_computation"); + EXPECT_EQ( + fusion->ToString(false, false), + "%fusion = f32[5,20]{1,0} fusion(f32[5,10]{1,0} %x, " + "f32[20,10]{1,0} %y), kind=kTransposeDot, calls=%fused_computation"); HloInstruction* loop = builder.AddInstruction( HloInstruction::CreateWhile(sout, computation, computation, x)); diff --git a/tensorflow/compiler/xla/service/hlo_matchers.cc b/tensorflow/compiler/xla/service/hlo_matchers.cc index 0660d5a1820f068a1e6a765c133f3b9654339c57..4255d6086625dfb9a045e4431e968a5ee0106ac7 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.cc +++ b/tensorflow/compiler/xla/service/hlo_matchers.cc @@ -73,6 +73,35 @@ void HloMatcher::DescribeTo(::std::ostream* os) const { } } +bool HloParameterMatcher::MatchAndExplain( + const HloInstruction* instruction, + ::testing::MatchResultListener* listener) const { + if (!HloMatcher::MatchAndExplain(instruction, listener)) { + return false; + } + if (instruction->parameter_number() != parameter_number_) { + *listener << "has wrong parameter number (got " + << instruction->parameter_number() << ", want " + << parameter_number_ << ")"; + return false; + } + return true; +} + +bool HloGetTupleElementMatcher::MatchAndExplain( + const HloInstruction* instruction, + ::testing::MatchResultListener* listener) const { + if (!HloMatcher::MatchAndExplain(instruction, listener)) { + return false; + } + if (instruction->tuple_index() != tuple_index_) { + *listener << "has wrong tuple index (got " << instruction->tuple_index() + << ", want " << tuple_index_ << ")"; + return false; + } + return true; +} + } // namespace testing void PrintTo(const HloInstruction* inst, ::std::ostream* os) { diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h index bc5ed029a45b4f92a240138dc1e933610efe1789..992f55788b4900949f4994ba5b7be015bcd0d3de 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.h +++ b/tensorflow/compiler/xla/service/hlo_matchers.h @@ -38,6 +38,36 @@ class HloMatcher : public ::testing::MatcherInterface { std::vector<::testing::Matcher> operands_; }; +// Custom matcher for parameters, which accepts a parameter number. +class HloParameterMatcher : public HloMatcher { + public: + explicit HloParameterMatcher(int64 parameter_number) + : HloMatcher(HloOpcode::kParameter, /*operands=*/{}), + parameter_number_(parameter_number) {} + + bool MatchAndExplain(const HloInstruction* instruction, + ::testing::MatchResultListener* listener) const override; + + private: + int64 parameter_number_; +}; + +// Custom matcher for get-tuple-element instructions, which accepts a tuple +// index to match. +class HloGetTupleElementMatcher : public HloMatcher { + public: + explicit HloGetTupleElementMatcher( + ::testing::Matcher operand, int64 tuple_index) + : HloMatcher(HloOpcode::kGetTupleElement, /*operands=*/{operand}), + tuple_index_(tuple_index) {} + + bool MatchAndExplain(const HloInstruction* instruction, + ::testing::MatchResultListener* listener) const override; + + private: + int64 tuple_index_; +}; + // HloInstruction* matchers for opcode and operands. Example: // namespace op = xla::opcode_matchers; // EXPECT_THAT(instruction, @@ -57,6 +87,7 @@ HLO_MATCHER(Call); HLO_MATCHER(Ceil); HLO_MATCHER(Clamp); HLO_MATCHER(Concatenate); +HLO_MATCHER(Conditional); HLO_MATCHER(Constant); HLO_MATCHER(Convert); HLO_MATCHER(Convolution); @@ -72,7 +103,6 @@ HLO_MATCHER(Exp); HLO_MATCHER(Floor); HLO_MATCHER(Fusion); HLO_MATCHER(Ge); -HLO_MATCHER(GetTupleElement); HLO_MATCHER(Gt); HLO_MATCHER(Infeed); HLO_MATCHER(IsFinite); @@ -90,9 +120,9 @@ HLO_MATCHER(Ne); HLO_MATCHER(Negate); HLO_MATCHER(Outfeed); HLO_MATCHER(Pad); -HLO_MATCHER(Parameter); HLO_MATCHER(Power); HLO_MATCHER(Recv); +HLO_MATCHER(RecvDone); HLO_MATCHER(Reduce); HLO_MATCHER(ReducePrecision); HLO_MATCHER(ReduceWindow); @@ -103,6 +133,7 @@ HLO_MATCHER(Rng); HLO_MATCHER(Select); HLO_MATCHER(SelectAndScatter); HLO_MATCHER(Send); +HLO_MATCHER(SendDone); HLO_MATCHER(ShiftLeft); HLO_MATCHER(ShiftRightLogical); HLO_MATCHER(ShiftRightArithmetic); @@ -115,6 +146,43 @@ HLO_MATCHER(Trace); HLO_MATCHER(Transpose); HLO_MATCHER(Tuple); HLO_MATCHER(While); + +// The special cases below let you check additional information about the +// HloInstruction, beyond just its opcode and operands. In all cases you can +// still use the generic matcher which doesn't check this info. +// +// Feel free to add additional custom matchers below. + +// - Parameter(N) matches parameter number N. +// - Parameter() matches any parameter. +inline ::testing::Matcher Parameter( + int64 parameter_number) { + return ::testing::MakeMatcher( + new ::xla::testing::HloParameterMatcher(parameter_number)); +} +inline ::testing::Matcher Parameter() { + return ::testing::MakeMatcher( + new ::xla::testing::HloMatcher(HloOpcode::kParameter, {})); +} + +// GetTupleElement(operand, N) matches a GTE instruction which gets the N'th +// tuple element of operand, while GetTupleElement(operand) matches any GTE +// operation on operand, and GetTupleElement() matches any GTE operation at all. +inline ::testing::Matcher GetTupleElement( + ::testing::Matcher operand, int64 tuple_index) { + return ::testing::MakeMatcher( + new ::xla::testing::HloGetTupleElementMatcher(operand, tuple_index)); +} +inline ::testing::Matcher GetTupleElement( + ::testing::Matcher operand) { + return ::testing::MakeMatcher( + new ::xla::testing::HloMatcher(HloOpcode::kGetTupleElement, {operand})); +} +inline ::testing::Matcher GetTupleElement() { + return ::testing::MakeMatcher( + new ::xla::testing::HloMatcher(HloOpcode::kGetTupleElement, {})); +} + #undef HLO_MATCHER } // namespace opcode_matchers diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index 1758f2760c46a5f0f5876ac6ba8dd013e71455b6..6fe2134466ffaf1402e5ecbc81aea9aafe2a468b 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -35,14 +35,15 @@ namespace xla { HloModule::HloModule(const string& name, const VersionedComputationHandle& entry_computation_handle, const HloModuleConfig& config) - : name_(name), + : name_(NameUniquer::GetSanitizedName(name)), config_(config), has_entry_computation_handle_(true), entry_computation_handle_(entry_computation_handle) {} -HloModule::HloModule(const string& name) : name_(name) {} +HloModule::HloModule(const string& name) + : name_(NameUniquer::GetSanitizedName(name)) {} HloModule::HloModule(const string& name, const HloModuleConfig& config) - : name_(name), config_(config) {} + : name_(NameUniquer::GetSanitizedName(name)), config_(config) {} HloComputation* HloModule::AddComputationInternal( std::unique_ptr computation, bool is_entry, @@ -170,20 +171,17 @@ void HloModule::ReplaceComputations( computations_ = std::move(new_computations); } -string HloModule::ToString() const { +string HloModule::ToString(bool include_large_constants) const { std::ostringstream s; s << "HloModule " << name() << ":\n\n"; for (const HloComputation* computation : MakeComputationPostOrder()) { - // Fusion computations are emitted with their fusion instruction and - // therefore don't need to be emitted as a separate comptutation in the - // module. - if (computation->IsFusionComputation()) { - continue; - } if (computation == entry_computation()) { s << "ENTRY "; } - s << computation->ToString() << "\n\n"; + s << computation->ToString( + /*nested_level=*/0, + /*include_large_constants=*/include_large_constants) + << "\n\n"; } return s.str(); } @@ -293,9 +291,16 @@ StatusOr> HloModule::CreateFromProto( tensorflow::gtl::FlatMap computation_map; for (const HloComputationProto& computation_proto : proto.computations()) { - TF_ASSIGN_OR_RETURN(std::unique_ptr computation, - HloComputation::CreateFromProto( - module.get(), computation_proto, &computation_map)); + TF_ASSIGN_OR_RETURN( + std::unique_ptr computation, + HloComputation::CreateFromProto( + module.get(), computation_proto, computation_map, + /*add_fused_computation=*/ + [&module](std::unique_ptr fused_computation) { + module->AddComputationInternal(std::move(fused_computation), + /*is_entry=*/false, + /*uniquify_names=*/false); + })); CHECK_NE(computation.get(), nullptr); TF_RET_CHECK(!ContainsKey(computation_map, computation->name())); string computation_name = computation->name(); diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index ad11d56006a79b509309daba55e94342911f76a1..5141e7bc8d4cf0ef4cd83310772e0c5d66b5da12 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -85,7 +85,11 @@ class HloModule { std::unique_ptr Clone(const string& suffix = "clone") const; // Return a pointer to the entry computation of the module.. - HloComputation* entry_computation() const { + const HloComputation* entry_computation() const { + CHECK_NE(nullptr, entry_computation_); + return entry_computation_; + } + HloComputation* entry_computation() { CHECK_NE(nullptr, entry_computation_); return entry_computation_; } @@ -139,7 +143,7 @@ class HloModule { const HloModuleConfig& config() const { return config_; } - string ToString() const; + string ToString(bool include_large_constants = false) const; // Convert an HloModule to or from a proto. HloModuleProto ToProto() const; diff --git a/tensorflow/compiler/xla/service/hlo_module_config.cc b/tensorflow/compiler/xla/service/hlo_module_config.cc index 8974deb530c2e4561b5ab57f43c65fd525db3617..822e2f1f53e5ee460b88c2241ecf7f6b91ef608b 100644 --- a/tensorflow/compiler/xla/service/hlo_module_config.cc +++ b/tensorflow/compiler/xla/service/hlo_module_config.cc @@ -39,8 +39,8 @@ void HloModuleConfig::SetDefaultComputationLayout( } string HloModuleConfig::compilation_cache_key() const { - string key = tensorflow::strings::StrCat("profiling=", hlo_profiling_enabled_, - "::hybrid=", has_hybrid_result_); + string key = + tensorflow::strings::StrCat("profiling=", hlo_profiling_enabled_); StrAppend(&key, "::("); std::vector params; for (const ShapeLayout& param_layout : diff --git a/tensorflow/compiler/xla/service/hlo_module_config.h b/tensorflow/compiler/xla/service/hlo_module_config.h index 4a7ead9c104d2ed50d5c895b3cdf2d3767ae16e8..a5ee895e48448fbb8fa3879dc1b6764c1f9f6966 100644 --- a/tensorflow/compiler/xla/service/hlo_module_config.h +++ b/tensorflow/compiler/xla/service/hlo_module_config.h @@ -104,16 +104,6 @@ class HloModuleConfig { // Whether to enable HLO-level profiling. bool hlo_profiling_enabled_ = false; - // If this flag is true, the generated executable will return a ShapedBuffer - // holding the result of the computation. In a ShapedBuffer, tuples have their - // structure held in host memory and the element arrays (leaves of the tuple - // structure) stored in device memory. The ShapedBuffer is considered "hybrid" - // because its leaves are on device but its structure is stored on - // host. Otherwise, if this flag is false, the generated executable will - // return a DeviceMemoryBase where the result is held entirely in device - // memory. - bool has_hybrid_result_ = false; - // Module/graph-level seed handle. uint64 seed_ = 0; diff --git a/tensorflow/compiler/xla/service/hlo_module_test.cc b/tensorflow/compiler/xla/service/hlo_module_test.cc index 20eef2f7d53251a374971e55441f6a4585e9b35c..bf6440d66cac0d3a929c377202b212aba262f887 100644 --- a/tensorflow/compiler/xla/service/hlo_module_test.cc +++ b/tensorflow/compiler/xla/service/hlo_module_test.cc @@ -101,7 +101,7 @@ TEST_F(HloModuleTest, CloneTest) { for (auto origin = post_order.begin(), copied = post_order_copied.begin(); origin != post_order.end() && copied != post_order_copied.end(); ++origin, ++copied) { - EXPECT_EQ((*origin)->name() + "copy", (*copied)->name()); + EXPECT_EQ((*origin)->name() + ".copy", (*copied)->name()); } } @@ -125,6 +125,26 @@ TEST_F(HloModuleTest, DiamondComputationsPostOrder) { EXPECT_EQ(post_order.front(), computation1); } +TEST_F(HloModuleTest, LargeConstantToString) { + // Create a module with a single computation. + auto module = CreateNewModule(); + auto builder = HloComputation::Builder("Constant"); + std::vector values(16, 42.0); + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1(values))); + module->AddEntryComputation(builder.Build()); + + EXPECT_EQ( + "HloModule LargeConstantToString:\n\nENTRY %Constant () -> f32[16] {\n " + "ROOT %constant = f32[16]{0} constant({...})\n}\n\n", + module->ToString(/*include_large_constants=*/false)); + EXPECT_EQ( + "HloModule LargeConstantToString:\n\nENTRY %Constant () -> f32[16] {\n " + "ROOT %constant = f32[16]{0} constant({42, 42, 42, 42, 42, 42, 42, 42, " + "42, 42, 42, 42, 42, 42, 42, 42})\n}\n\n", + module->ToString(/*include_large_constants=*/true)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_opcode.cc b/tensorflow/compiler/xla/service/hlo_opcode.cc index 157d19f5a9996ff90c4a5c3655f82ff5b8e62cfc..d1eaf357855205f1e9867e86f3042b96b6beff97 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.cc +++ b/tensorflow/compiler/xla/service/hlo_opcode.cc @@ -21,243 +21,22 @@ limitations under the License. namespace xla { string HloOpcodeString(HloOpcode opcode) { - // Note: Do not use ':' in opcode strings. It is used as a special character - // in these places: - // - In extended opcode strings (HloInstruction::ExtendedOpcodeString()), to - // separate the opcode from the fusion kind - // - In fully qualified names (HloInstruction::FullyQualifiedName()), to - // separate the qualifiers (name of the computation and potentially the - // fusion instruction) from the name switch (opcode) { - case HloOpcode::kAbs: - return "abs"; - case HloOpcode::kAdd: - return "add"; - case HloOpcode::kAnd: - return "and"; - case HloOpcode::kAtan2: - return "atan2"; - case HloOpcode::kBatchNormTraining: - return "batch-norm-training"; - case HloOpcode::kBatchNormInference: - return "batch-norm-inference"; - case HloOpcode::kBatchNormGrad: - return "batch-norm-grad"; - case HloOpcode::kBitcast: - return "bitcast"; - case HloOpcode::kBroadcast: - return "broadcast"; - case HloOpcode::kCall: - return "call"; - case HloOpcode::kClamp: - return "clamp"; - case HloOpcode::kComplex: - return "complex"; - case HloOpcode::kConcatenate: - return "concatenate"; - case HloOpcode::kConstant: - return "constant"; - case HloOpcode::kConvert: - return "convert"; - case HloOpcode::kConvolution: - return "convolution"; - case HloOpcode::kCos: - return "cosine"; - case HloOpcode::kCrossReplicaSum: - return "cross-replica-sum"; - case HloOpcode::kCustomCall: - return "custom-call"; - case HloOpcode::kCopy: - return "copy"; - case HloOpcode::kDivide: - return "divide"; - case HloOpcode::kDot: - return "dot"; - case HloOpcode::kDynamicSlice: - return "dynamic-slice"; - case HloOpcode::kDynamicUpdateSlice: - return "dynamic-update-slice"; - case HloOpcode::kEq: - return "equal-to"; - case HloOpcode::kExp: - return "exponential"; - case HloOpcode::kFloor: - return "floor"; - case HloOpcode::kCeil: - return "ceil"; - case HloOpcode::kFusion: - return "fusion"; - case HloOpcode::kGe: - return "greater-than-or-equal-to"; - case HloOpcode::kGetTupleElement: - return "get-tuple-element"; - case HloOpcode::kGt: - return "greater-than"; - case HloOpcode::kImag: - return "imag"; - case HloOpcode::kInfeed: - return "infeed"; - case HloOpcode::kIsFinite: - return "is-finite"; - case HloOpcode::kLe: - return "less-than-or-equal-to"; - case HloOpcode::kLog: - return "log"; - case HloOpcode::kLt: - return "less-than"; - case HloOpcode::kMap: - return "map"; - case HloOpcode::kMaximum: - return "maximum"; - case HloOpcode::kMinimum: - return "minimum"; - case HloOpcode::kMultiply: - return "multiply"; - case HloOpcode::kNe: - return "not-equal-to"; - case HloOpcode::kNegate: - return "negate"; - case HloOpcode::kNot: - return "not"; - case HloOpcode::kOr: - return "or"; - case HloOpcode::kOutfeed: - return "outfeed"; - case HloOpcode::kPad: - return "pad"; - case HloOpcode::kParameter: - return "parameter"; - case HloOpcode::kPower: - return "power"; - case HloOpcode::kReal: - return "real"; - case HloOpcode::kRecv: - return "recv"; - case HloOpcode::kReduce: - return "reduce"; - case HloOpcode::kReducePrecision: - return "reduce-precision"; - case HloOpcode::kReduceWindow: - return "reduce-window"; - case HloOpcode::kRemainder: - return "remainder"; - case HloOpcode::kReshape: - return "reshape"; - case HloOpcode::kReverse: - return "reverse"; - case HloOpcode::kRng: - return "rng"; - case HloOpcode::kRoundNearestAfz: - return "round-nearest-afz"; - case HloOpcode::kSelectAndScatter: - return "select-and-scatter"; - case HloOpcode::kSelect: - return "select"; - case HloOpcode::kSend: - return "send"; - case HloOpcode::kShiftLeft: - return "shift-left"; - case HloOpcode::kShiftRightArithmetic: - return "shift-right-arithmetic"; - case HloOpcode::kShiftRightLogical: - return "shift-right-logical"; - case HloOpcode::kSign: - return "sign"; - case HloOpcode::kSin: - return "sine"; - case HloOpcode::kSlice: - return "slice"; - case HloOpcode::kSort: - return "sort"; - case HloOpcode::kSubtract: - return "subtract"; - case HloOpcode::kTanh: - return "tanh"; - case HloOpcode::kTrace: - return "trace"; - case HloOpcode::kTranspose: - return "transpose"; - case HloOpcode::kTuple: - return "tuple"; - case HloOpcode::kWhile: - return "while"; +#define CASE_OPCODE_STRING(enum_name, opcode_name, ...) \ + case HloOpcode::enum_name: \ + return opcode_name; + HLO_OPCODE_LIST(CASE_OPCODE_STRING) +#undef CASE_OPCODE_STRING } } StatusOr StringToHloOpcode(const string& opcode_name) { - static auto* opcode_map = new tensorflow::gtl::FlatMap( - {{"abs", HloOpcode::kAbs}, - {"add", HloOpcode::kAdd}, - {"and", HloOpcode::kAnd}, - {"batch-norm-training", HloOpcode::kBatchNormTraining}, - {"batch-norm-inference", HloOpcode::kBatchNormInference}, - {"batch-norm-grad", HloOpcode::kBatchNormGrad}, - {"bitcast", HloOpcode::kBitcast}, - {"broadcast", HloOpcode::kBroadcast}, - {"call", HloOpcode::kCall}, - {"clamp", HloOpcode::kClamp}, - {"concatenate", HloOpcode::kConcatenate}, - {"constant", HloOpcode::kConstant}, - {"convert", HloOpcode::kConvert}, - {"convolution", HloOpcode::kConvolution}, - {"cosine", HloOpcode::kCos}, - {"cross-replica-sum", HloOpcode::kCrossReplicaSum}, - {"custom-call", HloOpcode::kCustomCall}, - {"copy", HloOpcode::kCopy}, - {"divide", HloOpcode::kDivide}, - {"dot", HloOpcode::kDot}, - {"dynamic-slice", HloOpcode::kDynamicSlice}, - {"dynamic-update-slice", HloOpcode::kDynamicUpdateSlice}, - {"equal-to", HloOpcode::kEq}, - {"exponential", HloOpcode::kExp}, - {"floor", HloOpcode::kFloor}, - {"ceil", HloOpcode::kCeil}, - {"fusion", HloOpcode::kFusion}, - {"greater-than-or-equal-to", HloOpcode::kGe}, - {"get-tuple-element", HloOpcode::kGetTupleElement}, - {"greater-than", HloOpcode::kGt}, - {"infeed", HloOpcode::kInfeed}, - {"is-finite", HloOpcode::kIsFinite}, - {"less-than-or-equal-to", HloOpcode::kLe}, - {"log", HloOpcode::kLog}, - {"less-than", HloOpcode::kLt}, - {"map", HloOpcode::kMap}, - {"maximum", HloOpcode::kMaximum}, - {"minimum", HloOpcode::kMinimum}, - {"multiply", HloOpcode::kMultiply}, - {"not", HloOpcode::kNot}, - {"not-equal-to", HloOpcode::kNe}, - {"negate", HloOpcode::kNegate}, - {"or", HloOpcode::kOr}, - {"outfeed", HloOpcode::kOutfeed}, - {"pad", HloOpcode::kPad}, - {"parameter", HloOpcode::kParameter}, - {"power", HloOpcode::kPower}, - {"recv", HloOpcode::kRecv}, - {"reduce", HloOpcode::kReduce}, - {"reduce-precision", HloOpcode::kReducePrecision}, - {"reduce-window", HloOpcode::kReduceWindow}, - {"remainder", HloOpcode::kRemainder}, - {"reshape", HloOpcode::kReshape}, - {"reverse", HloOpcode::kReverse}, - {"rng", HloOpcode::kRng}, - {"round-nearest-afz", HloOpcode::kRoundNearestAfz}, - {"select-and-scatter", HloOpcode::kSelectAndScatter}, - {"select", HloOpcode::kSelect}, - {"send", HloOpcode::kSend}, - {"shift-left", HloOpcode::kShiftLeft}, - {"shift-right-arithmetic", HloOpcode::kShiftRightArithmetic}, - {"shift-right-logical", HloOpcode::kShiftRightLogical}, - {"sign", HloOpcode::kSign}, - {"sine", HloOpcode::kSin}, - {"slice", HloOpcode::kSlice}, - {"sort", HloOpcode::kSort}, - {"subtract", HloOpcode::kSubtract}, - {"tanh", HloOpcode::kTanh}, - {"trace", HloOpcode::kTrace}, - {"transpose", HloOpcode::kTranspose}, - {"tuple", HloOpcode::kTuple}, - {"while", HloOpcode::kWhile}}); + static auto* opcode_map = new tensorflow::gtl::FlatMap({ +#define STRING_TO_OPCODE_ENTRY(enum_name, opcode_name, ...) \ + {opcode_name, HloOpcode::enum_name}, + HLO_OPCODE_LIST(STRING_TO_OPCODE_ENTRY) +#undef STRING_TO_OPCODE_ENTRY + }); auto it = opcode_map->find(opcode_name); if (it == opcode_map->end()) { return InvalidArgument("Unknown opcode: %s", opcode_name.c_str()); @@ -265,31 +44,36 @@ StatusOr StringToHloOpcode(const string& opcode_name) { return it->second; } +#define CHECK_DEFAULT(property_name, opcode_name) false +#define CHECK_PROPERTY(property_name, opcode_name, value) \ + (value & property_name) +#define RESOLVE(_1, _2, target, ...) target +#define HAS_PROPERTY(property, ...) \ + RESOLVE(__VA_ARGS__, CHECK_PROPERTY, CHECK_DEFAULT)(property, __VA_ARGS__) + bool HloOpcodeIsComparison(HloOpcode opcode) { switch (opcode) { - case HloOpcode::kGe: - case HloOpcode::kGt: - case HloOpcode::kLe: - case HloOpcode::kLt: - case HloOpcode::kEq: - case HloOpcode::kNe: - return true; - default: - return false; +#define CASE_IS_COMPARISON(enum_name, ...) \ + case HloOpcode::enum_name: \ + return HAS_PROPERTY(kHloOpcodeIsComparison, __VA_ARGS__); + HLO_OPCODE_LIST(CASE_IS_COMPARISON) +#undef CASE_IS_COMPARISON } } bool HloOpcodeIsVariadic(HloOpcode opcode) { switch (opcode) { - case HloOpcode::kCall: - case HloOpcode::kConcatenate: - case HloOpcode::kFusion: - case HloOpcode::kMap: - case HloOpcode::kTuple: - return true; - default: - return false; +#define CASE_IS_VARIADIC(enum_name, ...) \ + case HloOpcode::enum_name: \ + return HAS_PROPERTY(kHloOpcodeIsVariadic, __VA_ARGS__); + HLO_OPCODE_LIST(CASE_IS_VARIADIC) +#undef CASE_IS_VARIADIC } } +#undef HAS_PROPERTY +#undef RESOLVE +#undef CHECK_DEFAULT +#undef CHECK_PROPERTY + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index 07c2d26f00f2338d306b57933e5f0fb77b38b892..f3f79357582ac7661a532e94031acdbca0b86784 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -28,83 +28,116 @@ namespace xla { // present in the XLA service protobuf. // // See the XLA documentation for the semantics of each opcode. +// +// Each entry has the format: +// (enum_name, opcode_name) +// or +// (enum_name, opcode_name, p1 | p2 | ...) +// +// with p1, p2, ... are members of HloOpcodeProperty. They are combined +// using bitwise-or. +// +// Note: Do not use ':' in opcode names. It is used as a special character +// in these places: +// - In extended opcode strings (HloInstruction::ExtendedOpcodeString()), to +// separate the opcode from the fusion kind +// - In fully qualified names (HloInstruction::FullyQualifiedName()), to +// separate the qualifiers (name of the computation and potentially the +// fusion instruction) from the name +#define HLO_OPCODE_LIST(V) \ + V(kAbs, "abs") \ + V(kAdd, "add") \ + V(kAtan2, "atan2") \ + V(kBatchNormGrad, "batch-norm-grad") \ + V(kBatchNormInference, "batch-norm-inference") \ + V(kBatchNormTraining, "batch-norm-training") \ + V(kBitcast, "bitcast") \ + V(kBitcastConvert, "bitcast-convert") \ + V(kBroadcast, "broadcast") \ + V(kCall, "call", kHloOpcodeIsVariadic) \ + V(kCeil, "ceil") \ + V(kClamp, "clamp") \ + V(kComplex, "complex") \ + V(kConcatenate, "concatenate", kHloOpcodeIsVariadic) \ + V(kConditional, "conditional") \ + V(kConstant, "constant") \ + V(kConvert, "convert") \ + V(kConvolution, "convolution") \ + V(kCopy, "copy") \ + V(kCos, "cosine") \ + V(kCrossReplicaSum, "cross-replica-sum") \ + V(kCustomCall, "custom-call") \ + V(kDivide, "divide") \ + V(kDot, "dot") \ + V(kDynamicSlice, "dynamic-slice") \ + V(kDynamicUpdateSlice, "dynamic-update-slice") \ + V(kEq, "equal-to", kHloOpcodeIsComparison) \ + V(kExp, "exponential") \ + V(kFloor, "floor") \ + V(kFusion, "fusion", kHloOpcodeIsVariadic) \ + V(kGe, "greater-than-or-equal-to", kHloOpcodeIsComparison) \ + V(kGetTupleElement, "get-tuple-element") \ + V(kGt, "greater-than", kHloOpcodeIsComparison) \ + V(kImag, "imag") \ + V(kInfeed, "infeed") \ + V(kIsFinite, "is-finite") \ + V(kLe, "less-than-or-equal-to", kHloOpcodeIsComparison) \ + V(kLog, "log") \ + V(kAnd, "and") \ + V(kNot, "not") \ + V(kOr, "or") \ + V(kLt, "less-than", kHloOpcodeIsComparison) \ + V(kMap, "map", kHloOpcodeIsVariadic) \ + V(kMaximum, "maximum") \ + V(kMinimum, "minimum") \ + V(kMultiply, "multiply") \ + V(kNe, "not-equal-to", kHloOpcodeIsComparison) \ + V(kNegate, "negate") \ + V(kOutfeed, "outfeed") \ + V(kPad, "pad") \ + V(kParameter, "parameter") \ + V(kPower, "power") \ + V(kReal, "real") \ + V(kRecv, "recv") \ + V(kRecvDone, "recv-done") \ + V(kReduce, "reduce") \ + V(kReducePrecision, "reduce-precision") \ + V(kReduceWindow, "reduce-window") \ + V(kRemainder, "remainder") \ + V(kReshape, "reshape") \ + V(kReverse, "reverse") \ + V(kRng, "rng") \ + V(kRoundNearestAfz, "round-nearest-afz") \ + V(kSelect, "select") \ + V(kSelectAndScatter, "select-and-scatter") \ + V(kSend, "send") \ + V(kSendDone, "send-done") \ + V(kShiftLeft, "shift-left") \ + V(kShiftRightArithmetic, "shift-right-arithmetic") \ + V(kShiftRightLogical, "shift-right-logical") \ + V(kSign, "sign") \ + V(kSin, "sine") \ + V(kSlice, "slice") \ + V(kSort, "sort") \ + V(kSubtract, "subtract") \ + V(kTanh, "tanh") \ + V(kTrace, "trace") \ + V(kTranspose, "transpose") \ + V(kTuple, "tuple", kHloOpcodeIsVariadic) \ + V(kWhile, "while") + enum class HloOpcode { - kAbs, - kAdd, - kAtan2, - kBatchNormGrad, - kBatchNormInference, - kBatchNormTraining, - kBitcast, - kBroadcast, - kCall, - kCeil, - kClamp, - kComplex, - kConcatenate, - kConstant, - kConvert, - kConvolution, - kCopy, - kCos, - kCrossReplicaSum, - kCustomCall, - kDivide, - kDot, - kDynamicSlice, - kDynamicUpdateSlice, - kEq, - kExp, - kFloor, - kFusion, - kGe, - kGetTupleElement, - kGt, - kImag, - kInfeed, - kIsFinite, - kLe, - kLog, - kAnd, - kNot, - kOr, - kLt, - kMap, - kMaximum, - kMinimum, - kMultiply, - kNe, - kNegate, - kOutfeed, - kPad, - kParameter, - kPower, - kReal, - kRecv, - kReduce, - kReducePrecision, - kReduceWindow, - kRemainder, - kReshape, - kReverse, - kRng, - kRoundNearestAfz, - kSelect, - kSelectAndScatter, - kSend, - kShiftLeft, - kShiftRightArithmetic, - kShiftRightLogical, - kSign, - kSin, - kSlice, - kSort, - kSubtract, - kTanh, - kTrace, - kTranspose, - kTuple, - kWhile, +#define DECLARE_ENUM(enum_name, opcode_name, ...) enum_name, + HLO_OPCODE_LIST(DECLARE_ENUM) +#undef DECLARE_ENUM +}; + +// List of properties associated with opcodes. +// Properties are defined as increasing powers of two, so that we can use +// bitwise-or to combine properties, and bitwise-and to test for them. +enum HloOpcodeProperty { + kHloOpcodeIsComparison = 1 << 0, + kHloOpcodeIsVariadic = 1 << 1, }; // Returns a string representation of the opcode. @@ -125,7 +158,9 @@ bool HloOpcodeIsVariadic(HloOpcode opcode); // Returns the number of HloOpcode values. inline const uint32_t HloOpcodeCount() { - return static_cast(HloOpcode::kWhile) + 1; +#define HLO_COUNT_ONE(...) +1 +#define HLO_XLIST_LENGTH(list) list(HLO_COUNT_ONE) + return HLO_XLIST_LENGTH(HLO_OPCODE_LIST); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_opcode_test.cc b/tensorflow/compiler/xla/service/hlo_opcode_test.cc index 892c89f9df209f2e39005a4901feae6699ce4d0b..cd2ce5c69f030c65b889d67e082a3677b8739ddb 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode_test.cc +++ b/tensorflow/compiler/xla/service/hlo_opcode_test.cc @@ -26,5 +26,46 @@ TEST(HloOpcodeTest, StringifyMultiply) { ASSERT_EQ("multiply", HloOpcodeString(HloOpcode::kMultiply)); } +TEST(HloOpcodeTest, OpcodeProperties) { + // Test counting macro. +#define SOME_LIST(X) \ + X(One) \ + X(Two) \ + X(Three) + EXPECT_EQ(3, HLO_XLIST_LENGTH(SOME_LIST)); +#undef SOME_LIST + + for (int i = 0; i < HloOpcodeCount(); ++i) { + auto opcode = static_cast(i); + // Test round-trip conversion to and from string. + EXPECT_EQ(opcode, StringToHloOpcode(HloOpcodeString(opcode)).ValueOrDie()); + + // Test some properties. + switch (opcode) { + case HloOpcode::kEq: + case HloOpcode::kNe: + case HloOpcode::kGt: + case HloOpcode::kLt: + case HloOpcode::kGe: + case HloOpcode::kLe: + EXPECT_TRUE(HloOpcodeIsComparison(opcode)); + break; + default: + EXPECT_FALSE(HloOpcodeIsComparison(opcode)); + } + switch (opcode) { + case HloOpcode::kCall: + case HloOpcode::kConcatenate: + case HloOpcode::kFusion: + case HloOpcode::kMap: + case HloOpcode::kTuple: + EXPECT_TRUE(HloOpcodeIsVariadic(opcode)); + break; + default: + EXPECT_FALSE(HloOpcodeIsVariadic(opcode)); + } + } +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc index 37009369797693dcd06647fad845bb0c004cec67..6f6e679a21870e46da85963c3b2998465ac43420 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering.cc @@ -173,6 +173,19 @@ bool HloOrdering::UseIsBeforeValueDefinition( return true; } } + + // The use at a call occurs before values that are defined in the called + // computation. + if (use.instruction->opcode() == HloOpcode::kCall) { + const HloInstruction* call = use.instruction; + if (call_graph_->InstructionIsNestedIn(value.defining_instruction(), + call->to_apply())) { + VLOG(4) << " use is call " << use.instruction->name() + << " and def is in called computation"; + return true; + } + } + VLOG(4) << " use is not before value"; return false; } @@ -187,23 +200,6 @@ bool HloOrdering::LiveRangeStrictlyBefore( return false; } - // Live-out values from the module can never have ranges strictly before any - // other value. - if (a.live_out_of_module()) { - VLOG(4) << "a is live out of module"; - return false; - } - - // Live-out values of computations can never have ranges strictly before any - // other value in the computation (including values nested in - // subcomputations). - if (a.live_out_of_computation() && - call_graph_->InstructionIsNestedIn(b.defining_instruction(), - a.defining_instruction()->parent())) { - VLOG(4) << "a is live out of computation containing b"; - return false; - } - // All uses of 'a' must be before 'b' is defined. for (const HloUse& use : a.uses()) { if (!UseIsBeforeValueDefinition(use, b, dataflow)) { diff --git a/tensorflow/compiler/xla/service/hlo_profile_printer.cc b/tensorflow/compiler/xla/service/hlo_profile_printer.cc new file mode 100644 index 0000000000000000000000000000000000000000..e944ad15139af0d2f98e8e68d3d48303f47ecf1c --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_profile_printer.cc @@ -0,0 +1,67 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_profile_printer.h" + +#include "tensorflow/compiler/xla/service/human_readable_profile_builder.h" + +namespace xla { +string HloProfilePrinter::ToString(const int64* counters, + double clock_rate_ghz) const { + string result; + + for (int computation_idx = 0; computation_idx < computation_infos_size_; + computation_idx++) { + const HloComputationInfo& computation = computation_infos_[computation_idx]; + const HloInstructionInfo* instructions_begin = computation.instructions; + const HloInstructionInfo* instructions_end = + computation.instructions + computation.instructions_size; + bool any_instruction_profiled = + std::any_of(instructions_begin, instructions_end, + [&](const HloInstructionInfo& instruction_info) { + return counters[instruction_info.profile_index] != 0; + }); + + if (!any_instruction_profiled) { + continue; + } + + // Once we start using this in AOT for real, we will probably need a more + // minimal version of HumanReadableProfileBuilder. + HumanReadableProfileBuilder builder( + computation.name, counters[computation.profile_index], clock_rate_ghz); + + for (const auto* instruction = instructions_begin; + instruction != instructions_end; instruction++) { + builder.AddOp( + /*op_name=*/instruction->long_name, + /*short_name=*/instruction->short_name, instruction->category, + counters[instruction->profile_index], instruction->flop_count, + instruction->transcendental_count, instruction->bytes_accessed, + instruction->optimal_seconds); + } + + result += builder.ToString(); + } + + return result; +} + +HloProfilePrinter::~HloProfilePrinter() { + if (deleter_) { + deleter_(computation_infos_, computation_infos_size_); + } +} +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_profile_printer.h b/tensorflow/compiler/xla/service/hlo_profile_printer.h new file mode 100644 index 0000000000000000000000000000000000000000..2f056490ae027872570f7a0821ee63114f49fab8 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_profile_printer.h @@ -0,0 +1,103 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PROFILE_PRINTER_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PROFILE_PRINTER_H_ + +#include +#include +#include + +#include "tensorflow/compiler/xla/types.h" + +namespace xla { +// Instances of this class can pretty-print profile counters gathered from +// running an XLA computation without having access to the backing module. +class HloProfilePrinter { + public: + // Holds meta information about an HloInstruction. + // + // The pointer-typed fields can be owning or non-owning -- this decision is + // manifested as the deleter_ function in the containing HloProfilePrinter. + struct HloInstructionInfo { + // Textual information for pretty printing. + const char* long_name; + const char* short_name; + const char* category; + + // Metrics computed by HloCostAnalysis. + float flop_count; + float transcendental_count; + float bytes_accessed; + float optimal_seconds; + + // The index into the profile counters array for the HloInstruction + // corresponding to this HloInstructionInfo. + int64 profile_index; + }; + + // Holds meta information about an HloComputation. + // + // The pointer-typed fields can be owning or non-owning -- this decision is + // manifested as the deleter_ function in the containing HloProfilePrinter. + struct HloComputationInfo { + const char* name; + + // The index into the profile counters array for the HloInstruction + // corresponding to this HloComputationInfo. + int64 profile_index; + + HloInstructionInfo* instructions; + int64 instructions_size; + }; + + HloProfilePrinter( + HloComputationInfo* computation_infos, int64 computation_infos_size, + int64 profile_counters_size, + std::function deleter = nullptr) + : computation_infos_(computation_infos), + computation_infos_size_(computation_infos_size), + profile_counters_size_(profile_counters_size), + deleter_(std::move(deleter)) {} + + HloProfilePrinter(HloProfilePrinter&& other) { + std::swap(other.computation_infos_, computation_infos_); + std::swap(other.computation_infos_size_, computation_infos_size_); + std::swap(other.deleter_, deleter_); + } + + HloProfilePrinter(const HloProfilePrinter&) = delete; + HloProfilePrinter& operator=(const HloProfilePrinter&) = delete; + + // Converts the profile counter sequence `counters` to a human readable string + // representation. + string ToString(const int64* counters, double clock_rate_ghz) const; + + // Returns the size of the profile buffer expected by this printer. + int64 profile_counters_size() const { return profile_counters_size_; } + + ~HloProfilePrinter(); + + private: + // The `computation_infos_` field can be owning or non-owning -- this decision + // is manifested as the deleter_ function. + HloComputationInfo* computation_infos_ = nullptr; + int64 computation_infos_size_ = 0; + int64 profile_counters_size_ = 0; + std::function deleter_; +}; +} // namespace xla + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PROFILE_PRINTER_H_ diff --git a/tensorflow/compiler/xla/service/hlo_reachability.h b/tensorflow/compiler/xla/service/hlo_reachability.h index d7bdac9c86579f19afbba133772c2c50894853d1..553ec11f6f9a2997ab7113f9b8241e04c7fe20d5 100644 --- a/tensorflow/compiler/xla/service/hlo_reachability.h +++ b/tensorflow/compiler/xla/service/hlo_reachability.h @@ -30,11 +30,17 @@ namespace xla { class HloInstruction; -// A class for computing and representing reachability between HloInstructions. +// A class for representing reachability between HloInstructions. +// +// !!! THIS CLASS DOES NOT COMPUTE REACHABILITY !!! It has an adjacency matrix +// and it is up to the user of the class to set the adjacency matrix such that +// it represents reachability, i.e. such that it is transitive. That the graph +// be transitive is thus not an invariant of this class, but it is required for +// the name of the class and its methods to make sense. class HloReachabilityMap { public: - // Sets up an empty reachable matrix for the full set of instructions - // specified in 'instructions'. + // Sets up a graph with no edges and where the nodes correspond to the given + // instructions. explicit HloReachabilityMap(const std::list& instructions); // Set the reachability set of 'instruction' to the union of the reachability @@ -42,17 +48,33 @@ class HloReachabilityMap { // 'x' is not 'instruction' will return true iff IsReachable(x, input) is true // for some 'input' in 'inputs'. Also sets 'instruction' to be reachable from // itself. Returns whether the reachability set of 'instruction' changed. + // + // !!! THIS FUNCTION DOES NOT COMPUTE REACHABILITY !!! It sets the adjacency + // vector in the internal graph of this HloReachabilityMap for the given + // instruction and does not transitively update any other part of the + // adjacency matrix. bool SetReachabilityToUnion( tensorflow::gtl::ArraySlice inputs, const HloInstruction* instruction); // Sets entry so that IsReachable(a, b) will return true + // + // !!! THIS FUNCTION DOES NOT COMPUTE REACHABILITY !!! It sets the adjacency + // matrix in the internal graph of this HloReachabilityMap to have an edge + // from a to b and does not transitively update any other part of the + // adjacency matrix. void SetReachable(const HloInstruction* a, const HloInstruction* b); // Returns true if "b" is reachable from "a" + // + // Note that this function only correctly answers queries about reachability + // if the set of edges that have been provided to this class are transitive. bool IsReachable(const HloInstruction* a, const HloInstruction* b) const; // Returns true if "b" is reachable from "a" or "a" is reachable from "b" + // + // Note that this function only correctly answers queries about reachability + // if the set of edges that have been provided to this class are transitive. bool IsConnected(const HloInstruction* a, const HloInstruction* b) const; private: diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index c96df50e79a3c6d4ca5f8e7e0abec33cdfca1c70..1747790e63c6af997eea096b68e5525fdd9d131a 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -62,16 +62,11 @@ bool IsRematerializable(const HloInstruction* instruction) { case HloOpcode::kConstant: case HloOpcode::kCrossReplicaSum: case HloOpcode::kCustomCall: - case HloOpcode::kOutfeed: - case HloOpcode::kInfeed: case HloOpcode::kParameter: - case HloOpcode::kRecv: - case HloOpcode::kSend: - case HloOpcode::kTrace: case HloOpcode::kWhile: return false; default: - return true; + return !instruction->HasSideEffect(); } } @@ -571,7 +566,9 @@ Status MemoryUsageTracker::BeginInstruction(Item* item) { VLOG(3) << " memory usage = " << memory_usage_; VLOG(10) << ToString(); - DCHECK(Check()); + if (VLOG_IS_ON(1)) { + DCHECK(Check()); + } return Status::OK(); } @@ -608,8 +605,9 @@ Status MemoryUsageTracker::EndInstruction() { VLOG(3) << " memory usage = " << memory_usage_; VLOG(10) << ToString(); - DCHECK(Check()); - + if (VLOG_IS_ON(1)) { + DCHECK(Check()); + } return Status::OK(); } @@ -1026,7 +1024,9 @@ StatusOr HloRematerialization::RematerializeComputation( HloInstruction* best = best_item->instruction; VLOG(1) << "Rematerializing instruction " << best->name() << " (saving " - << memory_tracker.MemoryReducedIfRematerialized(best_item) << ")"; + << HumanReadableNumBytes( + memory_tracker.MemoryReducedIfRematerialized(best_item)) + << ")"; changed = true; remat_count++; @@ -1106,8 +1106,8 @@ StatusOr HloRematerialization::RematerializeComputation( net_instructions_added++; } - VLOG(3) << "memory_usage after rematerialization = " - << memory_tracker.memory_usage(); + VLOG(1) << "memory_usage after rematerialization = " + << HumanReadableNumBytes(memory_tracker.memory_usage()); } const CallSite* callsite = call_graph_node.GetCallSite(instruction); diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc index d88aa4bb567c6c5f6eab54f12239bf7040339c39..c9b57166af438ef19ae4f079b8ecc8ddd5aede00 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc @@ -323,6 +323,76 @@ TEST_F(HloRematerializationTest, RematerializeNestedComputations) { EXPECT_EQ(inner_computation->instruction_count(), 8); } +TEST_F(HloRematerializationTest, RngNotRematerialized) { + // Test that a single rng is not rematerialized: + // + // Entry computation: + // F32[] %param = {...} + // F32[1024] rng = rng(param) + // F32[1024] tanh = tanh(rng) + // F32[1024] exp = exp(rng) + // F32[1024] add_0 = add(rng, tanh) // LIVE: add_0 + rng + + // // tanh + exp + // + // F32[1024] add_1 = add(rng, add(exp, add_0)) // LIVE: add_1 + add_0 + + // // rng + tanh + exp + // + // F32[1024] add_2 = add(rng, add(tanh, add_1)) // LIVE: add_2 + add_1 + + // // rng + tanh + exp + auto module = CreateNewModule(); + + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "param")); + auto rng = builder.AddInstruction(HloInstruction::CreateRng( + vec1024_shape_, RandomDistribution::RNG_BERNOULLI, {param})); + auto tanh = builder.AddInstruction( + HloInstruction::CreateUnary(vec1024_shape_, HloOpcode::kTanh, rng)); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(vec1024_shape_, HloOpcode::kExp, rng)); + auto add_0 = builder.AddInstruction( + HloInstruction::CreateBinary(vec1024_shape_, HloOpcode::kAdd, rng, tanh)); + auto add_1 = builder.AddInstruction(HloInstruction::CreateBinary( + vec1024_shape_, HloOpcode::kAdd, rng, + builder.AddInstruction(HloInstruction::CreateBinary( + vec1024_shape_, HloOpcode::kAdd, exp, add_0)))); + builder.AddInstruction(HloInstruction::CreateBinary( + vec1024_shape_, HloOpcode::kAdd, rng, + builder.AddInstruction(HloInstruction::CreateBinary( + vec1024_shape_, HloOpcode::kAdd, tanh, add_1)))); + HloComputation* entry_computation = + module->AddEntryComputation(builder.Build()); + + auto count_rngs = [](const HloComputation* computation) { + int64 rng_count = 0; + for (auto* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kRng) { + ++rng_count; + } + } + return rng_count; + }; + // Before rematerialization there should be a single broadcast rng in + // the graph. + ASSERT_EQ(count_rngs(entry_computation), 1); + const int64 original_instruction_count = + entry_computation->instruction_count(); + SequentialHloOrdering::HloModuleSequence sequence; + // Pick a memory limit some where between 24KB (initial peak memory including + // parameter and output) and 20KB (peak memory possible with + // rematerialization). + TF_ASSERT_OK_AND_ASSIGN( + bool changed, HloRematerialization::RematerializeAndSchedule( + ByteSizeOf, + /*memory_limit_bytes=*/4 * ByteSizeOf(vec1024_shape_), + module.get(), &sequence)); + EXPECT_TRUE(changed); + // The rng should not have been rematerialized. + EXPECT_EQ(count_rngs(entry_computation), 1); + // There should have been rematerialization. + EXPECT_GT(entry_computation->instruction_count(), original_instruction_count); +} + TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) { // Test that a single instruction is rematerialized several times. Module: // diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc index c3f74e253f7a7882ec1c72e0ce634017dd2f0957..4a7caf3ebd81e4ca81400c67aa29a6a10bfe59d8 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.cc +++ b/tensorflow/compiler/xla/service/hlo_runner.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#define EIGEN_USE_THREADS #include "tensorflow/compiler/xla/service/hlo_runner.h" @@ -19,8 +20,6 @@ limitations under the License. #include #include -#define EIGEN_USE_THREADS - #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/ptr_util.h" @@ -30,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/common_runtime/eigen_thread_pool.h" #include "tensorflow/core/platform/logging.h" @@ -40,11 +40,29 @@ namespace se = ::perftools::gputools; namespace xla { /*static*/ StatusOr> -HloRunner::ReadModuleFromHloProtoFile(const char* filename, +HloRunner::CreateModuleFromString(const tensorflow::StringPiece hlo_string, + const DebugOptions& debug_options) { + HloModuleConfig config; + config.set_debug_options(debug_options); + return tools::Parse(hlo_string, config); +} + +/*static*/ StatusOr> +HloRunner::ReadModuleFromHloProtoFile(const std::string& filename, const DebugOptions& debug_options) { HloProto proto; - TF_RETURN_IF_ERROR(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), - filename, &proto)); + + const Status s = + tensorflow::ReadBinaryProto(tensorflow::Env::Default(), filename, &proto); + + if (!s.ok()) { + const Status s2 = + tensorflow::ReadTextProto(tensorflow::Env::Default(), filename, &proto); + if (!s2.ok()) { + return Status(s2.code(), s.error_message() + "\n" + s2.error_message()); + } + } + TF_ASSIGN_OR_RETURN( HloModuleConfig config, HloModule::CreateModuleConfigFromProto(proto.hlo_module())); @@ -54,6 +72,30 @@ HloRunner::ReadModuleFromHloProtoFile(const char* filename, return std::move(module); } +/*static*/ StatusOr> +HloRunner::ReadModuleFromHloTextDumpFile(const std::string& filename, + const DebugOptions& debug_options) { + string hlo_string; + TF_RETURN_IF_ERROR(tensorflow::ReadFileToString(tensorflow::Env::Default(), + filename, &hlo_string)); + HloModuleConfig config; + config.set_debug_options(debug_options); + return tools::Parse(hlo_string, config); +} + +/*static*/ StatusOr> HloRunner::ReadModule( + const std::string& filename, const DebugOptions& debug_options) { + auto module = HloRunner::ReadModuleFromHloProtoFile(filename, debug_options); + if (module.ok()) { + return module; + } + const std::string e = module.status().error_message(); + module = HloRunner::ReadModuleFromHloTextDumpFile(filename, debug_options); + return module.ok() ? std::move(module) + : Status(module.status().code(), + e + "\n" + module.status().error_message()); +} + // Define this in .cc file to avoid having to include eigen or forward declare // these types in the header. struct HloRunner::EigenThreadPoolWrapper { @@ -80,11 +122,16 @@ HloRunner::~HloRunner() { StatusOr HloRunner::Execute( std::unique_ptr module, tensorflow::gtl::ArraySlice arguments, - Shape* result_shape) { + Shape* result_shape, bool run_hlo_passes) { + if (run_hlo_passes) { + TF_ASSIGN_OR_RETURN( + module, backend().compiler()->RunHloPasses( + std::move(module), backend().default_stream_executor())); + } TF_ASSIGN_OR_RETURN( std::unique_ptr executable, - backend().compiler()->Compile(std::move(module), - backend().default_stream_executor())); + backend().compiler()->RunBackend(std::move(module), + backend().default_stream_executor())); se::Stream stream(backend().default_stream_executor()); stream.Init(); @@ -96,14 +143,13 @@ StatusOr HloRunner::Execute( run_options.set_intra_op_thread_pool( backend().eigen_intra_op_thread_pool_device()); - HloExecutionProfile hlo_execution_profile; ServiceExecutableRunOptions service_run_options( run_options, backend().StreamBorrower(), backend().inter_op_thread_pool()); TF_ASSIGN_OR_RETURN( se::DeviceMemoryBase result, executable->ExecuteOnStream(&service_run_options, arguments, - &hlo_execution_profile)); + /*hlo_execution_profile=*/nullptr)); TF_RET_CHECK(stream.BlockHostUntilDone()); allocations_.push_back(result); @@ -160,10 +206,12 @@ StatusOr> HloRunner::TransferFromDevice( StatusOr> HloRunner::ExecuteAndTransfer( std::unique_ptr module, - tensorflow::gtl::ArraySlice arguments) { + tensorflow::gtl::ArraySlice arguments, + bool run_hlo_passes) { Shape result_shape; - TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase device_base, - Execute(std::move(module), arguments, &result_shape)); + TF_ASSIGN_OR_RETURN( + se::DeviceMemoryBase device_base, + Execute(std::move(module), arguments, &result_shape, run_hlo_passes)); return TransferFromDevice(result_shape, device_base); } diff --git a/tensorflow/compiler/xla/service/hlo_runner.h b/tensorflow/compiler/xla/service/hlo_runner.h index a4d7b653dbfbfdb169c07bca3e461147fd9d077a..a65c66fd4b6db858a532096a5ee466aa9bf0d844 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.h +++ b/tensorflow/compiler/xla/service/hlo_runner.h @@ -35,7 +35,8 @@ namespace xla { // A base class for running an HloModule. This executes the given HloModule on a // certain backend directly without using the client interface. HloModule can be -// explicitly built, or loaded from a serialization file (e.g., hlo proto file). +// explicitly built, or loaded from a serialization file (e.g., hlo proto +// file), or parsed from a hlo textual IR string. class HloRunner { public: HloRunner(); @@ -44,25 +45,48 @@ class HloRunner { ~HloRunner(); - // Reads the binary proto file in xla.HloProto format, creates and returns the - // HloModule. + // Converts an HloModule from the given hlo textual IR string (in + // HloModule::ToString format). + static StatusOr> CreateModuleFromString( + const tensorflow::StringPiece hlo_string, + const DebugOptions& debug_options); + + // Reads the proto file in xla.HloProto format, creates and returns the + // HloModule. Will try to parse the filename as binary proto, then try as + // text proto if that fails. static StatusOr> ReadModuleFromHloProtoFile( - const char* filename, const DebugOptions& debug_options); + const std::string& filename, const DebugOptions& debug_options); + + // Reads the hlo text dump file in HloModule::ToString format, creates and + // returns the HloModule. + static StatusOr> ReadModuleFromHloTextDumpFile( + const std::string& filename, const DebugOptions& debug_options); + + // Tries to parse the filename specified first as binary proto format, then + // as a textual proto format, then textual IR, then gives up if both fail. + // ReadModuleFromHloProtoFile or ReadModuleFromHloTextDumpFile should be used + // explicitly when you know the format, this if you don't. + static StatusOr> ReadModule( + const std::string& filename, const DebugOptions& debug_options); // Executes the given module with given literals as input and returns the // result as a Literal. The LiteralPtr type accepts Literal* or // std::unique_ptr. + // + // If run_hlo_passes is false, the module will be executed without Hlo + // optimization. template StatusOr> Execute( std::unique_ptr module, - const tensorflow::gtl::ArraySlice literals); + const tensorflow::gtl::ArraySlice literals, + bool run_hlo_passes = true); // Executes the given module and returns a global data handle. StatusOr Execute( std::unique_ptr module, tensorflow::gtl::ArraySlice arguments, - Shape* result_shape); + Shape* result_shape, bool run_hlo_passes = true); // Transfers the given literal to the device and returns the data handle. StatusOr TransferToDevice( @@ -77,7 +101,8 @@ class HloRunner { StatusOr> ExecuteAndTransfer( std::unique_ptr module, tensorflow::gtl::ArraySlice - arguments); + arguments, + bool run_hlo_passes = true); // If backend is not created in the constructor, creates and returns the // default backend. If creation fails, crashes the program. @@ -99,14 +124,15 @@ class HloRunner { template StatusOr> HloRunner::Execute( std::unique_ptr module, - const tensorflow::gtl::ArraySlice literals) { + const tensorflow::gtl::ArraySlice literals, + bool run_hlo_passes) { std::vector arguments; for (const auto& literal : literals) { TF_ASSIGN_OR_RETURN(perftools::gputools::DeviceMemoryBase argument, TransferToDevice(*literal)); arguments.push_back(argument); } - return ExecuteAndTransfer(std::move(module), arguments); + return ExecuteAndTransfer(std::move(module), arguments, run_hlo_passes); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.cc b/tensorflow/compiler/xla/service/hlo_scheduling.cc index 8ccbcaeee4a9c9e94b344231953e20ac8f4b2053..0dc17392f1f520a415083c92b51db9d9abb321c0 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling.cc +++ b/tensorflow/compiler/xla/service/hlo_scheduling.cc @@ -31,6 +31,8 @@ limitations under the License. #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" +using ::tensorflow::strings::HumanReadableNumBytes; + namespace xla { StatusOr MinimumMemoryForSequence( @@ -375,6 +377,7 @@ StatusOr> CreateMemoryMinimizingSequence( // Note that this is just a heuristic. One obvious inaccuracy is that the // memory required for sub-computations might be different when considered // within the caller's context. But it's good enough for now. + VLOG(2) << "Computation: " << computation.name(); TF_ASSIGN_OR_RETURN( std::vector list_sequence, ListScheduler::Run(computation, points_to_analysis, size_function)); @@ -382,7 +385,7 @@ StatusOr> CreateMemoryMinimizingSequence( const int64 list_memory, MinimumMemoryForComputation(computation, list_sequence, points_to_analysis, size_function)); - VLOG(2) << "Min-memory list sequence: " << list_memory << " bytes"; + VLOG(2) << "Min-memory list sequence: " << HumanReadableNumBytes(list_memory); TF_ASSIGN_OR_RETURN( std::vector dfs_sequence, @@ -391,13 +394,15 @@ StatusOr> CreateMemoryMinimizingSequence( const int64 dfs_memory, MinimumMemoryForComputation(computation, dfs_sequence, points_to_analysis, size_function)); - VLOG(2) << "Min-memory dfs sequence: " << dfs_memory << " bytes"; + VLOG(2) << "Min-memory dfs sequence: " << HumanReadableNumBytes(dfs_memory); if (list_memory <= dfs_memory) { - VLOG(2) << "Chose min-memory list sequence: " << list_memory << " bytes"; + VLOG(2) << "Chose min-memory list sequence: " + << HumanReadableNumBytes(list_memory); return list_sequence; } else { - VLOG(2) << "Chose min-memory dfs sequence: " << dfs_memory << " bytes"; + VLOG(2) << "Chose min-memory dfs sequence: " + << HumanReadableNumBytes(dfs_memory); return dfs_sequence; } } diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc index 0d019d22f5d4cd401c0fc5572f99636dec4f7383..447c2446668253c932b44b51b2db22bfd47f9957 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_sharding.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/str_util.h" namespace xla { @@ -38,6 +39,15 @@ HloSharding HloSharding::Tile1D(const Shape& input_shape, int64 num_tiles) { } string HloSharding::ToString() const { + if (IsTuple()) { + std::vector parts; + parts.reserve(tuple_elements_.size()); + for (const HloSharding& element : tuple_elements_) { + parts.push_back(element.ToString()); + } + return StrCat("{", tensorflow::str_util::Join(parts, ", "), "}"); + } + string result = StrCat("{", (replicated_ ? " replicated" : ""), (maximal_ ? " maximal" : "")); @@ -53,6 +63,11 @@ string HloSharding::ToString() const { } bool HloSharding::UsesDevice(int64 device) const { + if (IsTuple()) { + return std::any_of( + tuple_elements_.begin(), tuple_elements_.end(), + [&](const HloSharding& s) { return s.UsesDevice(device); }); + } const auto& devices = tile_assignment_; return replicated_ || std::find(devices.begin(), devices.end(), device) != devices.end(); @@ -61,6 +76,7 @@ bool HloSharding::UsesDevice(int64 device) const { std::vector HloSharding::TileIndexForDevice(int64 device) const { CHECK(!ShapeUtil::IsTuple(tile_shape_)); CHECK(!maximal_); + CHECK(!IsTuple()); std::vector ret_index; tile_assignment_.Each([&](tensorflow::gtl::ArraySlice index, int64 d) { if (d == device) { @@ -74,6 +90,7 @@ std::vector HloSharding::TileIndexForDevice(int64 device) const { int64 HloSharding::DeviceForTileIndex( tensorflow::gtl::ArraySlice index) const { CHECK(!replicated_); + CHECK(!IsTuple()); if (maximal_) { return *tile_assignment_.begin(); } @@ -82,7 +99,7 @@ int64 HloSharding::DeviceForTileIndex( } std::vector HloSharding::TileOffsetForDevice(int64 device) const { - CHECK(!ShapeUtil::IsTuple(tile_shape_)); + CHECK(!IsTuple()); std::vector index = TileIndexForDevice(device); if (maximal_) { @@ -97,7 +114,7 @@ std::vector HloSharding::TileOffsetForDevice(int64 device) const { } std::vector HloSharding::TileLimitForDevice(int64 device) const { - CHECK(!ShapeUtil::IsTuple(tile_shape_)); + CHECK(!IsTuple()); CHECK(!maximal_); // Maximal shardings do not have a valid tile shape. std::vector index = TileIndexForDevice(device); @@ -108,14 +125,94 @@ std::vector HloSharding::TileLimitForDevice(int64 device) const { } StatusOr HloSharding::UniqueDevice() const { - if (!replicated_ && maximal_) { + if (IsTuple()) { + if (tuple_elements_.empty()) { + return tensorflow::errors::InvalidArgument( + "UniqueDevice() called on empty tuple"); + } + std::vector> results; + std::transform(tuple_elements_.begin(), tuple_elements_.end(), + std::back_inserter(results), + [](const HloSharding& s) { return s.UniqueDevice(); }); + if (std::all_of(results.begin(), results.end(), + [&](const StatusOr& s) { + return s.ok() && results[0].ok() && + s.ValueOrDie() == results[0].ValueOrDie(); + })) { + return results[0]; + } else { + return tensorflow::errors::InvalidArgument( + "Tuple did not contain a unique device"); + } + } + if (!replicated_ && maximal_ && !IsTuple()) { return static_cast(*tile_assignment_.begin()); } return tensorflow::errors::InvalidArgument( "UniqueDevice() called on sharding that executes on multiple devices"); } +bool HloSharding::HasUniqueDevice() const { + if (IsTuple()) { + return UniqueDevice().status().ok(); + } else { + return !IsReplicated() && IsTileMaximal(); + } +} + +Status HloSharding::ValidateTuple(const Shape& shape, int64 num_devices) const { + if (!ShapeUtil::IsTuple(shape)) { + return tensorflow::errors::InvalidArgument( + StrCat("Sharding is tuple-shaped but validation shape is not.")); + } + // The easiest way to get the number of elements in a nested tuple is just to + // create a shape tree. We could call GetAsShapeTree, but that will try and + // apply our tuple_shardings_ to the shape tree, and that might cause a crash + // at this point as we haven't validated them. + ShapeTree bool_shape_tree(shape, false); + int64 num_leaves = + std::distance(bool_shape_tree.leaf_begin(), bool_shape_tree.leaf_end()); + if (num_leaves != tuple_elements_.size()) { + return tensorflow::errors::InvalidArgument( + StrCat("Validation tuple shape has ", num_leaves, + " leaf elements, but this sharding contains ", + tuple_elements_.size(), " elements.")); + } + + // Now we've validated the number of tuple elements, it's safe to request a + // shape tree. + ShapeTree shape_tree = GetAsShapeTree(shape); + for (const auto& index_to_sharding : shape_tree.leaves()) { + Status status = index_to_sharding.second.ValidateNonTuple( + ShapeUtil::GetSubshape(shape, index_to_sharding.first), num_devices); + if (!status.ok()) { + tensorflow::errors::AppendToMessage( + &status, StrCat("Note: While validating sharding tuple element ", + index_to_sharding.first.ToString(), " which is ", + index_to_sharding.second.ToString())); + return status; + } + } + return Status::OK(); +} + Status HloSharding::Validate(const Shape& shape, int64 num_devices) const { + Status status = IsTuple() ? ValidateTuple(shape, num_devices) + : ValidateNonTuple(shape, num_devices); + if (!status.ok()) { + tensorflow::errors::AppendToMessage( + &status, StrCat("Note: While validating sharding ", ToString(), + " against shape ", ShapeUtil::HumanString(shape))); + } + return status; +} + +Status HloSharding::ValidateNonTuple(const Shape& shape, + int64 num_devices) const { + if (ShapeUtil::IsTuple(shape)) { + return tensorflow::errors::InvalidArgument( + StrCat("Validation shape is a tuple but sharding is not.")); + } if (replicated_) { return Status::OK(); } @@ -129,13 +226,11 @@ Status HloSharding::Validate(const Shape& shape, int64 num_devices) const { // Don't overwrite a bad status, so we report the first error. if (status.ok()) { if (core >= num_devices) { - status = - tensorflow::errors::InvalidArgument(tensorflow::strings::StrCat( - "core ", core, " > ", num_devices, " in tile assignment")); + status = tensorflow::errors::InvalidArgument(StrCat( + "core ", core, " > ", num_devices, " in tile assignment")); } else if (seen_cores.count(core) != 0) { - status = - tensorflow::errors::InvalidArgument(tensorflow::strings::StrCat( - "core ", core, " is not unique in tile assignment")); + status = tensorflow::errors::InvalidArgument( + StrCat("core ", core, " is not unique in tile assignment")); } } seen_cores.insert(core); @@ -151,7 +246,8 @@ Status HloSharding::Validate(const Shape& shape, int64 num_devices) const { // The tile rank must be the same as the input rank. if (ShapeUtil::Rank(shape) != ShapeUtil::Rank(tile_shape_)) { return tensorflow::errors::InvalidArgument( - "Tile rank is different to the input rank"); + "Tile rank is different to the input rank. sharding=", ToString(), + ", input_shape=", ShapeUtil::HumanString(shape)); } // The tile shape must not be the same as the input shape without maximal_ @@ -169,9 +265,9 @@ Status HloSharding::Validate(const Shape& shape, int64 num_devices) const { auto tile_dim = tile_shape_.dimensions(i); auto shape_dim = shape.dimensions(i); if (tile_dim > shape_dim) { - return tensorflow::errors::InvalidArgument(tensorflow::strings::StrCat( - "Tile is larger than input shape (dimension ", i, ", ", tile_dim, - " > ", shape_dim)); + return tensorflow::errors::InvalidArgument( + StrCat("Tile is larger than input shape (dimension ", i, ", ", + tile_dim, " > ", shape_dim)); } } @@ -181,10 +277,10 @@ Status HloSharding::Validate(const Shape& shape, int64 num_devices) const { int64 expected_dim = CeilOfRatio(shape.dimensions(i), tile_shape_.dimensions(i)); if (tile_assignment_.dimensions()[i] != expected_dim) { - return tensorflow::errors::InvalidArgument(tensorflow::strings::StrCat( - "Tile assignment tensor has incorrect shape. Dimension ", i, - " expected ", expected_dim, " but got ", - tile_assignment_.dimensions()[i])); + return tensorflow::errors::InvalidArgument( + StrCat("Tile assignment tensor has incorrect shape. Dimension ", i, + " expected ", expected_dim, " but got ", + tile_assignment_.dimensions()[i])); } } @@ -193,9 +289,19 @@ Status HloSharding::Validate(const Shape& shape, int64 num_devices) const { /*static*/ StatusOr HloSharding::FromProto( const OpSharding& proto) { - if (proto.type() == OpSharding::Type::OpSharding_Type_REPLICATED) { + if (proto.type() == OpSharding::Type::OpSharding_Type_TUPLE) { + std::vector tuple_shardings; + tuple_shardings.reserve(proto.tuple_shardings().size()); + for (const OpSharding& tuple_sharding_proto : proto.tuple_shardings()) { + TF_ASSIGN_OR_RETURN(HloSharding sharding, + HloSharding::FromProto(tuple_sharding_proto)); + tuple_shardings.push_back(sharding); + } + return HloSharding(tuple_shardings); + } else if (proto.type() == OpSharding::Type::OpSharding_Type_REPLICATED) { return Replicate(); - } else if (proto.type() == OpSharding::Type::OpSharding_Type_MAXIMAL) { + } else if (proto.type() == OpSharding::Type::OpSharding_Type_MAXIMAL || + proto.tile_assignment_devices().size() == 1) { return HloSharding(proto.tile_assignment_devices(0)); } // Some versions of gcc cannot infer the TileAssignment constructor from a @@ -212,6 +318,15 @@ Status HloSharding::Validate(const Shape& shape, int64 num_devices) const { OpSharding HloSharding::ToProto() const { OpSharding result; + + if (IsTuple()) { + for (const HloSharding& element : tuple_elements_) { + *result.add_tuple_shardings() = element.ToProto(); + } + result.set_type(OpSharding::Type::OpSharding_Type_TUPLE); + return result; + } + *result.mutable_tile_shape() = tile_shape_; for (int64 dim : tile_assignment_.dimensions()) { result.add_tile_assignment_dimensions(dim); diff --git a/tensorflow/compiler/xla/service/hlo_sharding.h b/tensorflow/compiler/xla/service/hlo_sharding.h index d7ada30c70bc3b41b3117375380eac2e883d9a9d..7263198385cf0c84b1dac1e15177dcac99adaafb 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.h +++ b/tensorflow/compiler/xla/service/hlo_sharding.h @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/array.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/protobuf_util.h" +#include "tensorflow/compiler/xla/shape_tree.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/hash/hash.h" @@ -67,6 +68,29 @@ class HloSharding { // `num_tiles` tiles. static HloSharding Tile1D(const Shape& input_shape, int64 num_tiles); + // Creates a new sharding for a tuple type. The given ShapeTree must have + // elements for every leaf shape contained in the tuple. + static HloSharding Tuple(const ShapeTree& sub_shardings) { + std::vector flattened_list; + flattened_list.reserve( + std::distance(sub_shardings.leaf_begin(), sub_shardings.leaf_end())); + for (const auto& index_to_sharding : sub_shardings.leaves()) { + flattened_list.push_back(index_to_sharding.second); + } + return HloSharding(flattened_list); + } + + // Creates a new sharding for a tuple type. The requested tuple shape must not + // be nested. For nested tuples, use the ShapeTree overload. + static HloSharding Tuple(const Shape& tuple_shape, + tensorflow::gtl::ArraySlice shardings) { + CHECK(ShapeUtil::IsTuple(tuple_shape)); + CHECK(!ShapeUtil::IsNestedTuple(tuple_shape)); + std::vector flattened_list(shardings.begin(), shardings.end()); + CHECK_EQ(flattened_list.size(), ShapeUtil::TupleElementCount(tuple_shape)); + return HloSharding(flattened_list); + } + // Create a new sharding from a protobuf OpSharding. static StatusOr FromProto(const OpSharding& proto); @@ -76,47 +100,93 @@ class HloSharding { // Validate that this sharding can be applied to a tensor with shape `shape`. Status Validate(const Shape& shape, int64 num_devices) const; + // Returns true if the sharding has tuple type. + bool IsTuple() const { return tuple_; } + // Returns true if the sharding is trivial: replicate on all devices. - bool IsReplicated() const { return replicated_; } + bool IsReplicated() const { + if (!IsTuple()) { + return replicated_; + } + return std::all_of(tuple_elements_.begin(), tuple_elements_.end(), + [](const HloSharding& s) { return s.IsReplicated(); }); + } // Returns true if the tile size is the same as the input size. - bool IsTileMaximal() const { return maximal_; } + bool IsTileMaximal() const { + if (!IsTuple()) { + return maximal_; + } + return std::all_of(tuple_elements_.begin(), tuple_elements_.end(), + [](const HloSharding& s) { return s.IsTileMaximal(); }); + } // Returns true if the sharding defines an operation on the given device. bool UsesDevice(int64 device) const; // Returns the tile that should be executed on the given device. + // REQUIRES: !IsTuple() std::vector TileIndexForDevice(int64 device) const; // Returns the device that should execute the given tile. // It is an error to call this if is_replicated() is true. + // REQUIRES: !IsTuple() int64 DeviceForTileIndex(tensorflow::gtl::ArraySlice index) const; // Given a device ID, returns the offset within the input space of the // tile that should be executed on the given core. This returns the lower // extent of the tile in the input space. + // REQUIRES: !IsTuple() std::vector TileOffsetForDevice(int64 device) const; // Given a device ID, returns the limit within the input space of the // tile that should be executed on the given core. This returns the upper // extent of the tile in the input space. + // REQUIRES: !IsTuple() std::vector TileLimitForDevice(int64 device) const; // Returns the single device this op operates on. - // Requires !Replicated() && IsTileMaximal(). + // REQUIRES: !IsTuple&& !Replicated() && IsTileMaximal() StatusOr UniqueDevice() const; // Returns true if this op only uses a single device. - bool HasUniqueDevice() const { return !IsReplicated() && IsTileMaximal(); } + bool HasUniqueDevice() const; + + // Returns the ShapeTree containing the shardings for each element of this + // tuple, if IsTuple, or a ShapeTree with a single element containing this + // sharding. Only the leaf elements are populated. This creates a new + // ShapeTree object so is not cheap. + ShapeTree GetAsShapeTree(const Shape& shape) const { + if (IsTuple()) { + ShapeTree result(shape, HloSharding::Replicate()); + CHECK_EQ(std::distance(result.leaf_begin(), result.leaf_end()), + tuple_elements_.size()); + auto it = tuple_elements_.begin(); + for (auto& index_to_sharding : result.leaves()) { + index_to_sharding.second = *it++; + } + return result; + } else { + return ShapeTree(shape, *this); + } + } bool operator==(const HloSharding& other) const { return replicated_ == other.replicated_ && maximal_ == other.maximal_ && protobuf_util::ProtobufEquals(tile_shape_, other.tile_shape_) && - tile_assignment_ == other.tile_assignment_; + tile_assignment_ == other.tile_assignment_ && + tuple_elements_ == other.tuple_elements_; } bool operator!=(const HloSharding& other) const { return !(*this == other); } size_t Hash() const { + if (!tuple_) { + size_t h = 0; + for (const auto& element : tuple_elements_) { + h = tensorflow::Hash64Combine(h, element.Hash()); + } + return h; + } if (replicated_) { return 0; } @@ -131,33 +201,52 @@ class HloSharding { } // Gets the tile shape. - // It is an error to call this if IsTileMaximal() is true. + // REQUIRES: !IsTileMaximal() && !IsTuple() const Shape& tile_shape() const { return tile_shape_; } // Gets the tile assignment tensor. - // It is an error to call this if IsReplicated() is true. + // REQUIRES: !IsReplicated() && !IsTuple() const Array& tile_assignment() const { return tile_assignment_; } private: HloSharding() : replicated_(true), maximal_(true), + tuple_(false), tile_shape_(), tile_assignment_({0}) {} explicit HloSharding(int64 device_id) : replicated_(false), maximal_(true), + tuple_(false), tile_shape_(), tile_assignment_({1}, device_id) {} HloSharding(const Shape& tile_shape, const Array& tile_assignment) : replicated_(false), maximal_(false), + tuple_(false), tile_shape_(tile_shape), tile_assignment_(tile_assignment) {} + HloSharding(const std::vector& tuple_shardings) + : replicated_(false), + maximal_(false), + tuple_(true), + tile_assignment_({0}), + tuple_elements_(tuple_shardings) {} + + // Internal helper to validate a tuple sharding. + Status ValidateTuple(const Shape& shape, int64 num_devices) const; + // Internal helper to validate a non-tuple (leaf) sharding. + Status ValidateNonTuple(const Shape& shape, int64 num_devices) const; bool replicated_; bool maximal_; + bool tuple_; Shape tile_shape_; Array tile_assignment_; + // Only non-empty when tuple_ is true, but because empty tuples are allowed + // may also be empty even then. This is a flattened list of all the leaf + // shardings in a tuple shape, by pre-order walk (ShapeTree iterator order). + std::vector tuple_elements_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_sharding_test.cc b/tensorflow/compiler/xla/service/hlo_sharding_test.cc index d0a20471a0f22a5fa414b71bb5160eed7cdc431b..0c7487b3ac77ff181d44dd55ebcf2608feaf02ea 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_test.cc @@ -70,6 +70,11 @@ TEST_F(HloShardingTest, DevicePlacement) { /*num_devices=*/6)); EXPECT_IS_NOT_OK( sharding.Validate(ShapeUtil::MakeShape(U32, {4}), /*num_devices=*/5)); + + ShapeTree shape_tree = + sharding.GetAsShapeTree(ShapeUtil::MakeShape(U32, {4})); + EXPECT_EQ(shape_tree.element({}), sharding); + EXPECT_TRUE(shape_tree.IsLeaf({})); } TEST_F(HloShardingTest, Tile) { @@ -132,6 +137,39 @@ TEST_F(HloShardingTest, Tile) { } } +TEST_F(HloShardingTest, NestedTuple) { + // nested_tuple_shape = (f32[], (f32[3]), f32[4, 6]) + Shape nested_tuple_shape = ShapeUtil::MakeTupleShape({ + ShapeUtil::MakeShape(F32, {}), + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {3})}), + ShapeUtil::MakeShape(F32, {4, 6}), + }); + + HloSharding tiled_sharding = HloSharding::Tile( + ShapeUtil::MakeShape(F32, {4, 3}), Array({{0, 1}})); + OpSharding proto; + proto.set_type(OpSharding::Type::OpSharding_Type_TUPLE); + *proto.add_tuple_shardings() = HloSharding::Replicate().ToProto(); + *proto.add_tuple_shardings() = HloSharding::AssignDevice(0).ToProto(); + *proto.add_tuple_shardings() = tiled_sharding.ToProto(); + HloSharding tuple_sharding = + HloSharding::FromProto(proto).ConsumeValueOrDie(); + + ShapeTree shape_tree = + tuple_sharding.GetAsShapeTree(nested_tuple_shape); + EXPECT_EQ(shape_tree.element({0}), HloSharding::Replicate()); + EXPECT_EQ(shape_tree.element({1, 0}), HloSharding::AssignDevice(0)); + EXPECT_EQ(shape_tree.element({2}), tiled_sharding); + + EXPECT_IS_OK(tuple_sharding.Validate(nested_tuple_shape, /*num_devices=*/5)); + // Test should fail because tuple element count does not match. + EXPECT_IS_NOT_OK(tuple_sharding.Validate(ShapeUtil::MakeTupleShape({}), + /*num_devices=*/5)); + // Test should fail because the input type is not a tuple. + EXPECT_IS_NOT_OK(tuple_sharding.Validate(ShapeUtil::MakeShape(F32, {}), + /*num_devices=*/5)); +} + TEST_F(HloShardingTest, Hash) { auto hash_compare_equal = [](const HloSharding& a, const HloSharding& b) { if (a.Hash() != b.Hash()) { @@ -184,6 +222,51 @@ TEST_F(HloShardingTest, Hash) { MakeArray({2, 2}, {0, 3, 1, 2})); EXPECT_FALSE(hash_compare_equal(sharding1, sharding2)); } + + HloSharding default_sharding = HloSharding::Replicate(); + { + ShapeTree shape_tree(ShapeUtil::MakeTupleShape({}), + default_sharding); + HloSharding sharding1 = HloSharding::Replicate(); + HloSharding sharding2 = HloSharding::Tuple(shape_tree); + EXPECT_FALSE(hash_compare_equal(sharding1, sharding2)); + } + + { + ShapeTree shape_tree(ShapeUtil::MakeTupleShape({}), + default_sharding); + HloSharding sharding1 = HloSharding::Tuple(shape_tree); + HloSharding sharding2 = HloSharding::Tuple(shape_tree); + EXPECT_TRUE(hash_compare_equal(sharding1, sharding2)); + } + + { + ShapeTree shape_tree1( + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {4})}), + default_sharding); + *shape_tree1.mutable_element({0}) = HloSharding::Replicate(); + ShapeTree shape_tree2( + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {4})}), + default_sharding); + *shape_tree2.mutable_element({0}) = HloSharding::AssignDevice(0); + HloSharding sharding1 = HloSharding::Tuple(shape_tree1); + HloSharding sharding2 = HloSharding::Tuple(shape_tree2); + EXPECT_FALSE(hash_compare_equal(sharding1, sharding2)); + } + + { + ShapeTree shape_tree1( + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {4})}), + default_sharding); + *shape_tree1.mutable_element({0}) = HloSharding::AssignDevice(0); + ShapeTree shape_tree2( + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {4})}), + default_sharding); + *shape_tree2.mutable_element({0}) = HloSharding::AssignDevice(0); + HloSharding sharding1 = HloSharding::Tuple(shape_tree1); + HloSharding sharding2 = HloSharding::Tuple(shape_tree2); + EXPECT_TRUE(hash_compare_equal(sharding1, sharding2)); + } } } // namespace diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc index 06abe007477dbcd00bcdc7f2656c4dece6d1cf74..101a710d1cad9401134fdfe1d0ec9df241bc01e1 100644 --- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc +++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc @@ -58,8 +58,6 @@ TensorShapeProto GetTensorShape(const HloInstruction* instruction) { string GetDeviceName(int device) { return StrCat("/device/XLA:", device); } -} // namespace - void CleanNodeName(string* name) { name->erase(std::remove(name->begin(), name->end(), '%'), name->end()); const string chars_to_replace = "<>[]"; @@ -70,6 +68,11 @@ void CleanNodeName(string* name) { std::replace_if(name->begin(), name->end(), pred, '_'); } +} // namespace + +HloTfGraphBuilder::HloTfGraphBuilder(const DebugOptions& debug_options) + : debug_options_(debug_options) {} + Status HloTfGraphBuilder::AddComputation(const HloComputation& computation) { VLOG(2) << "Adding computation " << computation.name(); for (auto embedded : computation.MakeEmbeddedComputationsList()) { @@ -90,24 +93,38 @@ const string& HloTfGraphBuilder::GetNodeNameForInstruction( if (ContainsKey(instruction_to_node_name_, instruction)) { return instruction_to_node_name_[instruction]; } + auto append = [](string* str, const string& other) { + if (str->empty()) { + *str = other; + } else if (!other.empty()) { + StrAppend(str, "/", other); + } + }; string node_name; + if (debug_options_.xla_hlo_tfgraph_device_scopes() && + instruction->has_sharding() && + instruction->sharding().HasUniqueDevice()) { + node_name = StrCat( + "dev", instruction->sharding().UniqueDevice().ConsumeValueOrDie()); + } // If an instruction is fused, put it in the subgraph of the fusion; // otherwise, put it in the computation subgraph. const HloComputation* computation = instruction->parent(); if (computation->IsFusionComputation()) { - node_name = GetNodeNameForInstruction(computation->FusionInstruction()); + append(&node_name, + GetNodeNameForInstruction(computation->FusionInstruction())); } else { - node_name = computation->name(); + append(&node_name, computation->name()); if (!instruction->metadata().op_name().empty()) { // Always make computations contain TF ops but not the other way around. - StrAppend(&node_name, "/", instruction->metadata().op_name()); + append(&node_name, instruction->metadata().op_name()); } } string instruction_name = instruction->name(); if (instruction->opcode() == HloOpcode::kParameter) { StrAppend(&instruction_name, ".", instruction->parameter_number()); } - StrAppend(&node_name, "/", instruction_name); + append(&node_name, instruction_name); CleanNodeName(&node_name); auto ret = instruction_to_node_name_.insert(std::make_pair(instruction, node_name)); diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.h b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.h index b2c578af912ac0b777d1bc72a198504735a6b845..9aa3e501d5f85e3b61b20555e3d13c5687f33f2f 100644 --- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.h +++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.h @@ -17,6 +17,7 @@ limitations under the License. #define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TFGRAPH_BUILDER_H_ #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/xla.pb.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" @@ -26,6 +27,8 @@ namespace hlo_graph_dumper { // This constructs a tensorflow graph for HLO computations. class HloTfGraphBuilder { public: + HloTfGraphBuilder(const DebugOptions& debug_options = DebugOptions()); + // Adds a computation to the graph. Status AddComputation(const HloComputation& computation); @@ -42,6 +45,7 @@ class HloTfGraphBuilder { Status AddInstruction(const HloInstruction* instruction); + DebugOptions debug_options_; tensorflow::GraphDef graph_def_; // This records instructions that have been visited. std::unordered_set visited_instructions_; @@ -49,9 +53,6 @@ class HloTfGraphBuilder { std::unordered_map instruction_to_node_name_; }; -// Cleans the node name to make it a valid name in a tensorflow graph. -void CleanNodeName(string* name); - } // namespace hlo_graph_dumper } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_value.cc b/tensorflow/compiler/xla/service/hlo_value.cc index e6cf0d37b8a0f42dc04cfaad067a4741bc803705..05b7dce3d1ecf935b80ba1cb46ef089b7b3b6f33 100644 --- a/tensorflow/compiler/xla/service/hlo_value.cc +++ b/tensorflow/compiler/xla/service/hlo_value.cc @@ -71,7 +71,7 @@ HloValue::HloValue(HloValue::Id id, HloInstruction* instruction, const ShapeIndex& index, bool is_phi) : id_(id), is_phi_(is_phi) { // The defining position is always the first element in the positions_ vector. - AddPosition(instruction, index); + positions_.push_back(HloPosition{instruction, index}); } bool HloValue::operator==(const HloValue& other) const { @@ -130,18 +130,14 @@ bool MayUseOperandValue(int64 operand_number, const ShapeIndex& index, CHECK_LE(operand_number, 2); return operand_number == 0 || index.empty(); - case HloOpcode::kCall: case HloOpcode::kTuple: // These instructions always pass through their operands transparently. return false; + case HloOpcode::kCall: case HloOpcode::kWhile: - // Though the while instructions passes through its operands, we return - // true because in SSA form there may be a Phi at the parameter of the - // while which is considered a use of its incoming value because the Phi - // input values are not passed through into the body computation. Because - // this function is used in both SSA and non-SSA forms of the analysis - // conservatively return true. + // Although call and while instructions pass through their operands, they + // are considered uses. return true; default: @@ -151,103 +147,58 @@ bool MayUseOperandValue(int64 operand_number, const ShapeIndex& index, } // namespace -void HloValue::AddPosition(HloInstruction* instruction, - const ShapeIndex& index) { - HloPosition new_position{instruction, index}; - - // The new position must not already exist in positions_. - for (const HloPosition& position : positions_) { - DCHECK_NE(position, new_position); - } - - positions_.push_back(std::move(new_position)); - - // Update uses. - for (HloInstruction* user : instruction->users()) { - for (int64 operand_number : user->OperandIndices(instruction)) { - if (MayUseOperandValue(operand_number, index, user)) { - HloUse new_use{user, operand_number, index}; - - // The new use must not already exist in uses_. - for (const HloUse& use : uses_) { - DCHECK_NE(use, new_use); - } - - uses_.push_back(std::move(new_use)); +void HloValue::SetPositionsAndComputeUses( + tensorflow::gtl::ArraySlice positions) { + CHECK_EQ(positions_.size(), 1) << "SetPositions should only be called once."; + + // The positions must be unique and should not contain the defining position + // as this is added at construction time. + for (const HloPosition& position_a : positions) { + DCHECK_NE(position_a, defining_position()); + for (const HloPosition& position_b : positions) { + if (&position_a != &position_b) { + DCHECK_NE(position_a, position_b); } } } - // Update liveout status of this HloValue. - const HloModule& module = *instruction->parent()->parent(); - if (instruction == module.entry_computation()->root_instruction()) { - live_out_of_module_ = true; - } - - if (instruction == instruction->parent()->root_instruction()) { - live_out_of_computation_ = true; - } -} + positions_.insert(positions_.end(), positions.begin(), positions.end()); -void HloValue::RemovePosition(HloInstruction* instruction, - const ShapeIndex& index) { - // The defining position cannot be removed. - CHECK(!(instruction == defining_instruction() && index == defining_index())); - - int64 size_before = positions_.size(); - positions_.erase( - std::remove_if(positions_.begin(), positions_.end(), - [instruction, &index](const HloPosition& position) { - return position.instruction == instruction && - position.index == index; - }), - positions_.end()); - // Only a single position should have been removed. - CHECK_EQ(positions_.size(), size_before - 1); - - // Update uses which referred to this position. - uses_.erase(std::remove_if(uses_.begin(), uses_.end(), - [instruction, &index](const HloUse& use) { - return use.instruction->operand( - use.operand_number) == instruction && - use.operand_index == index; - }), - uses_.end()); - - // Returns whether this value is contained in the given instruction's output. - auto is_contained_in = [this](const HloInstruction* instruction) { - for (const HloPosition& position : positions()) { - if (position.instruction == instruction) { - return true; - } + // Gather the computation roots at which this value appears. + tensorflow::gtl::FlatSet root_positions; + for (const HloPosition& position : positions_) { + if (position.instruction == + position.instruction->parent()->root_instruction()) { + root_positions.insert(position.instruction); } - return false; - }; - - const HloModule& module = *instruction->parent()->parent(); - if (instruction == module.entry_computation()->root_instruction()) { - // Value has been removed from a position in the entry root instruction. - live_out_of_module_ = - is_contained_in(module.entry_computation()->root_instruction()); - } - if (instruction == defining_instruction()->parent()->root_instruction()) { - // Value has been removed from the root of the computation the value has - // been defined in. - live_out_of_computation_ = - is_contained_in(defining_instruction()->parent()->root_instruction()); } -} -void HloValue::RecomputeUses() { - uses_.clear(); - for (const HloPosition& position : positions()) { + // Build vector of HloUses for the value. + for (const HloPosition& position : positions_) { for (HloInstruction* user : position.instruction->users()) { for (int64 operand_number : user->OperandIndices(position.instruction)) { - if (MayUseOperandValue(operand_number, position.index, user)) { - uses_.push_back(HloUse{user, operand_number, position.index}); + // Root instructions of computations are considered to be uses whether + // or not the root instruction itself actually uses the value. + if (MayUseOperandValue(operand_number, position.index, user) || + ContainsKey(root_positions, user)) { + HloUse new_use{user, operand_number, position.index}; + + // The new use must not already exist in uses_. + for (const HloUse& use : uses_) { + DCHECK_NE(use, new_use); + } + + uses_.push_back(std::move(new_use)); } } } + + // Update liveout status of this HloValue. + const HloModule& module = *position.instruction->parent()->parent(); + if (position.instruction == + module.entry_computation()->root_instruction()) { + live_out_of_module_ = true; + } } } diff --git a/tensorflow/compiler/xla/service/hlo_value.h b/tensorflow/compiler/xla/service/hlo_value.h index 6872bc76a82253b916e826aa1afabc3d309c1d12..2a711e8b42590c29d0aaab95dcf110063ada3182 100644 --- a/tensorflow/compiler/xla/service/hlo_value.h +++ b/tensorflow/compiler/xla/service/hlo_value.h @@ -121,6 +121,12 @@ class HloValue { HloValue(Id id, HloInstruction* instruction, const ShapeIndex& index, bool is_phi = false); + // Sets the positions in the module at which the HloValue appears. Updates + // uses. Should be called once and only once. The defining position should not + // be included in 'positions' as this is set at construction time. + void SetPositionsAndComputeUses( + tensorflow::gtl::ArraySlice positions); + // Return a unique identifier for this HloValue. This value is used for stable // sorting and iteration Id id() const { return id_; } @@ -143,28 +149,15 @@ class HloValue { // Return the shape of this HloValue. const Shape& shape() const { return defining_position().shape(); } - // Add or remove a position at which the HloValue appears. The definition - // position can not be removed. The uses of the HloValue are updated. - void AddPosition(HloInstruction* instruction, const ShapeIndex& index); - void RemovePosition(HloInstruction* instruction, const ShapeIndex& index); - - // Remove all positions except the defining position. Updates uses. - void ClearPositions(); - // Return all positions of the HloValue in the module. const std::vector& positions() const { return positions_; } // Return all uses of the HloValue. const std::vector& uses() const { return uses_; } - void RecomputeUses(); - // Get whether this HloValue is live out of the module. bool live_out_of_module() const { return live_out_of_module_; } - // Get whether this HloValue is live out of the computation it is defined in. - bool live_out_of_computation() const { return live_out_of_computation_; } - bool operator==(const HloValue& other) const; bool operator!=(const HloValue& other) const; diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index c1aa655401a2be68af943e2ed29c4ab99d341383..b8fd7a89efd4d86630eed1f29db5b7b1b7876d23 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -59,21 +59,27 @@ class ShapeVerifier : public DfsHloVisitor { } Status HandleConvert(HloInstruction* convert) override { - if (ShapeUtil::ElementIsComplex(convert->operand(0)->shape())) { - TF_RET_CHECK(ShapeUtil::ElementIsComplex(convert->shape())) - << "Unsupported complex->real kConvert"; - } return CheckShape(convert, ShapeInference::InferConvertShape( convert->operand(0)->shape(), convert->shape().element_type())); } + Status HandleBitcastConvert(HloInstruction* convert) override { + return CheckShape(convert, ShapeInference::InferBitcastConvertShape( + convert->operand(0)->shape(), + convert->shape().element_type())); + } + Status HandleCopy(HloInstruction* copy) override { return CheckUnaryShape(copy); } Status HandleDot(HloInstruction* dot) override { - return CheckBinaryShape(dot); + TF_ASSIGN_OR_RETURN(const Shape expected, + ShapeInference::InferDotOpShape( + dot->operand(0)->shape(), dot->operand(1)->shape(), + dot->dot_dimension_numbers())); + return CheckShape(dot, expected); } Status HandleConvolution(HloInstruction* convolution) override { @@ -87,8 +93,12 @@ class ShapeVerifier : public DfsHloVisitor { } Status HandleCrossReplicaSum(HloInstruction* crs) override { - return CheckShape(crs, ShapeInference::InferCrossReplicaSumShape( - crs->operand(0)->shape())); + std::vector operand_shapes; + for (const HloInstruction* operand : crs->operands()) { + operand_shapes.push_back(&operand->shape()); + } + return CheckShape( + crs, ShapeInference::InferCrossReplicaSumShape(operand_shapes)); } Status HandleReducePrecision(HloInstruction* reduce_precision) override { @@ -141,9 +151,6 @@ class ShapeVerifier : public DfsHloVisitor { } Status HandleBitcast(HloInstruction* bitcast) override { - // Bitcasts can be any shape, as long as the size matches the operand size. - TF_RET_CHECK(shape_size_fn_(bitcast->shape()) == - shape_size_fn_(bitcast->operand(0)->shape())); return tensorflow::Status::OK(); } @@ -263,6 +270,15 @@ class ShapeVerifier : public DfsHloVisitor { xla_while->while_body()->ComputeProgramShape().result()); } + Status HandleConditional(HloInstruction* conditional) override { + TF_RETURN_IF_ERROR(CheckShape( + conditional, + conditional->true_computation()->ComputeProgramShape().result())); + return CheckShape( + conditional, + conditional->false_computation()->ComputeProgramShape().result()); + } + Status HandlePad(HloInstruction* pad) override { return CheckShape(pad, ShapeInference::InferPadShape(pad->operand(0)->shape(), @@ -270,12 +286,40 @@ class ShapeVerifier : public DfsHloVisitor { pad->padding_config())); } - Status HandleSend(HloInstruction*) override { - return tensorflow::Status::OK(); + Status HandleSend(HloInstruction* send) override { + TF_RET_CHECK(send->users().size() == 1); + const HloInstruction* send_done = send->users().front(); + TF_RET_CHECK(send_done->opcode() == HloOpcode::kSendDone); + TF_RETURN_IF_ERROR(CheckSameChannel(send, send_done)); + return CheckShape( + send, ShapeUtil::MakeTupleShape( + {send->operand(0)->shape(), ShapeUtil::MakeShape(U32, {})})); } - Status HandleRecv(HloInstruction*) override { - return tensorflow::Status::OK(); + Status HandleSendDone(HloInstruction* send_done) override { + TF_RET_CHECK(send_done->operands().size() == 1); + const HloInstruction* send = send_done->operand(0); + TF_RET_CHECK(send->opcode() == HloOpcode::kSend); + TF_RETURN_IF_ERROR(CheckSameChannel(send, send_done)); + return CheckShape(send_done, ShapeUtil::MakeNil()); + } + + Status HandleRecv(HloInstruction* recv) override { + TF_RET_CHECK(recv->users().size() == 1); + const HloInstruction* recv_done = recv->users().front(); + TF_RET_CHECK(recv_done->opcode() == HloOpcode::kRecvDone); + TF_RETURN_IF_ERROR(CheckSameChannel(recv, recv_done)); + return CheckShape(recv, + ShapeUtil::MakeTupleShape( + {recv_done->shape(), ShapeUtil::MakeShape(U32, {})})); + } + + Status HandleRecvDone(HloInstruction* recv_done) override { + TF_RET_CHECK(recv_done->operands().size() == 1); + const HloInstruction* recv = recv_done->operand(0); + TF_RET_CHECK(recv->opcode() == HloOpcode::kRecv); + TF_RETURN_IF_ERROR(CheckSameChannel(recv, recv_done)); + return CheckShape(recv_done, recv->shape().tuple_shapes(0)); } Status HandleBatchNormTraining(HloInstruction* batch_norm_training) override { @@ -365,6 +409,19 @@ class ShapeVerifier : public DfsHloVisitor { instruction->opcode(), instruction->operands())); } + // Checks if the given two instructions shares the same channel id. + Status CheckSameChannel(const HloInstruction* instr1, + const HloInstruction* instr2) { + if (instr1->channel_id() != instr2->channel_id()) { + return FailedPrecondition( + "Expected to have the same channel id, actual channel ids are: %s " + "(%lld), %s (%lld)", + instr1->ToString().c_str(), instr1->channel_id(), + instr2->ToString().c_str(), instr2->channel_id()); + } + return tensorflow::Status::OK(); + } + // Returns the size of a Shape in bytes. const std::function shape_size_fn_; }; @@ -530,7 +587,7 @@ StatusOr HloVerifier::Run(HloModule* module) { // or ComputationLowerer::Visit() TF_RET_CHECK(instruction->dimensions().size() == ShapeUtil::Rank(instruction->operand(0)->shape())) - << "Broadcast HLO has invalid number of dimensions."; + << "Broadcast HLO has invalid number of dimensions."; } else if (instruction->opcode() == HloOpcode::kWhile) { auto* while_cond = instruction->while_condition(); auto* while_body = instruction->while_body(); diff --git a/tensorflow/compiler/xla/service/human_readable_profile_builder.cc b/tensorflow/compiler/xla/service/human_readable_profile_builder.cc index d620f45d27eba706fbd7fc30d3b27b0d963475d4..b7c40fdeeb157fc74900bd9cf9d68a06a2cb1d56 100644 --- a/tensorflow/compiler/xla/service/human_readable_profile_builder.cc +++ b/tensorflow/compiler/xla/service/human_readable_profile_builder.cc @@ -68,12 +68,20 @@ string HumanReadableProfileBuilder::ToString() const { }; float optimal_seconds_sum = 0.0; + int64 total_flops = 0.; + int64 total_transcendentals = 0.; + int64 total_bytes = 0; for (const auto& op : op_infos_) { optimal_seconds_sum += op.optimal_seconds; + total_flops += op.flop_count; + total_transcendentals += op.transcendental_count; + total_bytes += op.bytes_accessed; } - append_op({"[total]", "[total]", /*category=*/"", total_cycles_, -1, -1, -1, - optimal_seconds_sum}); + VLOG(1) << "Total floating point ops: " << total_flops; + + append_op({"[total]", "[total]", /*category=*/"", total_cycles_, total_flops, + total_transcendentals, total_bytes, optimal_seconds_sum}); // Sort ops in decreasing order of cycles. std::vector sorted_ops(op_infos_); diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index 0d1b7bc109c56bc4290ede09284c6d20142bdb08..ba901b99e4f3c72c84c1ecdf4e19e58ad9ab6506 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -33,7 +33,9 @@ namespace xla { switch (instruction.opcode()) { // Cheap instructions. case HloOpcode::kAdd: + case HloOpcode::kAnd: case HloOpcode::kBitcast: + case HloOpcode::kBitcastConvert: case HloOpcode::kBroadcast: case HloOpcode::kCeil: case HloOpcode::kClamp: @@ -53,15 +55,14 @@ namespace xla { case HloOpcode::kInfeed: case HloOpcode::kIsFinite: case HloOpcode::kLe: - case HloOpcode::kAnd: - case HloOpcode::kNot: - case HloOpcode::kOr: case HloOpcode::kLt: case HloOpcode::kMaximum: case HloOpcode::kMinimum: case HloOpcode::kMultiply: case HloOpcode::kNe: case HloOpcode::kNegate: + case HloOpcode::kNot: + case HloOpcode::kOr: case HloOpcode::kOutfeed: case HloOpcode::kPad: case HloOpcode::kReal: @@ -88,10 +89,11 @@ namespace xla { // Expensive instructions. case HloOpcode::kAtan2: - case HloOpcode::kBatchNormTraining: - case HloOpcode::kBatchNormInference: case HloOpcode::kBatchNormGrad: + case HloOpcode::kBatchNormInference: + case HloOpcode::kBatchNormTraining: case HloOpcode::kCall: + case HloOpcode::kConditional: case HloOpcode::kConvolution: case HloOpcode::kCrossReplicaSum: case HloOpcode::kCustomCall: @@ -103,17 +105,19 @@ namespace xla { case HloOpcode::kMap: case HloOpcode::kParameter: case HloOpcode::kPower: + case HloOpcode::kRecv: + case HloOpcode::kRecvDone: case HloOpcode::kReduce: case HloOpcode::kReduceWindow: case HloOpcode::kRemainder: case HloOpcode::kRng: case HloOpcode::kSelectAndScatter: + case HloOpcode::kSend: + case HloOpcode::kSendDone: case HloOpcode::kSort: case HloOpcode::kTanh: case HloOpcode::kTrace: case HloOpcode::kWhile: - case HloOpcode::kSend: - case HloOpcode::kRecv: return true; } diff --git a/tensorflow/compiler/xla/service/interpreter/BUILD b/tensorflow/compiler/xla/service/interpreter/BUILD index b273f091f148ad2155067782a51adb41ae557797..2704a805a91b93c69b751cdb61305ea7780f0ef2 100644 --- a/tensorflow/compiler/xla/service/interpreter/BUILD +++ b/tensorflow/compiler/xla/service/interpreter/BUILD @@ -52,8 +52,8 @@ cc_library( "//tensorflow/compiler/xla/service:inliner", "//tensorflow/compiler/xla/service:layout_assignment", "//tensorflow/compiler/xla/service:reshape_mover", + "//tensorflow/compiler/xla/service:while_loop_simplifier", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/stream_executor", ], alwayslink = True, # Contains compiler registration diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.cc b/tensorflow/compiler/xla/service/interpreter/compiler.cc index 93ea2f736742eab06ee0d7e881ee7c51daee9878..dc63a2224d659fa427d4d1a30c5dc0f94d643b36 100644 --- a/tensorflow/compiler/xla/service/interpreter/compiler.cc +++ b/tensorflow/compiler/xla/service/interpreter/compiler.cc @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/interpreter/executable.h" #include "tensorflow/compiler/xla/service/layout_assignment.h" #include "tensorflow/compiler/xla/service/reshape_mover.h" +#include "tensorflow/compiler/xla/service/while_loop_simplifier.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" @@ -56,6 +57,7 @@ Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) { pipeline.AddPass>( false, [](const Shape&, const Shape&) { return false; }); + pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(true); @@ -67,13 +69,19 @@ Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) { return pipeline.Run(hlo_module).status(); } -StatusOr> InterpreterCompiler::Compile( +StatusOr> InterpreterCompiler::RunHloPasses( + std::unique_ptr hlo_module, + se::StreamExecutor* /*stream_exec*/) { + VLOG(1) << "Run hlo passes on graph " << hlo_module->name(); + TF_RETURN_IF_ERROR(RunHloOptimization(hlo_module.get())); + return std::move(hlo_module); +} + +StatusOr> InterpreterCompiler::RunBackend( std::unique_ptr hlo_module, se::StreamExecutor* stream_exec) { TF_RET_CHECK(stream_exec != nullptr); - VLOG(1) << "Generate graph " << hlo_module->name(); - - TF_RETURN_IF_ERROR(RunHloOptimization(hlo_module.get())); + VLOG(1) << "Run backend " << hlo_module->name(); // Typically you would visit the HLO graph, building up a compiled equivalent // In this case we are using an HloEvaluator at execution time, so we don't diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.h b/tensorflow/compiler/xla/service/interpreter/compiler.h index cfdc9b6256569b0137784b0d1db846a5f2339a5d..278cf5184227ae25518b1d46c0e16e4cce7bd1a8 100644 --- a/tensorflow/compiler/xla/service/interpreter/compiler.h +++ b/tensorflow/compiler/xla/service/interpreter/compiler.h @@ -43,8 +43,12 @@ class InterpreterCompiler : public Compiler { InterpreterCompiler() {} ~InterpreterCompiler() override {} - StatusOr> Compile( - std::unique_ptr hlo_modules, + StatusOr> RunHloPasses( + std::unique_ptr hlo_module, + perftools::gputools::StreamExecutor* stream_exec) override; + + StatusOr> RunBackend( + std::unique_ptr hlo_module, perftools::gputools::StreamExecutor* stream_exec) override; StatusOr>> Compile( diff --git a/tensorflow/compiler/xla/service/interpreter/executable.cc b/tensorflow/compiler/xla/service/interpreter/executable.cc index 86dee8462fd4fdda580ada892e244f19177fb3e5..9183a1d1bfb8c2f6e1933c004f9c9f5f9ad8eced 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.cc +++ b/tensorflow/compiler/xla/service/interpreter/executable.cc @@ -42,7 +42,8 @@ namespace sep = ::perftools::gputools::interpreter; InterpreterExecutable::InterpreterExecutable( std::unique_ptr hlo_module) - : Executable(std::move(hlo_module)) {} + : Executable(std::move(hlo_module), /*hlo_profile_printer=*/nullptr, + /*hlo_profile_index_map=*/nullptr) {} InterpreterExecutable::~InterpreterExecutable() {} @@ -89,7 +90,7 @@ StatusOr InterpreterExecutable::ExecuteOnStream( uint64 start_micros = tensorflow::Env::Default()->NowMicros(); - HloComputation* computation = module().entry_computation(); + const HloComputation* computation = module().entry_computation(); if (computation->num_parameters() != arguments.size()) { return tensorflow::errors::Internal( "Mismatch between argument count and graph parameter count."); @@ -156,10 +157,5 @@ StatusOr InterpreterExecutable::ExecuteAsyncOnStream( return ShapeUtil::ByteSizeOf(shape, sizeof(void*)); } -std::unique_ptr InterpreterExecutable::CreateCostAnalysis() - const { - return MakeUnique(ShapeSizeBytes); -} - } // namespace interpreter } // namespace xla diff --git a/tensorflow/compiler/xla/service/interpreter/executable.h b/tensorflow/compiler/xla/service/interpreter/executable.h index c69b0d036d1058a6b24ee609a9923895d3246eec..0e87eb90bff4b896fc4bc0efc4fa7b851631be6f 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.h +++ b/tensorflow/compiler/xla/service/interpreter/executable.h @@ -61,8 +61,6 @@ class InterpreterExecutable : public Executable { static int64 ShapeSizeBytes(const Shape& shape); - std::unique_ptr CreateCostAnalysis() const override; - private: TF_DISALLOW_COPY_AND_ASSIGN(InterpreterExecutable); }; diff --git a/tensorflow/compiler/xla/service/interpreter/executor.cc b/tensorflow/compiler/xla/service/interpreter/executor.cc index 0bb3259ef43915067e614e72038387e8300ecc41..511de87b1be10741a4632d82cf726071c5c3fc12 100644 --- a/tensorflow/compiler/xla/service/interpreter/executor.cc +++ b/tensorflow/compiler/xla/service/interpreter/executor.cc @@ -100,9 +100,9 @@ bool InterpreterExecutor::StopTimer(Stream *stream, Timer *timer) { return true; } -bool InterpreterExecutor::BlockHostUntilDone(Stream *stream) { +port::Status InterpreterExecutor::BlockHostUntilDoneWithStatus(Stream *stream) { AsExecutorStream(stream)->BlockUntilDone(); - return true; + return port::Status::OK(); } DeviceDescription *InterpreterExecutor::PopulateDeviceDescription() const { diff --git a/tensorflow/compiler/xla/service/interpreter/executor.h b/tensorflow/compiler/xla/service/interpreter/executor.h index c59b2ccb1505b78be0c459ac9311428d65cc7e44..d3753a6a65d64c3d77644367bbd82068d4cf3044 100644 --- a/tensorflow/compiler/xla/service/interpreter/executor.h +++ b/tensorflow/compiler/xla/service/interpreter/executor.h @@ -157,7 +157,7 @@ class InterpreterExecutor : public internal::StreamExecutorInterface { bool StartTimer(Stream *stream, Timer *timer) override; bool StopTimer(Stream *stream, Timer *timer) override; - bool BlockHostUntilDone(Stream *stream) override; + port::Status BlockHostUntilDoneWithStatus(Stream *stream) override; int PlatformDeviceCount() override { return 1; } diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index 7eda7c2284c2457703fcfcd4226172e41dd4ae01..328afe42bad64713013f761a6819ae8a47a52e04 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -1303,7 +1303,7 @@ Status LayoutAssignment::AssignLayouts(const LayoutConstraints& constraints, TF_RET_CHECK(LayoutUtil::HasLayout(instruction->shape())); } - // Copy the root instrucion's result if the it does not match the result + // Copy the root instruction's result if the it does not match the result // layout constraint if (constraints.ResultLayout() != nullptr && !constraints.ResultLayout()->MatchesLayoutInShape( @@ -1328,6 +1328,20 @@ Status LayoutAssignment::RunOnComputation( << ")"; VLOG(2) << " ComputationLayout = " << computation_layout.ToString(); + // Clear existing layouts of the instructions. All layouts must be assigned by + // the LayoutAssignment pass, except for Infeed, Outfeed, Parameters and the + // computation result. The latter two are specified in computation_layout, so + // we only need to keep the existing layouts for Infeed and Outfeed. Clearing + // the layouts here avoids hiding potential bugs in the layout assignment pass + // that may accidently use the existing layout. + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kInfeed || + instruction->opcode() == HloOpcode::kOutfeed) { + continue; + } + LayoutUtil::ClearLayout(instruction->mutable_shape()); + } + // Construct LayoutConstraints with all layout constraints of the computation. LayoutConstraints constraints(points_to_analysis, computation); diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index c39ff52230055ec322ecf77f8df8ebdea12cdb6c..d51c0d1dfb727801d6d2a8328eba60838373479f 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -131,10 +131,10 @@ TEST_F(LayoutAssignmentTest, FusionInstruction) { std::vector> minor_to_majors = {{0, 1}, {1, 0}}; for (auto& minor_to_major : minor_to_majors) { auto builder = HloComputation::Builder(TestName()); - auto constant_literal1 = test_utils::CreateR2LiteralWithLayout( - {{1.0, 2.0}, {3.0, 4.0}}, minor_to_major); - auto constant_literal2 = test_utils::CreateR2LiteralWithLayout( - {{5.0, 6.0}, {7.0, 8.0}}, minor_to_major); + auto constant_literal1 = Literal::CreateR2WithLayout( + {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout(minor_to_major)); + auto constant_literal2 = Literal::CreateR2WithLayout( + {{5.0, 6.0}, {7.0, 8.0}}, LayoutUtil::MakeLayout(minor_to_major)); Shape ashape = constant_literal1->shape(); auto constant1 = builder.AddInstruction( @@ -181,12 +181,12 @@ TEST_F(LayoutAssignmentTest, TupleLayout) { // Verify the layouts of a tuple are assigned properly (the element layouts // match their source). auto builder = HloComputation::Builder(TestName()); - auto constant0 = builder.AddInstruction(HloInstruction::CreateConstant( - test_utils::CreateR2LiteralWithLayout({{1.0, 2.0}, {3.0, 4.0}}, - {0, 1}))); - auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant( - test_utils::CreateR2LiteralWithLayout({{1.0, 2.0}, {3.0, 4.0}}, - {1, 0}))); + auto constant0 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2WithLayout( + {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1})))); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2WithLayout( + {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0})))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant0, constant1})); @@ -218,12 +218,12 @@ TEST_F(LayoutAssignmentTest, TupleLayout) { TEST_F(LayoutAssignmentTest, TupleSelect) { // Verify layouts of a select with tuple operands is assigned properly. auto builder = HloComputation::Builder(TestName()); - auto constant0 = builder.AddInstruction(HloInstruction::CreateConstant( - test_utils::CreateR2LiteralWithLayout({{1.0, 2.0}, {3.0, 4.0}}, - {0, 1}))); - auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant( - test_utils::CreateR2LiteralWithLayout({{1.0, 2.0}, {3.0, 4.0}}, - {1, 0}))); + auto constant0 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2WithLayout( + {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1})))); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2WithLayout( + {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({1, 0})))); auto tuple0 = builder.AddInstruction( HloInstruction::CreateTuple({constant0, constant1})); auto tuple1 = builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/liveness_util.cc b/tensorflow/compiler/xla/service/liveness_util.cc index c27a8956a706febd1855854a2d0560754caf5c03..68c99256a246edcf43a8358f667fc4458b9b4fea 100644 --- a/tensorflow/compiler/xla/service/liveness_util.cc +++ b/tensorflow/compiler/xla/service/liveness_util.cc @@ -103,7 +103,7 @@ namespace { // Returns all uses of all aliases of 'instruction' at 'index' in 'uses'. // Each use in 'uses' is a pair (HloInstruction* user, int64 operand_index) -// where 'user' is a user of an alias of 'intruction' at 'index', and +// where 'user' is a user of an alias of 'instruction' at 'index', and // 'operand_index' is the operand index at which the alias appears in the // operand list of 'user'. std::vector> GetAllUsesOfInstructionAtIndex( @@ -215,7 +215,8 @@ bool CanShareOperandBufferWithUser( auto add_operand_it = std::find_if(add->operands().begin(), add->operands().end(), [&](HloInstruction* operand) { - return operand->opcode() == HloOpcode::kDot || + return operand->opcode() == HloOpcode::kConvolution || + operand->opcode() == HloOpcode::kDot || (operand->opcode() == HloOpcode::kFusion && operand->fusion_kind() == HloInstruction::FusionKind::kTransposeDot); @@ -242,6 +243,31 @@ bool CanShareOperandBufferWithUser( std::vector operand_indices = user->OperandIndices(operand); return operand_indices.size() == 1 && operand_indices[0] == 0; } + if (user->opcode() == HloOpcode::kCall) { + // TODO(b/62548313): Remove when buffer assignment is module scoped and + // does not assign buffers to calls. + // Find called computation parameter associated with 'operand'. + const std::vector operand_indices = user->OperandIndices(operand); + if (operand_indices.size() > 1) { + return false; + } + CHECK_EQ(1, operand_indices.size()); + auto* param = user->to_apply()->parameter_instruction(operand_indices[0]); + // Get all uses of 'operand' at 'index' in called computation. + auto param_uses = GetAllUsesOfInstructionAtIndex(param, operand_index, + points_to_analysis); + + // Return true iff: + // *) There exists exactly one use of 'operand' in called computation. + // *) The unique use is by the root instruction of called computation. + // (Note: we check the root of the called computation, because the + // root result buffer is required to alias with the Call result buffer). + // *) The root instruction of the called computation is element-wise on + // 'operand'. + auto* callee_root = user->to_apply()->root_instruction(); + return param_uses.size() == 1 && param_uses[0].first == callee_root && + callee_root->IsElementwiseOnOperand(param_uses[0].second); + } // Check if 'user' is element-wise. return user->IsElementwise(); } @@ -294,7 +320,8 @@ bool CanShareOperandBufferWithUser(HloInstruction* operand, auto add_operand_it = std::find_if(add->operands().begin(), add->operands().end(), [&](HloInstruction* operand) { - return operand->opcode() == HloOpcode::kDot || + return operand->opcode() == HloOpcode::kConvolution || + operand->opcode() == HloOpcode::kDot || (operand->opcode() == HloOpcode::kFusion && operand->fusion_kind() == HloInstruction::FusionKind::kTransposeDot); @@ -320,6 +347,31 @@ bool CanShareOperandBufferWithUser(HloInstruction* operand, std::vector operand_indices = user->OperandIndices(operand); return operand_indices.size() == 1 && operand_indices[0] == 0; } + if (user->opcode() == HloOpcode::kCall) { + // Get all uses of value defined by 'operand' at 'operand_index'. + const auto& uses = + dataflow.GetValueDefinedAt(operand, operand_index).uses(); + // Return true iff: + // *) There exists two uses of 'operand'. + // *) One use is by 'user' (caller). + // *) One use is by root instruction of called computation (callee root). + // (Note: we check the root of the called computation, because the + // root result buffer is required to alias with the Call result buffer). + // *) The root instruction of the called computation is element-wise on + // 'operand'. + const bool found_caller_use = + std::find_if(uses.begin(), uses.end(), [user](const HloUse& use) { + return use.instruction == user; + }) != uses.end(); + auto* callee_root = user->to_apply()->root_instruction(); + const bool found_elementwise_callee_use = + std::find_if( + uses.begin(), uses.end(), [callee_root](const HloUse& use) { + return use.instruction == callee_root && + callee_root->IsElementwiseOnOperand(use.operand_number); + }) != uses.end(); + return uses.size() == 2 && found_caller_use && found_elementwise_callee_use; + } // Check if 'user' is element-wise. return user->IsElementwise(); } diff --git a/tensorflow/compiler/xla/service/liveness_util_test.cc b/tensorflow/compiler/xla/service/liveness_util_test.cc index b5e15906d3c085f773eb46b543515a614e63c59a..2c2a02f6375343d67dfb155bbb03729ff6e490d2 100644 --- a/tensorflow/compiler/xla/service/liveness_util_test.cc +++ b/tensorflow/compiler/xla/service/liveness_util_test.cc @@ -277,8 +277,11 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) { auto b = builder.AddInstruction(HloInstruction::CreateConstant( Literal::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); auto dot = builder.AddInstruction( - HloInstruction::CreateBinary(data_shape, HloOpcode::kDot, a, b)); + HloInstruction::CreateDot(data_shape, a, b, dot_dnums)); auto one = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(1.0))); @@ -312,8 +315,11 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedTransposeDotAdd) { auto b_t = builder.AddInstruction( HloInstruction::CreateTranspose(data_shape, b, {1, 0})); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); auto dot = builder.AddInstruction( - HloInstruction::CreateBinary(data_shape, HloOpcode::kDot, a, b_t)); + HloInstruction::CreateDot(data_shape, a, b_t, dot_dnums)); auto one = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(1.0))); @@ -415,5 +421,44 @@ TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) { CanShareOperandBufferWithUser(data, {}, whil, {}, *dataflow_analysis_)); } +// Tests that Call can alias operand buffer if the only use of the operand +// in the called computation is an elementwise instruction. +TEST_F(CanShareOperandBufferWithUserTest, CallToComputationWithFusionRoot) { + Shape shape = ShapeUtil::MakeShape(F32, {8}); + // Build sub-computation with fusion root. + auto sub_builder = HloComputation::Builder(TestName() + "_sub"); + auto sub_param = sub_builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "sub_param")); + auto one = sub_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + auto ones = sub_builder.AddInstruction( + HloInstruction::CreateBroadcast(shape, one, {1})); + auto add = sub_builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, sub_param, ones)); + + module_ = CreateNewModule(); + auto sub_computation = module_->AddEmbeddedComputation(sub_builder.Build()); + sub_computation->CreateFusionInstruction({add, ones}, + HloInstruction::FusionKind::kLoop); + + // Build entry-computation with kCall which calls 'sub_computation'. + auto builder = HloComputation::Builder(TestName()); + + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param")); + auto reverse = + builder.AddInstruction(HloInstruction::CreateReverse(shape, param, {0})); + auto call = builder.AddInstruction( + HloInstruction::CreateCall(shape, {reverse}, sub_computation)); + computation_ = module_->AddEntryComputation(builder.Build()); + + RunAnalysis(); + + EXPECT_TRUE(CanShareOperandBufferWithUser(reverse, {}, call, {}, + *points_to_analysis_)); + EXPECT_TRUE(CanShareOperandBufferWithUser(reverse, {}, call, {}, + *dataflow_analysis_)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_compiler.cc b/tensorflow/compiler/xla/service/llvm_compiler.cc new file mode 100644 index 0000000000000000000000000000000000000000..34f3419269abbc73cd0ddb13c723a8da38ab19ff --- /dev/null +++ b/tensorflow/compiler/xla/service/llvm_compiler.cc @@ -0,0 +1,39 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/llvm_compiler.h" + +namespace xla { +StatusOr>> LLVMCompiler::Compile( + std::vector> modules, + std::vector> + stream_execs) { + std::vector> result; + for (size_t i = 0; i < modules.size(); i++) { + if (stream_execs[i].size() != 1) { + return Unimplemented( + "Model partitioning not implemented for the CPU/GPU compilers!"); + } + + TF_ASSIGN_OR_RETURN( + modules[i], RunHloPasses(std::move(modules[i]), stream_execs[i][0])); + TF_ASSIGN_OR_RETURN(std::unique_ptr executable, + RunBackend(std::move(modules[i]), stream_execs[i][0])); + result.push_back(std::move(executable)); + } + + return {std::move(result)}; +} +} // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_compiler.h b/tensorflow/compiler/xla/service/llvm_compiler.h index b2e72871c10192c84349b117797c7bd7e6ee251a..c5393cef4f961c5d04c32d0d4291732b8ec702f1 100644 --- a/tensorflow/compiler/xla/service/llvm_compiler.h +++ b/tensorflow/compiler/xla/service/llvm_compiler.h @@ -57,6 +57,21 @@ class LLVMCompiler : public Compiler { void RemovePostOptimizationHook() { user_post_optimization_hook_ = nullptr; } + // Bring in + // StatusOr> RunBackend( + // std::unique_ptr module, + // perftools::gputools::StreamExecutor* stream_exec) + // StatusOr> RunHloPasses( + // std::unique_ptr module, + // perftools::gputools::StreamExecutor* stream_exec) + using Compiler::RunBackend; + using Compiler::RunHloPasses; + + StatusOr>> Compile( + std::vector> modules, + std::vector> + stream_execs) override; + protected: ModuleHook user_pre_optimization_hook_; ModuleHook user_post_optimization_hook_; diff --git a/tensorflow/compiler/xla/service/llvm_ir/BUILD b/tensorflow/compiler/xla/service/llvm_ir/BUILD index 075d4a1ab5e5f39394ade393d21525ca3e97136e..d878061f724de1c82f8285b0f082d0be4d5778df 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/BUILD +++ b/tensorflow/compiler/xla/service/llvm_ir/BUILD @@ -48,6 +48,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:name_uniquer", "//tensorflow/core:lib", "@llvm//:core", "@llvm//:support", @@ -155,6 +156,30 @@ cc_library( ], ) +cc_library( + name = "vector_support_library", + srcs = ["vector_support_library.cc"], + hdrs = ["vector_support_library.h"], + deps = [ + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", + "@llvm//:core", + ], +) + +cc_library( + name = "kernel_support_library", + srcs = ["kernel_support_library.cc"], + hdrs = ["kernel_support_library.h"], + deps = [ + ":llvm_loop", + "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", + "//tensorflow/core:lib", + "@llvm//:core", + ], +) + # ----------------------------------------------------------------------------- filegroup( diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc index bdddc232ef74dfa37e2d5cc780b0fe11e7bc8e76..21bca1d6beff5b2804531724b94b123d4523c173 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc @@ -83,7 +83,7 @@ void AliasAnalysis::AddAliasingInformationToIrArray(const HloInstruction& hlo, if (std::find(parameter_instructions.begin(), parameter_instructions.end(), &hlo) != parameter_instructions.end()) { - array->AddInvariantLoad(llvm::MDNode::get(*context_, /*MDs=*/{})); + array->MarkInvariantOverWholeProgram(context_); } } } diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc index e3f98ac13e76f0df465066422ca7918a0f218b60..7224bd689842d89563b374f3db3d4e314be18764 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc @@ -256,10 +256,10 @@ void IrArray::AnnotateLoadStoreInstructionWithMetadata( llvm::Instruction* instruction) const { CHECK(llvm::isa(instruction) || llvm::isa(instruction)); + CHECK(!llvm::isa(instruction) || !is_invariant_) + << "Trying to create a store to an invariant IRArray."; for (const auto& kind_md_pair : metadata_) { - CHECK(kind_md_pair.first != llvm::LLVMContext::MD_invariant_load || - llvm::isa(instruction)); instruction->setMetadata(kind_md_pair.first, kind_md_pair.second); } } diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h index 1ed7e99a829f5b0daa709913554d2300503ca33e..387d4629125cbb791840e943013188d14159908a 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h @@ -229,9 +229,33 @@ class IrArray { AddMetadata(llvm::LLVMContext::MD_noalias, noalias); } - void AddInvariantLoad(llvm::MDNode* invariant_load) { - CHECK_NE(invariant_load, nullptr); - AddMetadata(llvm::LLVMContext::MD_invariant_load, invariant_load); + // Promises LLVM that the data pointed to by this IrArray never changes after + // it's first loaded. + // + // The temporal scope of this promise is the "whole program" from LLVM's point + // of view, but how this translates to HLOs differs between backends. + // + // In the single-threaded CPU backend, we emit one function that + // runs all the HLOs in sequence, so the whole program is the whole HLO + // module. + // + // In the GPU backend, we emit one GPU kernel per top-level HLO (i.e. per HLO + // in the entry computation). From LLVM's perspective, launching a new kernel + // is like launching a new program, and so the whole program is one top-level + // HLO. Since the scope of the promise is smaller than in the CPU backend, we + // can mark more things as invariant in the GPU backend. + // + // Marking loads as invariant is particularly helpful on GPUs because + // invariant loads can be lowered to PTX ld.global.nc (equivalent to CUDA's + // __ldg intrinsic). These loads use a special cache, and can be + // significantly faster than regular loads. + void MarkInvariantOverWholeProgram(llvm::LLVMContext* context) { + if (is_invariant_) { + return; + } + is_invariant_ = true; + AddMetadata(llvm::LLVMContext::MD_invariant_load, + llvm::MDNode::get(*context, {})); } const std::map& metadata() const { return metadata_; } @@ -261,6 +285,8 @@ class IrArray { // loads/stores for this array. They keys are the metadata kinds and the // values are the metadata nodes. std::map metadata_; + + bool is_invariant_ = false; }; } // namespace llvm_ir diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc new file mode 100644 index 0000000000000000000000000000000000000000..d68d699d7ef420bb644829125e46b5f565c93825 --- /dev/null +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc @@ -0,0 +1,111 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h" + +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" + +namespace xla { +void KernelSupportLibrary::For( + tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + llvm::Value* step, + const std::function& for_body_generator) { + If(ir_builder_->CreateICmpSLT(start, end), [&]() { + for_body_generator(start, /*is_first_iteration=*/true); + For(name, ir_builder_->CreateAdd(start, step), end, step, + [&](llvm::Value* iv) { for_body_generator(iv, false); }); + }); +} + +void KernelSupportLibrary::For( + tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + llvm::Value* step, bool peel_first_iteration, + const std::function& for_body_generator) { + if (peel_first_iteration) { + For(name, start, end, step, true, + [&](llvm::Value* indvar, bool is_first_iteration) { + for_body_generator(indvar, ir_builder_->getInt1(is_first_iteration)); + }); + } else { + std::unique_ptr loop = llvm_ir::ForLoop::EmitForLoop( + name, start, end, step, ir_builder_, + /*prevent_unrolling=*/prevent_unrolling_, + /*prevent_vectorization=*/prevent_vectorization_); + ir_builder_->SetInsertPoint(&loop->GetBodyBasicBlock()->back()); + for_body_generator(loop->GetIndVarValue(), + /*is_first_iteration=*/ir_builder_->CreateICmpEQ( + loop->GetIndVarValue(), start)); + llvm_ir::SetToLastInsertPoint(loop->GetExitBasicBlock(), ir_builder_); + } +} + +void KernelSupportLibrary::If( + llvm::Value* condition, const std::function& true_block_generator, + const std::function& false_block_generator) { + llvm_ir::LlvmIfData if_data = + llvm_ir::EmitIfThenElse(condition, "", ir_builder_); + ir_builder_->SetInsertPoint(&if_data.true_block->back()); + true_block_generator(); + ir_builder_->SetInsertPoint(&if_data.false_block->back()); + false_block_generator(); + llvm_ir::SetToLastInsertPoint(if_data.after_block, ir_builder_); +} + +void KernelSupportLibrary::EmitAndCallOutlinedKernel( + bool enable_fast_math, bool optimize_for_size, + llvm::IRBuilder<>* ir_builder, tensorflow::StringPiece kernel_name, + KernelSupportLibrary::ArgumentVector arguments, + const std::function& + kernel_body_generator) { + llvm::Module* module = ir_builder->GetInsertBlock()->getModule(); + llvm::Function* function = + module->getFunction(llvm_ir::AsStringRef(kernel_name)); + if (!function) { + VLOG(2) << "Generating kernel for " << kernel_name; + std::vector arg_types; + std::transform(arguments.begin(), arguments.end(), + std::back_inserter(arg_types), + [](llvm::Value* arg) { return arg->getType(); }); + + auto* function_type = llvm::FunctionType::get( + ir_builder->getVoidTy(), arg_types, /*isVarArg=*/false); + + function = llvm_ir::CreateFunction( + function_type, llvm::GlobalValue::InternalLinkage, + /*enable_fast_math=*/enable_fast_math, + /*optimize_for_size=*/optimize_for_size, kernel_name, module); + + llvm::IRBuilder<>::InsertPointGuard guard(*ir_builder); + + auto* entry_bb = + llvm::BasicBlock::Create(ir_builder->getContext(), "entry", function); + auto* return_inst = llvm::ReturnInst::Create(ir_builder->getContext(), + /*retVal=*/nullptr, entry_bb); + // Set the insert point to before return_inst. + ir_builder->SetInsertPoint(return_inst); + + std::vector arg_values; + std::transform(function->arg_begin(), function->arg_end(), + std::back_inserter(arg_values), std::addressof); + kernel_body_generator(arg_values); + } else { + VLOG(3) << "Re-using kernel for " << kernel_name; + } + + ir_builder->CreateCall(function, llvm_ir::AsArrayRef(arguments)); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h new file mode 100644 index 0000000000000000000000000000000000000000..150a464c66961a0e68149bb4729d60cc4e363ba3 --- /dev/null +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h @@ -0,0 +1,163 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_KERNEL_SUPPORT_LIBRARY_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_KERNEL_SUPPORT_LIBRARY_H_ + +#include + +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Value.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/core/lib/core/stringpiece.h" + +namespace xla { +// A thin wrapper around llvm_loop.h to make code generating structured control +// flow more readable. +class KernelSupportLibrary { + public: + // `ir_builder` is the llvm::IRBuilder instance used to generate LLVM IR. + // If `prevent_unrolling` is true then unrolling is explicitly disabled on + // every loop generated by this instance of KernelSupportLibrary. + explicit KernelSupportLibrary(llvm::IRBuilder<>* ir_builder, + bool prevent_unrolling = true, + bool prevent_vectorization = true) + : ir_builder_(ir_builder), + prevent_unrolling_(prevent_unrolling), + prevent_vectorization_(prevent_vectorization) {} + + // Generates the following control flow structure: + // + // if (`start` < `end`) { + // `for_body_generator(/*ind_var=*/start, /*is_first_iteration=*/true)`; + // for (i64 i = `start` + `step`; i s< `end`; i += `step`) + // `for_body_generator(/*ind_var=*/,i, /*is_first_iteration=*/false)`; + // } + void For( + tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + llvm::Value* step, + const std::function& + for_body_generator); + + void For( + tensorflow::StringPiece name, int64 start, int64 end, int64 step, + const std::function& + for_body_generator) { + For(name, /*start=*/ir_builder_->getInt64(start), + /*end=*/ir_builder_->getInt64(end), + /*step=*/ir_builder_->getInt64(step), for_body_generator); + } + + // Generates the following control flow structure if `peel_first_iteration` is + // true: + // + // if (`start` < `end`) { + // `for_body_generator(/*ind_var=*/start, /*is_first_iteration=*/,true)`; + // for (i64 i = `start` + `step`; i s< `end`; i += `step`) + // `for_body_generator(/*ind_var=*/,i, /*is_first_iteration=*/,false)`; + // } + // + // and the following if `peel_first_iteration` is false: + // + // for (i64 i = `start`; i s< `end`; i += `step`) + // `for_body_generator(/*ind_var=*/,i, + // /*is_first_iteration=*/,(i != `start`))`; + void For(tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + llvm::Value* step, bool peel_first_iteration, + const std::function& + for_body_generator); + + void For(tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + int64 step, bool peel_first_iteration, + const std::function& + for_body_generator) { + For(name, /*start=*/start, /*end=*/end, + /*step=*/ir_builder_->getInt64(step), peel_first_iteration, + for_body_generator); + } + + void For( + tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, + llvm::Value* step, + const std::function& for_body_generator) { + For(name, start, end, step, + /*peel_first_iteration=*/false, + [&](llvm::Value* indvar, llvm::Value*) { for_body_generator(indvar); }); + } + + void For( + tensorflow::StringPiece name, int64 start, int64 end, int64 step, + const std::function& for_body_generator) { + For(name, /*start=*/ir_builder_->getInt64(start), + /*end=*/ir_builder_->getInt64(end), + /*step=*/ir_builder_->getInt64(step), for_body_generator); + } + + // Generates the following control flow structure: + // + // if (`condition`) + // `true_block_generator()`; + // else + // `false_block_generator()`; + void If(llvm::Value* condition, + const std::function& true_block_generator, + const std::function& false_block_generator = []() {}); + + using ArgumentVector = tensorflow::gtl::ArraySlice; + + // Generates the following control flow structure: + // + // define @`kernel_name`(arg0, arg1, ... arg`arguments.size()`) { + // kernel_body_generator({arg0, arg1, ... arg`arguments.size()`}); + // } + // + // ... + // call @`kernel_name`(arguments[0], arguments[1] ...) + // ... + // + // If a function called `kernel_name` is already present in the module then + // that function is re-used. In that sense we're using the llvm::Module as a + // cache of outlined kernels, keyed by function name. + static void EmitAndCallOutlinedKernel( + bool enable_fast_math, bool optimize_for_size, + llvm::IRBuilder<>* ir_builder, tensorflow::StringPiece kernel_name, + ArgumentVector arguments, + const std::function& kernel_body_generator); + + // Thin wrapper around the more general EmitAndCallOutlinedKernel above. + static void EmitAndCallOutlinedKernel( + bool enable_fast_math, bool optimize_for_size, + llvm::IRBuilder<>* ir_builder, tensorflow::StringPiece kernel_name, + llvm::Value* arg0, llvm::Value* arg1, llvm::Value* arg2, + const std::function& + kernel_body_generator) { + EmitAndCallOutlinedKernel( + enable_fast_math, optimize_for_size, ir_builder, kernel_name, + {arg0, arg1, arg2}, [&](ArgumentVector args) { + kernel_body_generator(args[0], args[1], args[2]); + }); + } + + private: + llvm::IRBuilder<>* ir_builder_; + bool prevent_unrolling_; + bool prevent_vectorization_; +}; +} // namespace xla + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_KERNEL_SUPPORT_LIBRARY_H_ diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc index 83d35cb9efca0c27765045ce214e0e1060b18ed0..7b227ce294176cfbbf7308bbf65afe21814f3dea 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc @@ -34,21 +34,24 @@ namespace llvm_ir { ForLoop::ForLoop(tensorflow::StringPiece prefix, tensorflow::StringPiece suffix, llvm::Value* start_index, llvm::Value* end_index, - llvm::Value* step, bool prevent_unrolling) + llvm::Value* step, bool prevent_unrolling, + bool prevent_vectorization) : prefix_(prefix.ToString()), suffix_(suffix.ToString()), start_index_(start_index), end_index_(end_index), step_(step), insert_before_bb_(nullptr), - prevent_unrolling_(prevent_unrolling) {} + prevent_unrolling_(prevent_unrolling), + prevent_vectorization_(prevent_vectorization) {} /* static */ std::unique_ptr ForLoop::EmitForLoop( tensorflow::StringPiece prefix, llvm::Value* start_index, llvm::Value* end_index, llvm::Value* step, llvm::IRBuilder<>* ir_builder, - bool prevent_unrolling) { - std::unique_ptr loop(new ForLoop( - prefix, /*suffix=*/"", start_index, end_index, step, prevent_unrolling)); + bool prevent_unrolling, bool prevent_vectorization) { + std::unique_ptr loop(new ForLoop(prefix, /*suffix=*/"", start_index, + end_index, step, prevent_unrolling, + prevent_vectorization)); loop->Emit(ir_builder); return loop; } @@ -127,14 +130,12 @@ void ForLoop::Emit(llvm::IRBuilder<>* ir_builder) { ir_builder->CreateStore(indvar_inc, indvar_address); llvm::BranchInst* back_branch = ir_builder->CreateBr(header_bb_); - if (prevent_unrolling_) { - const char* const kLlvmLoopUnrollDisableMDName = "llvm.loop.unroll.disable"; - llvm::LLVMContext* ctx = &back_branch->getContext(); - + std::vector loop_metadata = GetLoopMetadata(ir_builder); + if (!loop_metadata.empty()) { + llvm::LLVMContext* ctx = &start_index_->getContext(); auto temp_node = llvm::MDNode::getTemporary(*ctx, llvm::None); - auto no_unroll_node = llvm::MDNode::get( - *ctx, {llvm::MDString::get(*ctx, kLlvmLoopUnrollDisableMDName)}); - auto loop_id = llvm::MDNode::get(*ctx, {temp_node.get(), no_unroll_node}); + loop_metadata.insert(loop_metadata.begin(), temp_node.get()); + auto loop_id = llvm::MDNode::get(*ctx, loop_metadata); loop_id->replaceOperandWith(0, loop_id); back_branch->setMetadata(llvm::LLVMContext::MD_loop, loop_id); } @@ -143,6 +144,27 @@ void ForLoop::Emit(llvm::IRBuilder<>* ir_builder) { ir_builder->SetInsertPoint(exit_bb_); } +std::vector ForLoop::GetLoopMetadata( + llvm::IRBuilder<>* ir_builder) { + const char* const kLlvmLoopUnrollDisableMDName = "llvm.loop.unroll.disable"; + const char* const kLlvmLoopVectorizeMDName = "llvm.loop.vectorize.enable"; + llvm::LLVMContext* ctx = &start_index_->getContext(); + + std::vector result; + if (prevent_unrolling_) { + result.push_back(llvm::MDNode::get( + *ctx, {llvm::MDString::get(*ctx, kLlvmLoopUnrollDisableMDName)})); + } + + if (prevent_vectorization_) { + result.push_back(llvm::MDNode::get( + *ctx, {llvm::MDString::get(*ctx, kLlvmLoopVectorizeMDName), + llvm::ConstantAsMetadata::get(ir_builder->getFalse())})); + } + + return result; +} + string ForLoop::GetQualifiedName(tensorflow::StringPiece name) { return llvm_ir::IrName(prefix_, llvm_ir::IrName(name, suffix_)); } @@ -156,23 +178,25 @@ llvm::BasicBlock* ForLoop::CreateLoopBB(tensorflow::StringPiece name, std::unique_ptr ForLoopNest::AddLoop(tensorflow::StringPiece suffix, llvm::Value* start_index, llvm::Value* end_index, - bool prevent_unrolling) { + bool prevent_unrolling, + bool prevent_vectorization) { return AddLoop(suffix, start_index, end_index, ir_builder_->getInt64(1), - prevent_unrolling); + prevent_unrolling, prevent_vectorization); } std::unique_ptr ForLoopNest::AddLoop(tensorflow::StringPiece suffix, llvm::Value* start_index, llvm::Value* end_index, llvm::Value* stride, - bool prevent_unrolling) { + bool prevent_unrolling, + bool prevent_vectorization) { if (inner_loop_body_bb_ != nullptr) { // Create this loop inside the previous one. ir_builder_->SetInsertPoint(&*inner_loop_body_bb_->getFirstInsertionPt()); } std::unique_ptr loop(new ForLoop( /*prefix=*/name_, suffix, start_index, end_index, stride, - prevent_unrolling)); + prevent_unrolling, prevent_vectorization)); loop->Emit(ir_builder_); if (outer_loop_preheader_bb_ == nullptr) { @@ -191,20 +215,24 @@ std::unique_ptr ForLoopNest::AddLoop(tensorflow::StringPiece suffix, std::unique_ptr ForLoopNest::AddLoop(int64 start_index, int64 end_index, tensorflow::StringPiece suffix, - bool prevent_unrolling) { + bool prevent_unrolling, + bool prevent_vectorization) { CHECK_LE(start_index, end_index); return AddLoop(suffix, ir_builder_->getInt64(start_index), - ir_builder_->getInt64(end_index), prevent_unrolling); + ir_builder_->getInt64(end_index), prevent_unrolling, + prevent_vectorization); } std::unique_ptr ForLoopNest::AddLoop(int64 start_index, int64 end_index, int64 stride, tensorflow::StringPiece suffix, - bool prevent_unrolling) { + bool prevent_unrolling, + bool prevent_vectorization) { CHECK_LE(start_index, end_index); return AddLoop(suffix, ir_builder_->getInt64(start_index), ir_builder_->getInt64(end_index), - ir_builder_->getInt64(stride), prevent_unrolling); + ir_builder_->getInt64(stride), prevent_unrolling, + prevent_vectorization); } IrArray::Index ForLoopNest::AddLoopsForShape(const Shape& shape, diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h index 90f7c7df9e22d6404e9fdad2ce210506583bd427..20069ce5a28184a5a9216d1a3751d1cee547727d 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h @@ -71,12 +71,10 @@ class ForLoop { // // If `prevent_unrolling` is true then emit metadata that directs LLVM to not // unroll the generated loop. - static std::unique_ptr EmitForLoop(tensorflow::StringPiece prefix, - llvm::Value* start_index, - llvm::Value* end_index, - llvm::Value* step, - llvm::IRBuilder<>* ir_builder, - bool prevent_unrolling = false); + static std::unique_ptr EmitForLoop( + tensorflow::StringPiece prefix, llvm::Value* start_index, + llvm::Value* end_index, llvm::Value* step, llvm::IRBuilder<>* ir_builder, + bool prevent_unrolling = false, bool prevent_vectorization = false); // The names of the blocks follow LLVM's conventions. Control flow amongst the // blocks for the example C code looks like: @@ -130,7 +128,7 @@ class ForLoop { ForLoop(tensorflow::StringPiece prefix, tensorflow::StringPiece suffix, llvm::Value* start_index, llvm::Value* end_index, llvm::Value* step, - bool prevent_unrolling); + bool prevent_unrolling, bool prevent_vectorization); // Emit the loop at the insert point of the builder. void Emit(llvm::IRBuilder<>* ir_builder); @@ -142,6 +140,10 @@ class ForLoop { // they are set. string GetQualifiedName(tensorflow::StringPiece name); + // Return a list of metadata nodes that should be associated with the + // llvm::Loop for this `ForLoop`. + std::vector GetLoopMetadata(llvm::IRBuilder<>* ir_builder); + string prefix_; string suffix_; llvm::Value* start_index_; @@ -160,6 +162,7 @@ class ForLoop { llvm::BasicBlock* exit_bb_; llvm::Value* indvar_; bool prevent_unrolling_; + bool prevent_vectorization_; TF_DISALLOW_COPY_AND_ASSIGN(ForLoop); }; @@ -185,24 +188,28 @@ class ForLoopNest { std::unique_ptr AddLoop(tensorflow::StringPiece suffix, llvm::Value* start_index, llvm::Value* end_index, llvm::Value* stride, - bool prevent_unrolling = false); + bool prevent_unrolling = false, + bool prevent_vectorization = false); // Like the above, except that it defaults to a stride of one. std::unique_ptr AddLoop(tensorflow::StringPiece suffix, llvm::Value* start_index, llvm::Value* end_index, - bool prevent_unrolling = false); + bool prevent_unrolling = false, + bool prevent_vectorization = false); // A convenient wrapper of the other flavor of AddLoop. The given start and // end index are constant. std::unique_ptr AddLoop(int64 start_index, int64 end_index, int64 stride, tensorflow::StringPiece suffix, - bool prevent_unrolling = false); + bool prevent_unrolling = false, + bool prevent_vectorization = false); // Like the above, except that it defaults to a stride of one. std::unique_ptr AddLoop(int64 start_index, int64 end_index, tensorflow::StringPiece suffix, - bool prevent_unrolling = false); + bool prevent_unrolling = false, + bool prevent_vectorization = false); // Add loops to iterate through the indices within the specified // shape. The returned index collects the induction variables of the diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc index 5dff4b5778970dd473c5f158b3828a850847d1ff..9a0c94b1c73c48682c1e868d4518b3797b01bbed 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc @@ -25,6 +25,7 @@ limitations under the License. #include "llvm/Target/TargetOptions.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/name_uniquer.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" @@ -141,6 +142,13 @@ llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type, return llvm::Type::getInt8Ty(module->getContext()); case S16: case U16: + case BF16: + // For BF16 we just need some type that is 16 bits wide so that it will + // take up the right amount of space in memory. LLVM does not have a BF16 + // type (the LLVM half type is IEEE 16 bit floating point, not bfloat), so + // we can't map it directly to an LLVM type. We will not map a BF16 + // addition to an addition on this type (int16) - this is just the type + // used for storage. return llvm::Type::getInt16Ty(module->getContext()); case S32: case U32: @@ -163,8 +171,9 @@ llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type, // z, and reinterpret_cast(z)[1] shall designate the // imaginary part of z. return llvm::StructType::create( - "complex64", llvm::Type::getFloatTy(module->getContext()), - llvm::Type::getFloatTy(module->getContext())); + {llvm::Type::getFloatTy(module->getContext()), + llvm::Type::getFloatTy(module->getContext())}, + "complex64", /*isPacked=*/true); } return cplx_t; } @@ -178,6 +187,21 @@ llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type, } } +int GetSizeInBits(llvm::Type* type) { + const llvm::StructType* struct_ty = llvm::dyn_cast(type); + if (struct_ty) { + CHECK(struct_ty->isPacked()); + int bits = 0; + for (auto element_type : struct_ty->elements()) { + bits += GetSizeInBits(element_type); + } + return bits; + } + int bits = type->getPrimitiveSizeInBits(); + CHECK_GT(bits, 0) << "type is not sized"; + return bits; +} + llvm::Type* ShapeToIrType(const Shape& shape, llvm::Module* module) { llvm::Type* result_type = PrimitiveTypeToIrType(shape.element_type(), module); if (ShapeUtil::IsTuple(shape)) { @@ -263,6 +287,11 @@ llvm::Constant* LiteralToConstant(const Literal& literal, int64 dimension_index, value = llvm::ConstantFP::get(ir_element_type, literal.Get(*multi_index)); break; + case BF16: + value = llvm::ConstantInt::get( + ir_element_type, + tensorflow::bit_cast(literal.Get(*multi_index))); + break; case F64: value = llvm::ConstantFP::get(ir_element_type, literal.Get(*multi_index)); @@ -537,6 +566,14 @@ void SetToFirstInsertPoint(llvm::BasicBlock* blk, llvm::IRBuilder<>* builder) { builder->SetInsertPoint(blk, blk->getFirstInsertionPt()); } +void SetToLastInsertPoint(llvm::BasicBlock* blk, llvm::IRBuilder<>* builder) { + if (llvm::Instruction* terminator = blk->getTerminator()) { + builder->SetInsertPoint(terminator); + } else { + builder->SetInsertPoint(blk); + } +} + llvm::Value* CreateRor(llvm::Value* rotand, llvm::Value* rotor, llvm::IRBuilder<>* builder) { auto size = rotand->getType()->getPrimitiveSizeInBits(); @@ -555,8 +592,9 @@ int64 ByteSizeOf(const Shape& shape, const llvm::DataLayout& data_layout) { llvm::FastMathFlags GetFastMathFlags(bool fast_math_enabled) { llvm::FastMathFlags flags; if (fast_math_enabled) { - // UnsafeAlgebra implies NoInfs, NoNaNs, NoSignedZeros, and AllowReciprocal. - flags.setUnsafeAlgebra(); + // Fast implies AllowReassoc, NoInfs, NoNaNs, NoSignedZeros, + // AllowReciprocal, AllowContract, and ApproxFunc. + flags.setFast(); } return flags; } @@ -619,14 +657,27 @@ std::map MergeMetadata( return result; } +static string GetProcessUniqueIrFileName(tensorflow::StringPiece prefix) { + static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED); + static NameUniquer* uniquer = new NameUniquer(/*separator=*/"-"); + + tensorflow::mutex_lock lock(mu); + return uniquer->GetUniqueName(prefix); +} + Status DumpIRToDirectory(const string& directory_name, const string& hlo_module_name, const llvm::Module& llvm_module, bool optimized) { - string safe_file_name_base = SanitizeFileName(hlo_module_name); + // We can end up compiling different modules with the same name when using + // XlaJitCompiledCpuFunction::Compile. Avoid overwriting IR files previously + // dumped from the same process in such cases. + string unique_and_safe_file_name = GetProcessUniqueIrFileName( + tensorflow::strings::StrCat("ir-", SanitizeFileName(hlo_module_name), "-", + optimized ? "with" : "no", "-opt")); + string ir_file_name = tensorflow::io::JoinPath( directory_name, - tensorflow::strings::StrCat("ir-", safe_file_name_base, "-", - optimized ? "with" : "no", "-opt.ll")); + tensorflow::strings::StrCat(unique_and_safe_file_name, ".ll")); std::unique_ptr f; TF_RETURN_IF_ERROR( @@ -637,5 +688,32 @@ Status DumpIRToDirectory(const string& directory_name, return f->Close(); } +llvm::Function* CreateFunction(llvm::FunctionType* function_type, + llvm::GlobalValue::LinkageTypes linkage, + bool enable_fast_math, bool optimize_for_size, + tensorflow::StringPiece name, + llvm::Module* module) { + llvm::Function* function = + llvm::Function::Create(function_type, linkage, AsStringRef(name), module); + function->setCallingConv(llvm::CallingConv::C); + function->addFnAttr("no-frame-pointer-elim", "false"); + + if (enable_fast_math) { + function->addFnAttr("unsafe-fp-math", "true"); + function->addFnAttr("no-infs-fp-math", "true"); + function->addFnAttr("no-nans-fp-math", "true"); + function->addFnAttr("no-signed-zeros-fp-math", "true"); + } + + // Add the optize attribute to the function if optimizing for size. This + // controls internal behavior of some optimization passes (e.g. loop + // unrolling). + if (optimize_for_size) { + function->addFnAttr(llvm::Attribute::OptimizeForSize); + } + + return function; +} + } // namespace llvm_ir } // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h index 304192b58e9331c2544f973bf65299111122aea8..6bdc6a01a2b487df3dd80a02e67f5bcf62dead31 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h @@ -129,6 +129,9 @@ llvm::Value* EmitBufferIndexingGEP(llvm::Value* array, int64 index, llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type, llvm::Module* module); +// Returns the type size in bits. If "type" is a struct, it must be packed. +int GetSizeInBits(llvm::Type* type); + // Returns the LLVM type which represents the given XLA shape. For example, // if "shape" is [5 x [10 x f32]], the function returns [5 x [10 x float]]. llvm::Type* ShapeToIrType(const Shape& shape, llvm::Module* module); @@ -243,6 +246,8 @@ llvm::Instruction* AddRangeMetadata(int64 lower, int64 upper, void SetToFirstInsertPoint(llvm::BasicBlock* blk, llvm::IRBuilder<>* builder); +void SetToLastInsertPoint(llvm::BasicBlock* blk, llvm::IRBuilder<>* builder); + // Create a bitwise rotation of `rotand` by `rotor`. llvm::Value* CreateRor(llvm::Value* rotand, llvm::Value* rotor, llvm::IRBuilder<>* builder); @@ -276,6 +281,12 @@ Status DumpIRToDirectory(const string& directory_name, const string& hlo_module_name, const llvm::Module& llvm_module, bool optimized); +llvm::Function* CreateFunction(llvm::FunctionType* function_type, + llvm::GlobalValue::LinkageTypes linkage, + bool enable_fast_math, bool optimize_for_size, + tensorflow::StringPiece name, + llvm::Module* module); + } // namespace llvm_ir } // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/ops.h b/tensorflow/compiler/xla/service/llvm_ir/ops.h index 11e84d9cb5defbcb87a8f696d56c139686c960d8..f72f482e3128c61e53cc454e7da8b5795ba6f695 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ops.h +++ b/tensorflow/compiler/xla/service/llvm_ir/ops.h @@ -40,11 +40,24 @@ bool CanUpdateDynamicSliceInPlace(HloInstruction* dynamic_update_slice, inline bool CanEmitFusedDynamicUpdateSliceInPlace( HloInstruction* fusion, const BufferAssignment& assignment) { CHECK_EQ(fusion->opcode(), HloOpcode::kFusion); - return fusion->fusion_kind() == HloInstruction::FusionKind::kLoop && - fusion->fused_expression_root()->opcode() == - HloOpcode::kDynamicUpdateSlice && - CanUpdateDynamicSliceInPlace(fusion->fused_expression_root(), - assignment); + HloInstruction* fused_root = fusion->fused_expression_root(); + if (fused_root->opcode() != HloOpcode::kDynamicUpdateSlice || + fusion->fusion_kind() != HloInstruction::FusionKind::kLoop) { + return false; + } + // Walk DynamicUpdateSlice operand(0) to fused parameter and get its + // associated operand. See if it shares an allocation with this operand. + HloInstruction* fusion_operand; + ShapeIndex index; + std::tie(fusion_operand, index) = + fused_root->mutable_operand(0)->LatestNonGteAncestorAndIndex(); + if (fusion_operand->opcode() != HloOpcode::kParameter) { + return false; + } + auto* operand = fusion->operand(fusion_operand->parameter_number()); + return assignment.HasAllocationAt(operand, index) && + assignment.HasAllocationAt(fusion, {}) && + assignment.SharesSliceAtIndex(fusion, {}, operand, index); } // Emits IR for running the given dynamic-update-slice op in-place -- that is, diff --git a/tensorflow/compiler/xla/service/llvm_ir/vector_support_library.cc b/tensorflow/compiler/xla/service/llvm_ir/vector_support_library.cc new file mode 100644 index 0000000000000000000000000000000000000000..59e82960787918d4747ad4dedf4bfb4f2fd40352 --- /dev/null +++ b/tensorflow/compiler/xla/service/llvm_ir/vector_support_library.cc @@ -0,0 +1,268 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/llvm_ir/vector_support_library.h" + +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" + +namespace xla { +VectorSupportLibrary::VectorSupportLibrary(PrimitiveType primitive_type, + int64 vector_size, + llvm::IRBuilder<>* ir_builder, + std::string name) + : vector_size_(vector_size), + primitive_type_(primitive_type), + ir_builder_(ir_builder), + name_(std::move(name)) { + scalar_type_ = llvm_ir::PrimitiveTypeToIrType( + primitive_type, ir_builder_->GetInsertBlock()->getModule()); + scalar_pointer_type_ = llvm::PointerType::getUnqual(scalar_type_); + vector_type_ = llvm::VectorType::get(scalar_type_, vector_size); + vector_pointer_type_ = llvm::PointerType::getUnqual(vector_type_); +} + +llvm::Value* VectorSupportLibrary::Mul(llvm::Value* lhs, llvm::Value* rhs) { + CHECK(lhs->getType() == scalar_type() || lhs->getType() == vector_type()); + return MulInternal(lhs, rhs); +} + +llvm::Value* VectorSupportLibrary::MulInternal(llvm::Value* lhs, + llvm::Value* rhs) { + if (scalar_type_->isFloatingPointTy()) { + return ir_builder()->CreateFMul(lhs, rhs, name()); + } else { + return ir_builder()->CreateMul(lhs, rhs, name()); + } +} + +llvm::Value* VectorSupportLibrary::Add(llvm::Value* lhs, llvm::Value* rhs) { + CHECK(lhs->getType() == scalar_type() || lhs->getType() == vector_type()); + return AddInternal(lhs, rhs); +} + +llvm::Value* VectorSupportLibrary::AddInternal(llvm::Value* lhs, + llvm::Value* rhs) { + if (scalar_type_->isFloatingPointTy()) { + return ir_builder()->CreateFAdd(lhs, rhs, name()); + } else { + return ir_builder()->CreateAdd(lhs, rhs, name()); + } +} + +llvm::Value* VectorSupportLibrary::ComputeOffsetPointer( + llvm::Value* base_pointer, llvm::Value* offset_elements) { + if (base_pointer->getType() != scalar_pointer_type()) { + base_pointer = ir_builder()->CreateBitCast(base_pointer, + scalar_pointer_type(), name()); + } + return ir_builder()->CreateInBoundsGEP(base_pointer, {offset_elements}, + name()); +} + +llvm::Value* VectorSupportLibrary::LoadVector(llvm::Value* pointer) { + if (pointer->getType() != vector_pointer_type()) { + pointer = + ir_builder()->CreateBitCast(pointer, vector_pointer_type(), name()); + } + return ir_builder()->CreateAlignedLoad( + pointer, ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_), name()); +} + +llvm::Value* VectorSupportLibrary::LoadScalar(llvm::Value* pointer) { + if (pointer->getType() != scalar_pointer_type()) { + pointer = + ir_builder()->CreateBitCast(pointer, scalar_pointer_type(), name()); + } + return ir_builder()->CreateAlignedLoad( + pointer, ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_), name()); +} + +void VectorSupportLibrary::StoreVector(llvm::Value* value, + llvm::Value* pointer) { + if (pointer->getType() != vector_pointer_type()) { + pointer = ir_builder()->CreateBitCast(pointer, vector_pointer_type()); + } + ir_builder()->CreateAlignedStore( + value, pointer, ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_)); +} + +void VectorSupportLibrary::StoreScalar(llvm::Value* value, + llvm::Value* pointer) { + if (pointer->getType() != scalar_pointer_type()) { + pointer = + ir_builder()->CreateBitCast(pointer, scalar_pointer_type(), name()); + } + ir_builder()->CreateAlignedStore( + value, pointer, ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_)); +} + +llvm::Value* VectorSupportLibrary::LoadBroadcast(llvm::Value* pointer) { + if (pointer->getType() != scalar_pointer_type()) { + pointer = + ir_builder()->CreateBitCast(pointer, scalar_pointer_type(), name()); + } + return ir_builder()->CreateVectorSplat( + vector_size(), ir_builder()->CreateLoad(pointer), name()); +} + +llvm::Value* VectorSupportLibrary::AddReduce(llvm::Value* vector) { + llvm::SmallVector mask(vector_size(), nullptr); + for (unsigned i = vector_size(); i != 1; i >>= 1) { + // On every iteration, we shuffle half of the remaining lanes to the top + // half of shuffle, and add two old and the new vector. + + for (unsigned j = 0; j < vector_size(); ++j) { + if (j < (i / 2)) { + mask[j] = ir_builder()->getInt32(i / 2 + j); + } else { + mask[j] = llvm::UndefValue::get(ir_builder()->getInt32Ty()); + } + } + + llvm::Value* half_remaining_lanes = ir_builder()->CreateShuffleVector( + vector, llvm::UndefValue::get(vector_type()), + llvm::ConstantVector::get(mask), ""); + vector = Add(vector, half_remaining_lanes); + } + + return ir_builder()->CreateExtractElement(vector, ir_builder()->getInt32(0), + name()); +} + +llvm::Value* VectorSupportLibrary::AvxStyleHorizontalAdd(llvm::Value* lhs, + llvm::Value* rhs) { + CHECK_EQ(lhs->getType(), vector_type()); + CHECK_EQ(rhs->getType(), vector_type()); + CHECK_EQ(vector_size() % 2, 0); + + llvm::SmallVector mask_a, mask_b; + + // Adding the values shuffled using mask_a and mask_b gives us the + // AVX-style horizontal add we want. The masks work as documented + // in https://llvm.org/docs/LangRef.html#shufflevector-instruction + // + // Here are the masks for vector_width() == 8: + // + // index: |0 |1 |2 | 3 |4 |5 | 6 | 7 + // --------+--+--+--+---+--+--+---+--- + // mask_a: |0 |2 |8 |10 |4 |6 |12 |14 + // mask_b: |1 |3 |9 |11 |5 |7 |13 |16 + // + // So, as an example, the value at lane 3 of the result vector is + // the result of adding lane 10 and lane 11 in the combined lhs++rhs + // vector, which are the lanes 2 and 3 in the rhs vector. + for (int i = 0; i < vector_size(); i += 2) { + int increment = i < vector_size() / 2 ? 0 : (vector_size() / 2); + mask_a.push_back(ir_builder()->getInt32(increment + i)); + mask_b.push_back(ir_builder()->getInt32(increment + i + 1)); + } + for (int i = 0; i < vector_size(); i += 2) { + int increment = i < vector_size() / 2 ? (vector_size() / 2) : vector_size(); + mask_a.push_back(ir_builder()->getInt32(increment + i)); + mask_b.push_back(ir_builder()->getInt32(increment + i + 1)); + } + + llvm::Value* shuffle_0 = ir_builder()->CreateShuffleVector( + lhs, rhs, llvm::ConstantVector::get(mask_a)); + llvm::Value* shuffle_1 = ir_builder()->CreateShuffleVector( + lhs, rhs, llvm::ConstantVector::get(mask_b)); + + return Add(shuffle_0, shuffle_1); +} + +llvm::Value* VectorSupportLibrary::ExtractLowHalf(llvm::Value* vector) { + llvm::SmallVector mask; + for (int i = 0; i < vector_size() / 2; i++) { + mask.push_back(ir_builder()->getInt32(i)); + } + + return ir_builder()->CreateShuffleVector(vector, + llvm::UndefValue::get(vector_type()), + llvm::ConstantVector::get(mask)); +} + +llvm::Value* VectorSupportLibrary::ExtractHighHalf(llvm::Value* vector) { + llvm::SmallVector mask; + for (int i = 0; i < vector_size() / 2; i++) { + mask.push_back(ir_builder()->getInt32(i + vector_size() / 2)); + } + + return ir_builder()->CreateShuffleVector(vector, + llvm::UndefValue::get(vector_type()), + llvm::ConstantVector::get(mask)); +} + +std::vector VectorSupportLibrary::ComputeHorizontalSums( + std::vector vectors) { + // TODO(sanjoy): Move this magic constant to TargetMachineFeatures. + const int kAvxVectorWidth = 8; + if (vector_size() == kAvxVectorWidth && vectors.size() == kAvxVectorWidth) { + return ComputeAvxOptimizedHorizontalSums(std::move(vectors)); + } + + std::vector result; + std::transform(vectors.begin(), vectors.end(), std::back_inserter(result), + [this](llvm::Value* vector) { return AddReduce(vector); }); + return result; +} + +std::vector +VectorSupportLibrary::ComputeAvxOptimizedHorizontalSums( + std::vector vectors) { + while (vectors.size() != 2) { + std::vector new_vectors; + for (int i = 0; i < vectors.size(); i += 2) { + new_vectors.push_back(AvxStyleHorizontalAdd(vectors[i], vectors[i + 1])); + } + + vectors = std::move(new_vectors); + } + + llvm::Value* low = + AddInternal(ExtractLowHalf(vectors[0]), ExtractHighHalf(vectors[0])); + llvm::Value* high = + AddInternal(ExtractLowHalf(vectors[1]), ExtractHighHalf(vectors[1])); + + std::vector results; + for (int i = 0; i < 8; i++) { + llvm::Value* scalar_result = ir_builder()->CreateExtractElement( + i < 4 ? low : high, ir_builder()->getInt32(i % 4), name()); + results.push_back(scalar_result); + } + + return results; +} + +llvm::Value* VectorSupportLibrary::GetZeroVector() { + return llvm::Constant::getNullValue(vector_type()); +} + +llvm::Value* VectorSupportLibrary::GetZeroScalar() { + return llvm::Constant::getNullValue(scalar_type()); +} + +LlvmVariable::LlvmVariable(llvm::Type* type, llvm::IRBuilder<>* ir_builder) + : ir_builder_(ir_builder) { + alloca_ = llvm_ir::EmitAllocaAtFunctionEntry(type, "", ir_builder_); +} + +llvm::Value* LlvmVariable::Get() const { + return ir_builder_->CreateLoad(alloca_); +} + +void LlvmVariable::Set(llvm::Value* new_value) { + ir_builder_->CreateStore(new_value, alloca_); +} +} // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/vector_support_library.h b/tensorflow/compiler/xla/service/llvm_ir/vector_support_library.h new file mode 100644 index 0000000000000000000000000000000000000000..f4c7a6a420a55db5760e67cf3725dc9cfe9e8b52 --- /dev/null +++ b/tensorflow/compiler/xla/service/llvm_ir/vector_support_library.h @@ -0,0 +1,205 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_VECTOR_SUPPORT_LIBRARY_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_VECTOR_SUPPORT_LIBRARY_H_ + +#include + +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Value.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +// A thin wrapper around llvm_util.h to make code generating vector math flow +// more readable. +class VectorSupportLibrary { + public: + // This VectorSupportLibrary instance remembers `primitive_type` and + // `vector_size`, and these are implicitly used by the methods on this + // instance (i.e. LoadVector will load a vector of type <`vector_size` x + // `primitive_type`>). + VectorSupportLibrary(PrimitiveType primitive_type, int64 vector_size, + llvm::IRBuilder<>* ir_builder, std::string name); + + llvm::Value* Mul(llvm::Value* lhs, llvm::Value* rhs); + llvm::Value* Mul(int64 lhs, llvm::Value* rhs) { + return Mul(ir_builder()->getInt64(lhs), rhs); + } + + llvm::Value* Add(llvm::Value* lhs, llvm::Value* rhs); + llvm::Value* Add(int64 lhs, llvm::Value* rhs) { + return Add(ir_builder()->getInt64(lhs), rhs); + } + + llvm::Value* MulAdd(llvm::Value* a, llvm::Value* b, llvm::Value* c) { + return Add(c, Mul(a, b)); + } + + llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer, + llvm::Value* offset_elements); + llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer, + int64 offset_elements) { + return ComputeOffsetPointer(base_pointer, + ir_builder()->getInt64(offset_elements)); + } + + llvm::Value* LoadVector(llvm::Value* pointer); + + llvm::Value* LoadVector(llvm::Value* base_pointer, + llvm::Value* offset_elements) { + return LoadVector(ComputeOffsetPointer(base_pointer, offset_elements)); + } + + llvm::Value* LoadVector(llvm::Value* base_pointer, int64 offset_elements) { + return LoadVector(base_pointer, ir_builder()->getInt64(offset_elements)); + } + + llvm::Value* LoadScalar(llvm::Value* pointer); + + llvm::Value* LoadScalar(llvm::Value* base_pointer, + llvm::Value* offset_elements) { + return LoadScalar(ComputeOffsetPointer(base_pointer, offset_elements)); + } + + llvm::Value* LoadScalar(llvm::Value* base_pointer, int64 offset_elements) { + return LoadScalar(base_pointer, ir_builder()->getInt64(offset_elements)); + } + + void StoreVector(llvm::Value* value, llvm::Value* pointer); + + void StoreVector(llvm::Value* value, llvm::Value* base_pointer, + llvm::Value* offset_elements) { + StoreVector(value, ComputeOffsetPointer(base_pointer, offset_elements)); + } + + void StoreVector(llvm::Value* value, llvm::Value* base_pointer, + int64 offset_elements) { + StoreVector(value, base_pointer, ir_builder()->getInt64(offset_elements)); + } + + void StoreScalar(llvm::Value* value, llvm::Value* pointer); + void StoreScalar(llvm::Value* value, llvm::Value* base_pointer, + llvm::Value* offset_elements) { + StoreScalar(value, ComputeOffsetPointer(base_pointer, offset_elements)); + } + + void StoreScalar(llvm::Value* value, llvm::Value* base_pointer, + int64 offset_elements) { + StoreScalar(base_pointer, ir_builder()->getInt64(offset_elements)); + } + + llvm::Value* LoadBroadcast(llvm::Value* pointer); + llvm::Value* LoadBroadcast(llvm::Value* base_pointer, + llvm::Value* offset_elements) { + return LoadBroadcast(ComputeOffsetPointer(base_pointer, offset_elements)); + } + llvm::Value* LoadBroadcast(llvm::Value* base_pointer, int64 offset_elements) { + return LoadBroadcast(base_pointer, ir_builder()->getInt64(offset_elements)); + } + + // Compute the horizontal sum of each vector in `vectors`. The i'th element + // in the result vector is the (scalar) horizontal sum of the i'th vector in + // `vectors`. + std::vector ComputeHorizontalSums( + std::vector vectors); + + llvm::Value* GetZeroVector(); + llvm::Value* GetZeroScalar(); + + llvm::IRBuilder<>* ir_builder() const { return ir_builder_; } + int64 vector_size() const { return vector_size_; } + llvm::Type* vector_type() const { return vector_type_; } + llvm::Type* vector_pointer_type() const { return vector_pointer_type_; } + llvm::Type* scalar_type() const { return scalar_type_; } + llvm::Type* scalar_pointer_type() const { return scalar_pointer_type_; } + + const std::string& name() const { return name_; } + + private: + llvm::Value* ExtractLowHalf(llvm::Value*); + llvm::Value* ExtractHighHalf(llvm::Value*); + + llvm::Value* MulInternal(llvm::Value* lhs, llvm::Value* rhs); + llvm::Value* AddInternal(llvm::Value* lhs, llvm::Value* rhs); + + llvm::Value* AddReduce(llvm::Value* vector); + + // Perform an X86 AVX style horizontal add between `lhs` and `rhs`. The + // resulting IR for an 8-float wide vector is expected to lower to a single + // vhaddps instruction on a CPU that supports vhaddps, and not be too bad in + // other cases. + // + // For a vector width of 8, the result vector is computed as: + // Result[0] = Lhs[0] + Lhs[1] + // Result[1] = Lhs[2] + Lhs[3] + // Result[2] = Rhs[0] + Rhs[1] + // Result[3] = Rhs[2] + Rhs[3] + // Result[4] = Lhs[4] + Lhs[5] + // Result[5] = Lhs[6] + Lhs[7] + // Result[6] = Rhs[4] + Rhs[5] + // Result[7] = Rhs[6] + Rhs[7] + llvm::Value* AvxStyleHorizontalAdd(llvm::Value* lhs, llvm::Value* rhs); + + std::vector ComputeAvxOptimizedHorizontalSums( + std::vector vectors); + + int64 vector_size_; + PrimitiveType primitive_type_; + llvm::IRBuilder<>* ir_builder_; + llvm::Type* vector_type_; + llvm::Type* vector_pointer_type_; + llvm::Type* scalar_type_; + llvm::Type* scalar_pointer_type_; + std::string name_; +}; + +// This wraps an alloca-backed stack variable which LLVM's SSA construction pass +// can later convert to a SSA value. +class LlvmVariable { + public: + LlvmVariable(llvm::Type*, llvm::IRBuilder<>* ir_builder); + + llvm::Value* Get() const; + void Set(llvm::Value* new_value); + + private: + llvm::AllocaInst* alloca_; + llvm::IRBuilder<>* ir_builder_; +}; + +class VectorVariable : public LlvmVariable { + public: + VectorVariable(VectorSupportLibrary* vector_support, + llvm::Value* initial_value) + : LlvmVariable(vector_support->vector_type(), + vector_support->ir_builder()) { + Set(initial_value); + } +}; + +class ScalarVariable : public LlvmVariable { + public: + ScalarVariable(VectorSupportLibrary* vector_support, + llvm::Value* initial_value) + : LlvmVariable(vector_support->scalar_type(), + vector_support->ir_builder()) { + Set(initial_value); + } +}; +} // namespace xla + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_VECTOR_SUPPORT_LIBRARY_H_ diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc index d4d35da9d636e6e204f36850e7987327ab258696..06f43bd3cb2376d34a3104133c868c4f4e5cc730 100644 --- a/tensorflow/compiler/xla/service/local_service.cc +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -68,26 +68,6 @@ LocalService::LocalService(const ServiceOptions& options, std::unique_ptr execute_backend) : Service(options, std::move(execute_backend)) {} -namespace { -// Returns the space required to allocate a shape. If -// allocate_space_for_deep_copy the space includes all sub-buffers of -// a tuple. -int64 RequiredSpace(const Shape& shape, bool allocate_space_for_deep_copy, - TransferManager* transfer_manager) { - int64 size = 0; - // TODO(b/33492279) remove once no devices represent result tuples as - // contiguous buffers. - if (allocate_space_for_deep_copy) { - ShapeUtil::ForEachSubshape( - shape, [&size, transfer_manager](const Shape& subshape, - const ShapeIndex& /*index*/) { - size += transfer_manager->GetByteSizeRequirement(subshape); - }); - } - return size; -} -} // namespace - StatusOr> LocalService::CompileExecutable( const ComputationHandle& computation, const tensorflow::gtl::ArraySlice argument_layouts, diff --git a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc index b92017c6cbc43d78ab4e5b32f25f5980b8d4ae56..6aca6ba38572c5311797fbb91acbbcd6610a3410 100644 --- a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc +++ b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc @@ -23,6 +23,23 @@ limitations under the License. namespace xla { +namespace { + +// Gather fusion instructions from 'instruction' into 'fusion_instructions'. +void GatherFusionInstructions( + HloInstruction* instruction, + std::vector* fusion_instructions) { + CHECK_EQ(HloOpcode::kFusion, instruction->opcode()); + for (auto* fused : instruction->fused_instructions()) { + if (fused->opcode() == HloOpcode::kFusion) { + GatherFusionInstructions(fused, fusion_instructions); + } + } + fusion_instructions->push_back(instruction); +} + +} // namespace + /* static */ StatusOr> LogicalBufferAnalysis::Run(const HloModule* module) { std::unique_ptr analysis( @@ -41,15 +58,19 @@ Status LogicalBufferAnalysis::Analyze() { // We filter out fusion computations, and get to them through fusion // instructions. This is because it's possible to have orphaned (unreachable) // fusion computations, and we don't want to try to assign buffers to those. + std::vector fusion_instructions; for (auto* computation : module_->MakeNonfusionComputations()) { TF_RETURN_IF_ERROR(computation->Accept(this)); for (auto* instruction : computation->instructions()) { if (instruction->opcode() != HloOpcode::kFusion) { continue; } - TF_RETURN_IF_ERROR(instruction->fused_expression_root()->Accept(this)); + GatherFusionInstructions(instruction, &fusion_instructions); } } + for (auto* instruction : fusion_instructions) { + TF_RETURN_IF_ERROR(instruction->fused_expression_root()->Accept(this)); + } return Status::OK(); } @@ -104,6 +125,21 @@ Status LogicalBufferAnalysis::HandleBitcast(HloInstruction*) { return Status::OK(); } +Status LogicalBufferAnalysis::HandleRecvDone(HloInstruction*) { + // RecvDone doesn't create a new buffer but rather aliases its input (Recv) + // tuple element at {0} to its output. + return Status::OK(); +} + +Status LogicalBufferAnalysis::HandleSend(HloInstruction* send) { + // Send creates new buffers for the top-level tuple and the context (tuple + // element at {1}). Tuple element at {0} is an alias of the Send operand, so + // we don't need to create a new Logical Buffer for that. + NewLogicalBuffer(send, /*index=*/{}); + NewLogicalBuffer(send, /*index=*/{1}); + return Status::OK(); +} + Status LogicalBufferAnalysis::HandleTuple(HloInstruction* tuple) { // A Tuple instruction only creates the top-level buffer. NewLogicalBuffer(tuple, /*index=*/{}); diff --git a/tensorflow/compiler/xla/service/logical_buffer_analysis.h b/tensorflow/compiler/xla/service/logical_buffer_analysis.h index a82e83ec5c3d2b0e011d85f3d03bea8fca870154..598d08b7203b25b194dfc3b3125ec58c96b2cd4c 100644 --- a/tensorflow/compiler/xla/service/logical_buffer_analysis.h +++ b/tensorflow/compiler/xla/service/logical_buffer_analysis.h @@ -60,6 +60,8 @@ class LogicalBufferAnalysis : public DfsHloVisitorWithDefault { Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; Status HandleBitcast(HloInstruction* bitcast) override; Status HandleCopy(HloInstruction* copy) override; + Status HandleRecvDone(HloInstruction* recv_done) override; + Status HandleSend(HloInstruction* send) override; Status HandleSelect(HloInstruction* select) override; // A map from the buffer ID to the logical buffer diff --git a/tensorflow/compiler/xla/service/name_uniquer.cc b/tensorflow/compiler/xla/service/name_uniquer.cc index a0d08c288dbcc45e83a36ce7b094b04a9dbae532..7d8c05fffa4ab11d7dbf9956d2cb7ebd5bcdd3c4 100644 --- a/tensorflow/compiler/xla/service/name_uniquer.cc +++ b/tensorflow/compiler/xla/service/name_uniquer.cc @@ -17,12 +17,44 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" namespace xla { +namespace { + +bool IsAllowed(char character) { + auto c = static_cast(character); + return (isalnum(c) != 0) || c == '_' || c == '.' || c == '-'; +} + +} // namespace + +NameUniquer::NameUniquer(const string& separator) { + CHECK(std::all_of(separator.begin(), separator.end(), IsAllowed)) + << "separator should comprises allowed characters only"; + separator_ = separator; +} + +/*static*/ string NameUniquer::GetSanitizedName(const string& name) { + string result = name; + CHECK(!result.empty()) << "name should not be empty"; + char c = static_cast(result[0]); + if (!isalpha(c) && c != '_') { + result[0] = '_'; + } + for (int i = 1; i < result.length(); i++) { + if (!IsAllowed(result[i])) { + result[i] = '_'; + } + } + return result; +} + string NameUniquer::GetUniqueName(tensorflow::StringPiece prefix) { string root = prefix.empty() ? "name" : prefix.ToString(); + root = GetSanitizedName(root); // Strip away numeric suffix (if any). Only recognize separator if it is in // the middle of the name. diff --git a/tensorflow/compiler/xla/service/name_uniquer.h b/tensorflow/compiler/xla/service/name_uniquer.h index ed379b52258463b960dea788721c2c4325ef0260..4139c2700b25e8600182a034a8ac6f4f041c12e6 100644 --- a/tensorflow/compiler/xla/service/name_uniquer.h +++ b/tensorflow/compiler/xla/service/name_uniquer.h @@ -28,14 +28,21 @@ namespace xla { // Simple stateful class that helps generate "unique" names. To use it, simply // call GetUniqueName as many times as needed. The names returned by // GetUniqueName are guaranteed to be distinct for this instance of the class. +// Note that the names will be sanitized to match regexp +// "[a-zA-Z_][a-zA-Z0-9_.-]*". class NameUniquer { public: - explicit NameUniquer(const string& separator = "__") - : separator_(separator) {} + // The separator must contain allowed characters only: "[a-zA-Z0-9_.-]". + explicit NameUniquer(const string& separator = "__"); - // Get a unique name in a string, with an optional prefix for convenience. + // Get a sanitized unique name in a string, with an optional prefix for + // convenience. string GetUniqueName(tensorflow::StringPiece prefix = ""); + // Sanitizes and returns the name. Unallowed characters will be replaced with + // '_'. The result will match the regexp "[a-zA-Z_][a-zA-Z0-9_.-]*". + static string GetSanitizedName(const string& name); + private: // The string to use to separate the prefix of the name from the uniquing // integer value. diff --git a/tensorflow/compiler/xla/service/name_uniquer_test.cc b/tensorflow/compiler/xla/service/name_uniquer_test.cc index 9f0747a6e2175a968d8f3661ac51512009e86f29..4258cf16876ab46dce6df062ab701b1b1a4a7580 100644 --- a/tensorflow/compiler/xla/service/name_uniquer_test.cc +++ b/tensorflow/compiler/xla/service/name_uniquer_test.cc @@ -60,12 +60,30 @@ TEST_F(NameUniquerTest, NumericSuffixes) { EXPECT_EQ("bar", uniquer.GetUniqueName("bar.-1000")); EXPECT_EQ("bar.1", uniquer.GetUniqueName("bar.-2000")); EXPECT_EQ("bar.2", uniquer.GetUniqueName("bar.1")); +} + +TEST_F(NameUniquerTest, Sanitize) { + NameUniquer uniquer("_"); + + EXPECT_EQ("foo", uniquer.GetUniqueName("foo")); + EXPECT_EQ("foo_1", uniquer.GetUniqueName("foo")); + EXPECT_EQ("foo.54", uniquer.GetUniqueName("foo.54")); + EXPECT_EQ("foo_54", uniquer.GetUniqueName("foo_54")); + EXPECT_EQ("foo_54.1", uniquer.GetUniqueName("foo_54.1")); + EXPECT_EQ("foo_55", uniquer.GetUniqueName("foo")); + + // Invalid characters will be replaced with '_'. + EXPECT_EQ("bar", uniquer.GetUniqueName("bar<-1000")); + EXPECT_EQ("bar_1", uniquer.GetUniqueName("bar<-2000")); + EXPECT_EQ("bar_2", uniquer.GetUniqueName("bar_1")); // Separator is only recognized in the middle of the prefix. - EXPECT_EQ(".10", uniquer.GetUniqueName(".10")); - EXPECT_EQ(".10.1", uniquer.GetUniqueName(".10")); - EXPECT_EQ("foobar.", uniquer.GetUniqueName("foobar.")); - EXPECT_EQ("foobar..1", uniquer.GetUniqueName("foobar.")); + EXPECT_EQ("_10", uniquer.GetUniqueName( + ".10")); // the leading '.' is replaced with '_'. + EXPECT_EQ("_10_1", uniquer.GetUniqueName(".10")); + EXPECT_EQ("_10_2", uniquer.GetUniqueName("_10")); + EXPECT_EQ("foobar_", uniquer.GetUniqueName("foobar_")); + EXPECT_EQ("foobar__1", uniquer.GetUniqueName("foobar_")); } } // namespace diff --git a/tensorflow/compiler/xla/service/platform_util.cc b/tensorflow/compiler/xla/service/platform_util.cc index 3a1818de82d3fd305e2c6b3bd1f2cf8125806a75..aa974ee61a27de9c19e97d8a6eb48f9261ce4bd9 100644 --- a/tensorflow/compiler/xla/service/platform_util.cc +++ b/tensorflow/compiler/xla/service/platform_util.cc @@ -33,10 +33,32 @@ namespace se = ::perftools::gputools; namespace xla { +using tensorflow::str_util::Lowercase; + // Minimum supported CUDA compute capability is 3.5. constexpr int kMinCudaComputeCapabilityMajor = 3; constexpr int kMinCudaComputeCapabilityMinor = 5; +// The name of the interpreter platform. +constexpr char kInterpreter[] = "interpreter"; + +namespace { + +string CanonicalPlatformName(const string& name) { + string platform_str = Lowercase(name); + // "cpu" and "host" mean the same thing. + if (platform_str == "cpu") { + platform_str = "host"; + } + // "gpu" and "cuda" mean the same thing. + if (platform_str == "gpu") { + platform_str = "cuda"; + } + return platform_str; +} + +} // namespace + /* static */ StatusOr> PlatformUtil::GetSupportedPlatforms() { se::MultiPlatformManager::PlatformMap platform_map; @@ -78,7 +100,7 @@ PlatformUtil::GetSupportedPlatforms() { return platforms; } -/* static */ StatusOr PlatformUtil::GetDefaultPlatform() { +/* static */ StatusOr PlatformUtil::GetSolePlatform() { TF_ASSIGN_OR_RETURN(auto platforms, GetSupportedPlatforms()); if (platforms.empty()) { return NotFound("no platforms found"); @@ -87,13 +109,77 @@ PlatformUtil::GetSupportedPlatforms() { } // Multiple platforms present and we can't pick a reasonable default. - auto l = [](string* out, const se::Platform* p) { out->append(p->Name()); }; - string platforms_string = tensorflow::str_util::Join(platforms, ", ", l); + string platforms_string = tensorflow::str_util::Join( + platforms, ", ", + [](string* out, const se::Platform* p) { out->append(p->Name()); }); return InvalidArgument( "must specify platform because more than one platform found: %s", platforms_string.c_str()); } +/* static */ StatusOr PlatformUtil::GetDefaultPlatform() { + TF_ASSIGN_OR_RETURN(auto platforms, GetSupportedPlatforms()); + if (platforms.empty()) { + return NotFound("no platforms found"); + } else if (platforms.size() == 1) { + return platforms[0]; + } else if (platforms.size() == 2) { + for (int i = 0; i < 2; i++) { + if (Lowercase(platforms[i]->Name()) == kInterpreter && + Lowercase(platforms[1 - i]->Name()) != kInterpreter) { + return platforms[1 - i]; + } + } + } + + // Multiple platforms present and we can't pick a reasonable default. + string platforms_string = tensorflow::str_util::Join( + platforms, ", ", + [](string* out, const se::Platform* p) { out->append(p->Name()); }); + return InvalidArgument( + "must specify platform because more than one platform (except for the " + "interpreter platform) found: %s", + platforms_string.c_str()); +} + +/*static*/ StatusOr PlatformUtil::GetPlatform( + const string& platform_name) { + string platform_str = CanonicalPlatformName(platform_name); + TF_ASSIGN_OR_RETURN(auto platforms, PlatformUtil::GetSupportedPlatforms()); + for (se::Platform* platform : platforms) { + if (Lowercase(platform->Name()) == platform_str) { + return platform; + } + } + return InvalidArgument("platform %s not found", platform_name.c_str()); +} + +/*static*/ StatusOr PlatformUtil::GetPlatformExceptFor( + const string& platform_name) { + string platform_str = CanonicalPlatformName(platform_name); + + TF_ASSIGN_OR_RETURN(auto platforms, PlatformUtil::GetSupportedPlatforms()); + std::vector matched; + for (se::Platform* platform : platforms) { + if (Lowercase(platform->Name()) != platform_name) { + matched.push_back(platform); + } + } + if (matched.empty()) { + return InvalidArgument("unable to find platform that is not %s", + platform_name.c_str()); + } + if (matched.size() == 1) { + return matched[0]; + } + string matched_string = tensorflow::str_util::Join( + matched, ", ", + [](string* out, const se::Platform* p) { out->append(p->Name()); }); + return InvalidArgument( + "found multiple platforms %s, but expected one platform except for %s", + matched_string.c_str(), platform_name.c_str()); +} + // Returns whether the device underlying the given StreamExecutor is supported // by XLA. static bool IsDeviceSupported(se::StreamExecutor* executor) { diff --git a/tensorflow/compiler/xla/service/platform_util.h b/tensorflow/compiler/xla/service/platform_util.h index eac573703085aca2801885cd9abbe0022f1c029e..69188820a70707d9c9be10b20fb7de92ad4d9873 100644 --- a/tensorflow/compiler/xla/service/platform_util.h +++ b/tensorflow/compiler/xla/service/platform_util.h @@ -16,11 +16,14 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_PLATFORM_UTIL_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_PLATFORM_UTIL_H_ +#include #include #include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/platform/types.h" namespace xla { @@ -34,10 +37,27 @@ class PlatformUtil { static StatusOr> GetSupportedPlatforms(); - // Convenience function which returns the default supported platform. If + // Convenience function which returns the default supported platform for + // tests. If exactly one supported platform is present, then this platform is + // the default platform. If exactly two platforms are present and one of them + // is the interpreter platform, then the other platform is the default + // platform. Otherwise returns an error. + static StatusOr GetDefaultPlatform(); + + // Convenience function which returns the sole supported platform. If // exactly one supported platform is present, then this platform is the // default platform. Otherwise returns an error. - static StatusOr GetDefaultPlatform(); + static StatusOr GetSolePlatform(); + + // Returns the platform according to the given name. Returns error if there is + // no such platform. + static StatusOr GetPlatform( + const string& platform_name); + + // Returns exactly one platform that does not have given name. Returns error + // if there is no such platform, or there are multiple such platforms. + static StatusOr GetPlatformExceptFor( + const string& platform_name); // Returns a vector of StreamExecutors for the given platform. The vector is // indexed by device ordinal (device numbering used by StreamExecutor). If an diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index bac33d8102e07766531a4ce6eac77aff4971bfef..fe6993db983ef66f5de5a8eee1ed277318a7f7ee 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -430,9 +430,12 @@ StatusOr> Service::BuildExecutable( /*include_unreachable_instructions=*/ true)); + TF_ASSIGN_OR_RETURN( + module, backend->compiler()->RunHloPasses(std::move(module), executor)); + TF_ASSIGN_OR_RETURN( std::unique_ptr executable, - backend->compiler()->Compile(std::move(module), executor)); + backend->compiler()->RunBackend(std::move(module), executor)); if (!other_directory_path.empty()) { executable->set_session_module(std::move(session_module)); @@ -490,14 +493,20 @@ Service::ExecuteParallelAndRegisterResult( std::vector> arguments, Backend* backend, tensorflow::gtl::ArraySlice device_handles, - tensorflow::gtl::ArraySlice result_tags) { + tensorflow::gtl::ArraySlice result_tags, + ExecutionProfile* profile) { // Streams where the computation are launched, so we can wait on the streams // to complete. std::vector::SmartPtr> streams; + std::vector> timers; // Global data handles for the computation results, one for each computation. std::vector result_handles; + // Device ID to stream executor, populated only with devices that are being + // profiled. + std::map index_to_profiled_streams; + TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment, backend->computation_placer()->AssignDevices( options_.number_of_replicas(), executables.size())); @@ -510,6 +519,21 @@ Service::ExecuteParallelAndRegisterResult( backend->BorrowStream(replicas[replica])); streams.push_back(std::move(stream)); + if (replica == 0 && profile != nullptr) { + timers.emplace_back( + new perftools::gputools::Timer(streams.back()->parent())); + streams.back() + ->InitTimer(timers.back().get()) + .ThenStartTimer(timers.back().get()); + CHECK(timers.front() != nullptr); + } + + if (replica == 0 && + executables[i]->module_config().debug_options().xla_hlo_profile() && + executables[i]->hlo_profiling_enabled()) { + index_to_profiled_streams[i] = streams.back().get(); + } + // Set up run options. ExecutableRunOptions options; options.set_stream(streams.back().get()); @@ -526,6 +550,10 @@ Service::ExecuteParallelAndRegisterResult( perftools::gputools::DeviceMemoryBase result, executables[i]->ExecuteAsyncOnStream(&run_options, arguments[i])); + if (replica == 0 && profile != nullptr) { + streams.back()->ThenStopTimer(timers.back().get()); + } + // All replicas share the same device address for the result allocation, // so only one of the replicas need to register the result handle. if (replica == 0) { @@ -543,6 +571,55 @@ Service::ExecuteParallelAndRegisterResult( } } + // For every stream that had profiling enabled, obtain and debug-dump the HLO + // profile. + for (auto& index_to_profiled_stream : index_to_profiled_streams) { + int64 device = index_to_profiled_stream.first; + se::Stream* stream = index_to_profiled_stream.second; + Executable* executable = executables[device]; + const HloModule& module = executable->module(); + HloExecutionProfile hlo_profile(&executable->hlo_profile_printer(), + &executable->hlo_profile_index_map()); + TF_RETURN_IF_ERROR( + executable->PopulateExecutionProfile(&hlo_profile, stream->parent())); + XLA_LOG_LINES( + tensorflow::INFO, + hlo_profile.ToString(streams[0]->parent()->GetDeviceDescription())); + hlo_graph_dumper::MaybeDumpHloModule(module, "Service::Execute", + &hlo_profile); + } + + if (profile != nullptr) { + CHECK(!timers.empty()); + std::vector timer_nanoseconds; + timer_nanoseconds.reserve(timers.size()); + for (auto& timer : timers) { + timer_nanoseconds.push_back(timer->Nanoseconds()); + } + uint64 nanoseconds = + *std::max_element(timer_nanoseconds.begin(), timer_nanoseconds.end()); + + // Merge in run-time profile information from execution_profile on the + // zeroth device. + profile->MergeFrom(executables[0]->execution_profile()); + + // Overall execution time (in nanoseconds) from the executor timer. + profile->set_compute_and_transfer_time_ns(nanoseconds); + + // TODO(b/28123297): On GPU we end up including transfer time in + // the compute time this way. Instead, we should get the correct + // value by measuring it. Setting the field here at least lets + // benchmarks provide *some* value for GPU computations. + // + // TODO(b/28447609): The value in compute_and_transfer_time_ns is actually + // the compute time without the transfer time, so this way we get the + // correct compute time. We should instead have the correct value for + // compute_and_transfer_time and set compute_time to the compute time. + if (profile->compute_time_ns() == 0) { + profile->set_compute_time_ns(profile->compute_and_transfer_time_ns()); + } + } + return result_handles; } @@ -589,6 +666,7 @@ StatusOr Service::ExecuteAndRegisterResult( result, executable->ExecuteOnStreamWrapper( &run_options[0], profile, arguments)); } else { + // TODO(b/69985541): Support profiling also on this path. std::vector< tensorflow::gtl::ArraySlice> repeated_arguments(options_.number_of_replicas(), arguments); @@ -715,14 +793,16 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, // Execute the generated executables in parallel and return the device // handles for each computation's output. + ExecutionProfile profile; TF_ASSIGN_OR_RETURN( std::vector outputs, ExecuteParallelAndRegisterResult(executable_ptrs, all_arguments, execute_backend_.get(), device_handles, - computation_names)); + computation_names, &profile)); for (const GlobalDataHandle& output : outputs) { ExecuteResponse response; *response.mutable_output() = output; + *response.mutable_profile() = profile; *result->add_responses() = response; } @@ -963,18 +1043,29 @@ tensorflow::Status Service::TransferToClient(const TransferToClientRequest* arg, return tensorflow::Status::OK(); } +namespace { + +// Creates a clone of the given shaped buffer with the given device ordinal. The +// shape and DeviceMemoryBase values of the clone are identical to the original. +std::unique_ptr CloneShapedBufferOnDevice( + const ShapedBuffer& shaped_buffer, int device_ordinal) { + auto clone = MakeUnique( + shaped_buffer.shape(), shaped_buffer.platform(), device_ordinal); + ShapeUtil::ForEachSubshape( + shaped_buffer.shape(), [&clone, &shaped_buffer](const Shape& /*subshape*/, + const ShapeIndex& index) { + clone->AddBufferAtIndex(shaped_buffer.buffer(index), index); + }); + return clone; +} + +} // namespace + tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg, TransferToServerResponse* result) { Literal literal = Literal(arg->literal()); const Shape& shape = literal.shape(); - if (ShapeUtil::IsTuple(shape) && options_.number_of_replicas() > 1) { - // TODO(b/32990684): Tuple transfers to host end up allocating further - // buffers - implement that correctly. - return Unimplemented( - "Tuple transfers to the device not supported with replication."); - } - std::vector replicas; if (arg->has_device_handle()) { TF_ASSIGN_OR_RETURN(replicas, @@ -984,24 +1075,45 @@ tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg, replicas, Replicas(*execute_backend_, SingleComputationDeviceHandle())); } - // Allocate memory on the device, using the stream executor. The size of the - // allocation is obtained by examining the shape of the literal passed from - // the client. An allocation handle is returned in the response. - int64 allocation_size = - execute_backend_->transfer_manager()->GetByteSizeRequirement(shape); - - TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase allocation, - execute_backend_->memory_allocator()->Allocate( - replicas[0]->device_ordinal(), allocation_size)); - + // All memory allocation is done on the first replica. The allocations in all + // other replicas mirror the firsts'. + int master_device_ordinal = replicas[0]->device_ordinal(); + TF_ASSIGN_OR_RETURN( + std::unique_ptr shaped_buffer, + ShapedBuffer::Allocate( + execute_backend_->transfer_manager()->HostShapeToDeviceShape(shape), + execute_backend_->memory_allocator(), master_device_ordinal, + [this](const Shape& shape) { + return execute_backend_->transfer_manager()->GetByteSizeRequirement( + shape); + })); + + // The allocation tracker only keeps track of the top-level buffer of the + // shape so pass in the buffer at shape index {}. + // TODO(b/37515654): Allocation tracker should hold a ShapedBuffer. *result->mutable_data() = allocation_tracker_.Register( - execute_backend_.get(), replicas[0]->device_ordinal(), allocation, shape, - StrCat("TransferToServer literal of size ", allocation_size)); + execute_backend_.get(), master_device_ordinal, + shaped_buffer->buffer(/*index=*/{}), shape, + StrCat("TransferToServer literal of shape ", + ShapeUtil::HumanString(shape))); + // Transfer the data to the replicas. for (se::StreamExecutor* executor : replicas) { - TF_RETURN_IF_ERROR( - execute_backend_->transfer_manager()->TransferLiteralToDevice( - executor, literal, &allocation)); + if (executor->device_ordinal() == master_device_ordinal) { + TF_RETURN_IF_ERROR( + execute_backend_->transfer_manager()->TransferLiteralToDevice( + executor, literal, *shaped_buffer)); + } else { + // The replica is not the master. Create an cloned shaped buffer with + // the replica's device ordinal. This is required because + // TransferLiteralToDevice verifies that the device ordinal of the shaped + // buffer matches that of the executor. + std::unique_ptr clone = + CloneShapedBufferOnDevice(*shaped_buffer, executor->device_ordinal()); + TF_RETURN_IF_ERROR( + execute_backend_->transfer_manager()->TransferLiteralToDevice( + executor, literal, *clone)); + } } return tensorflow::Status::OK(); } @@ -1082,8 +1194,9 @@ tensorflow::Status Service::IsConstant(const IsConstantRequest* arg, return InvalidArgument("computations may not be empty"); } - TF_ASSIGN_OR_RETURN(bool is_constant, - user_computation->IsConstant(arg->operand())); + TF_ASSIGN_OR_RETURN( + bool is_constant, + user_computation->IsConstant(arg->operand(), arg->num_parameters())); result->set_is_constant(is_constant); return tensorflow::Status::OK(); @@ -1101,8 +1214,9 @@ tensorflow::Status Service::ComputeConstant(const ComputeConstantRequest* arg, return InvalidArgument("computations may not be empty"); } - TF_ASSIGN_OR_RETURN(bool is_constant, - user_computation->IsConstant(arg->operand())); + TF_ASSIGN_OR_RETURN( + bool is_constant, + user_computation->IsConstant(arg->operand(), arg->parameters_size())); if (!is_constant) { return InvalidArgument("Operand to ComputeConstant depends on parameter."); } @@ -1141,8 +1255,18 @@ tensorflow::Status Service::ComputeConstant(const ComputeConstantRequest* arg, /*include_unreachable_instructions=*/ false)); + std::vector parameters(arg->parameters_size()); + for (int64 i = 0; i < arg->parameters_size(); ++i) { + parameters[i] = Literal(arg->parameters(i)); + } + std::vector parameter_ptrs; + std::transform(parameters.begin(), parameters.end(), + std::back_inserter(parameter_ptrs), + [](const Literal& literal) { return &literal; }); + HloEvaluator evaluator; - TF_ASSIGN_OR_RETURN(auto result_literal, evaluator.Evaluate(*module, {})); + TF_ASSIGN_OR_RETURN(auto result_literal, + evaluator.Evaluate(*module, parameter_ptrs)); // Since the shape_with_output_layout option in ExecutionOption is // non-effective to the Evaluator results, explicit relayout here. if (arg->has_output_layout()) { @@ -1266,6 +1390,17 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) { handle_status = computation->AddConcatenateInstruction(arg->concatenate_request()); break; + case OpRequest::kConditionalRequest: { + TF_ASSIGN_OR_RETURN(UserComputation * true_computation, + computation_tracker_.Resolve( + arg->conditional_request().true_computation())); + TF_ASSIGN_OR_RETURN(UserComputation * false_computation, + computation_tracker_.Resolve( + arg->conditional_request().false_computation())); + handle_status = computation->AddConditionalInstruction( + arg->conditional_request(), *true_computation, *false_computation); + break; + } case OpRequest::kConstantRequest: handle_status = computation->AddConstantInstruction(arg->constant_request()); @@ -1274,6 +1409,10 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) { handle_status = computation->AddConvertInstruction(arg->convert_request()); break; + case OpRequest::kBitcastConvertRequest: + handle_status = computation->AddBitcastConvertInstruction( + arg->bitcast_convert_request()); + break; case OpRequest::kConvolveRequest: handle_status = computation->AddConvolveInstruction(arg->convolve_request()); @@ -1286,6 +1425,9 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) { handle_status = computation->AddCustomCallInstruction(arg->custom_call_request()); break; + case OpRequest::kDotRequest: + handle_status = computation->AddDotInstruction(arg->dot_request()); + break; case OpRequest::kDynamicSliceRequest: handle_status = computation->AddDynamicSliceInstruction(arg->dynamic_slice_request()); @@ -1406,8 +1548,12 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) { handle_status = computation->AddRecvInstruction(arg->recv_request()); break; } + case OpRequest::kFftRequest: + return Unimplemented("FftRequest not implemented in XLA service."); + case OpRequest::OP_NOT_SET: + return InvalidArgument("XLA service received OpRequest with OP_NOT_SET"); default: - return InvalidArgument("Unsupported operation"); + return InvalidArgument("Unsupported operation in XLA service"); } TF_ASSIGN_OR_RETURN(*result->mutable_output(), handle_status); diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h index 2452259f736054b5bf1f03fc5103d65eded7f398..47f4f0ade594089aa71717ef1e122886b0a6c7ac 100644 --- a/tensorflow/compiler/xla/service/service.h +++ b/tensorflow/compiler/xla/service/service.h @@ -272,8 +272,6 @@ class Service : public ServiceInterface { // Create a Hlo module config for the given program shape and arguments. // execution_options is optional; if not given a default is used. - // has_hybrid_result is used to initialize the same-named field in - // HloModuleConfig -- see that class for documentation. StatusOr> CreateModuleConfig( const ProgramShape& program_shape, tensorflow::gtl::ArraySlice argument_shapes, @@ -327,7 +325,8 @@ class Service : public ServiceInterface { arguments, Backend* backend, tensorflow::gtl::ArraySlice device_handles, - tensorflow::gtl::ArraySlice result_tags); + tensorflow::gtl::ArraySlice result_tags, + ExecutionProfile* profile); // Convenience function for adding a function to a user computation. template diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 791d17365b1d756714b5feb0439e6919d9f23edc..9c1b951d017569a6dc89bc6583c72b5e42f0c07c 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -29,8 +29,10 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/math/math_util.h" #include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" @@ -89,8 +91,6 @@ BinaryOperation OpcodeToBinaryOperation(HloOpcode opcode) { return BINOP_ATAN2; case HloOpcode::kComplex: return BINOP_COMPLEX; - case HloOpcode::kDot: - return BINOP_DOT; case HloOpcode::kMultiply: return BINOP_MUL; case HloOpcode::kAdd: @@ -440,6 +440,37 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, /* static */ StatusOr ShapeInference::InferConvertShape( const Shape& operand_shape, PrimitiveType new_element_type) { + auto old_element_type = operand_shape.element_type(); + if (primitive_util::IsComplexType(old_element_type) && + !primitive_util::IsComplexType(new_element_type)) { + return Unimplemented( + "Unsupported conversion from complex to real type: %s => %s", + ShapeUtil::HumanString(operand_shape).c_str(), + PrimitiveType_Name(new_element_type).c_str()); + } + if (ShapeUtil::IsTuple(operand_shape) || new_element_type == TUPLE) { + // Note: we may want to support tuple conversions via this operation in the + // future, by recursing into the tuple elements to check all sub-conversions + // are valid. For now we just reject them, though. + return InvalidArgument( + "cannot convert from or to tuple type; requested conversion: %s => %s", + ShapeUtil::HumanString(operand_shape).c_str(), + PrimitiveType_Name(new_element_type).c_str()); + } + + return ShapeUtil::ChangeElementType(operand_shape, new_element_type); +} + +/* static */ StatusOr ShapeInference::InferBitcastConvertShape( + const Shape& operand_shape, PrimitiveType new_element_type) { + auto old_element_type = operand_shape.element_type(); + if (primitive_util::IsComplexType(old_element_type) != + primitive_util::IsComplexType(new_element_type)) { + return Unimplemented( + "Unsupported conversion between real and complex types: %s => %s", + ShapeUtil::HumanString(operand_shape).c_str(), + PrimitiveType_Name(new_element_type).c_str()); + } if (ShapeUtil::IsTuple(operand_shape) || new_element_type == TUPLE) { // Note: we may want to support tuple conversions via this operation in the // future, by recursing into the tuple elements to check all sub-conversions @@ -449,6 +480,13 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, ShapeUtil::HumanString(operand_shape).c_str(), PrimitiveType_Name(new_element_type).c_str()); } + if (primitive_util::BitWidth(old_element_type) != + primitive_util::BitWidth(new_element_type)) { + return InvalidArgument( + "cannot bitcast types with different bit-widths: %s => %s", + PrimitiveType_Name(old_element_type).c_str(), + PrimitiveType_Name(new_element_type).c_str()); + } return ShapeUtil::ChangeElementType(operand_shape, new_element_type); } @@ -510,8 +548,113 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, return ShapeUtil::MakeShape(operand_shape.element_type(), dimensions); } -/* static */ StatusOr ShapeInference::InferDotOpShape(const Shape& lhs, - const Shape& rhs) { +// Current DotDimensionNumbers Requirements: +// +// Contracting Dimensions: +// *) Exactly one contracting dimension on both lhs and rhs. +// *) Contracting dimension size must be the same on both lhs and rhs. +// *) Contracting dimension numbers do not need to be the same (i.e. transposes +// are passed on to emitter implementations). +// +// Batch Dimensions: +// *) Same number of batch dimensions on both lhs and rhs. +// *) Same batch dimension numbers (and sizes) on both lhs and rhs. +// *) Batch dimension numbers must be ordered before contracting and +// non-contracting/non-batch dimension numbers. +// +// Non-Contracting-Non-Batch Dimensions: +// *) Can be 0 (matrix-vector) or 1 (matrix-matrix). +// + +namespace { + +Status ValidateDotDimensionNumbers( + const Shape& lhs, const Shape& rhs, + const DotDimensionNumbers& dimension_numbers) { + // Check that dimension numbers are in range. + auto dims_in_range = + [](const int64 rank, tensorflow::gtl::ArraySlice contracting_dims, + tensorflow::gtl::ArraySlice batch_dims) -> bool { + auto in_range = [&rank](int64 i) -> bool { return 0 <= i && i < rank; }; + return std::all_of(contracting_dims.begin(), contracting_dims.end(), + in_range) && + std::all_of(batch_dims.begin(), batch_dims.end(), in_range); + }; + + tensorflow::gtl::ArraySlice lhs_contracting_dimensions = + AsInt64Slice(dimension_numbers.lhs_contracting_dimensions()); + tensorflow::gtl::ArraySlice rhs_contracting_dimensions = + AsInt64Slice(dimension_numbers.rhs_contracting_dimensions()); + tensorflow::gtl::ArraySlice lhs_batch_dimensions = + AsInt64Slice(dimension_numbers.lhs_batch_dimensions()); + tensorflow::gtl::ArraySlice rhs_batch_dimensions = + AsInt64Slice(dimension_numbers.rhs_batch_dimensions()); + + if (!dims_in_range(ShapeUtil::Rank(lhs), lhs_contracting_dimensions, + lhs_batch_dimensions) || + !dims_in_range(ShapeUtil::Rank(rhs), rhs_contracting_dimensions, + rhs_batch_dimensions)) { + return InvalidArgument("A dimension number is out of range in dot: %s", + dimension_numbers.DebugString().c_str()); + } + + // Check that dimension numbers are unique. + auto dims_unique = [](tensorflow::gtl::ArraySlice contracting_dims, + tensorflow::gtl::ArraySlice batch_dims) -> bool { + tensorflow::gtl::FlatSet dim_set; + auto is_unique = [&dim_set](int64 i) -> bool { + return dim_set.insert(i).second; + }; + return std::all_of(contracting_dims.begin(), contracting_dims.end(), + is_unique) && + std::all_of(batch_dims.begin(), batch_dims.end(), is_unique); + }; + + if (!dims_unique(lhs_contracting_dimensions, lhs_batch_dimensions) || + !dims_unique(rhs_contracting_dimensions, rhs_batch_dimensions)) { + return InvalidArgument("A dimension number is not unique in dot: %s", + dimension_numbers.DebugString().c_str()); + } + + // Check that the count of non-contracting-non-batch dimensions is in {0, 1}. + const int64 lhs_non_contracting_non_batch_dims = + ShapeUtil::Rank(lhs) - + dimension_numbers.lhs_contracting_dimensions_size() - + dimension_numbers.lhs_batch_dimensions_size(); + const int64 rhs_non_contracting_non_batch_dims = + ShapeUtil::Rank(rhs) - + dimension_numbers.rhs_contracting_dimensions_size() - + dimension_numbers.rhs_batch_dimensions_size(); + if (lhs_non_contracting_non_batch_dims < 0 || + lhs_non_contracting_non_batch_dims > 1 || + rhs_non_contracting_non_batch_dims < 0 || + rhs_non_contracting_non_batch_dims > 1) { + return InvalidArgument( + "batch and contracting dimension number mismatch " + "with rank "); + } + + // Check that batch dimension numbers are ordered before all others, and + // that they are monotonically increasing. + std::vector batch_dim_numbers(lhs_batch_dimensions.size()); + std::iota(batch_dim_numbers.begin(), batch_dim_numbers.end(), 0); + if (!std::equal(batch_dim_numbers.begin(), batch_dim_numbers.end(), + lhs_batch_dimensions.begin()) || + !std::equal(batch_dim_numbers.begin(), batch_dim_numbers.end(), + rhs_batch_dimensions.begin())) { + return InvalidArgument( + "batch dimension numbers must precede non-batch dimensions and be" + "monotonically increasing."); + } + + return Status::OK(); +} + +} // namespace + +/* static */ StatusOr ShapeInference::InferDotOpShape( + const Shape& lhs, const Shape& rhs, + const DotDimensionNumbers& dimension_numbers) { TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(lhs, "lhs of dot")); TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(rhs, "rhs of dot")); @@ -531,37 +674,62 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, return fail("element types do not match"); } - if (ShapeUtil::Rank(lhs) < 1 || ShapeUtil::Rank(lhs) > 2 || - ShapeUtil::Rank(rhs) < 1 || ShapeUtil::Rank(rhs) > 2) { - return fail("dot only supports rank 1 or 2"); + if ((ShapeUtil::Rank(lhs) < 1) || (ShapeUtil::Rank(rhs) < 1)) { + return fail("dot only supports rank 1 or above."); + } + + // Validate basic properties of dot dimension numbers. + TF_RETURN_IF_ERROR(ValidateDotDimensionNumbers(lhs, rhs, dimension_numbers)); + + // Check that there is only one contracting dimension for both lhs and rhs. + if (dimension_numbers.lhs_contracting_dimensions_size() != + dimension_numbers.rhs_contracting_dimensions_size() || + dimension_numbers.lhs_contracting_dimensions_size() != 1) { + return fail("must specify one contracting dimension for both lhs and rhs."); + } + + // Check that contracting dimension sizes match. + const int64 lhs_contracting_dimension = + dimension_numbers.lhs_contracting_dimensions(0); + const int64 rhs_contracting_dimension = + dimension_numbers.rhs_contracting_dimensions(0); + if (lhs.dimensions(lhs_contracting_dimension) != + rhs.dimensions(rhs_contracting_dimension)) { + return fail("contracting dimension sizes do not match."); } - // Determine the index of the contracted dimensions for input tensors. - // dimensions -1 of lhs and dimension 0 of rhs are contracted. - int64 lhs_contracted_dimension = ShapeUtil::GetDimensionNumber(lhs, -1); - int64 rhs_contracted_dimension = 0; + // Check that number of batch dimensions match. + if (dimension_numbers.lhs_batch_dimensions_size() != + dimension_numbers.rhs_batch_dimensions_size()) { + return fail("must the same number of batch dimensions for lhs and rhs."); + } - // Check if the contracted dimension sizes are the same. - if ((lhs_contracted_dimension < ShapeUtil::Rank(lhs) && - rhs_contracted_dimension < ShapeUtil::Rank(rhs)) && - lhs.dimensions(lhs_contracted_dimension) != - rhs.dimensions(rhs_contracted_dimension)) { - return fail("contracted dimensions mismatch"); + // Check that batch dimension numbers and sizes match. + for (int64 i = 0; i < dimension_numbers.lhs_batch_dimensions_size(); ++i) { + if (dimension_numbers.lhs_batch_dimensions(i) != + dimension_numbers.rhs_batch_dimensions(i) || + lhs.dimensions(dimension_numbers.lhs_batch_dimensions(i)) != + rhs.dimensions(dimension_numbers.rhs_batch_dimensions(i))) { + return fail("batch dimension numbers and sizes must match for lhs/rhs."); + } } // The ranks of lhs and rhs are decremented by 1 respectively due to the // contraction, and added for the rank of the result. When an input tensor is // a scalar, its contribution to the rank of the result is 0. // Generate the result dimensions in order, rhs dimensions followed by lhs - // dimensions except the contracted dimensions. + // dimensions except the contracted and batch dimensions. std::vector dimensions; + std::unordered_set rhs_batch_dims( + dimension_numbers.rhs_batch_dimensions().begin(), + dimension_numbers.rhs_batch_dimensions().end()); for (int64 i = 0; i < ShapeUtil::Rank(lhs); i++) { - if (i != lhs_contracted_dimension) { + if (i != lhs_contracting_dimension) { dimensions.push_back(lhs.dimensions(i)); } } for (int64 i = 0; i < ShapeUtil::Rank(rhs); i++) { - if (i != rhs_contracted_dimension) { + if (i != rhs_contracting_dimension && rhs_batch_dims.count(i) == 0) { dimensions.push_back(rhs.dimensions(i)); } } @@ -770,11 +938,13 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs)); TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs)); - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(lhs, "lhs of binary operation")); - TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(rhs, "rhs of binary operation")); + TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque( + lhs, tensorflow::strings::StrCat("lhs of binary operation ", + BinaryOperation_Name(operation)))); + TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque( + rhs, tensorflow::strings::StrCat("rhs of binary operation ", + BinaryOperation_Name(operation)))); switch (operation) { - case BINOP_DOT: - return InferDotOpShape(lhs, rhs); case BINOP_MAX: case BINOP_MIN: case BINOP_SUB: @@ -1402,7 +1572,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( ShapeUtil::HumanString(lhs).c_str(), ShapeUtil::HumanString(rhs).c_str()); } - if (dnums.spatial_dimensions_size() != + if (dnums.input_spatial_dimensions_size() != dnums.kernel_spatial_dimensions_size()) { return InvalidArgument( "Both arguments to convolution must have same number of dimensions.\n" @@ -1410,7 +1580,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( window.DebugString().c_str()); } - const int num_spatial_dims = dnums.spatial_dimensions_size(); + const int num_spatial_dims = dnums.input_spatial_dimensions_size(); if (window.dimensions_size() != num_spatial_dims) { return InvalidArgument( "Window must have same number of dimensions as dimension numbers.\n" @@ -1439,8 +1609,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( std::vector input_dnums(num_dims); input_dnums[0] = dnums.input_batch_dimension(); input_dnums[1] = dnums.input_feature_dimension(); - std::copy(dnums.spatial_dimensions().begin(), - dnums.spatial_dimensions().end(), input_dnums.begin() + 2); + std::copy(dnums.input_spatial_dimensions().begin(), + dnums.input_spatial_dimensions().end(), input_dnums.begin() + 2); std::sort(input_dnums.begin(), input_dnums.end()); std::vector window_dnums(num_dims); @@ -1450,12 +1620,20 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( dnums.kernel_spatial_dimensions().end(), window_dnums.begin() + 2); std::sort(window_dnums.begin(), window_dnums.end()); + std::vector output_dnums(num_dims); + output_dnums[0] = dnums.output_batch_dimension(); + output_dnums[1] = dnums.output_feature_dimension(); + std::copy(dnums.output_spatial_dimensions().begin(), + dnums.output_spatial_dimensions().end(), output_dnums.begin() + 2); + std::sort(output_dnums.begin(), output_dnums.end()); + std::vector expected_dnums(num_dims); std::iota(expected_dnums.begin(), expected_dnums.end(), 0); const auto in_range = [num_dims](int64 i) { return 0 <= i && i < num_dims; }; if (!std::all_of(input_dnums.begin(), input_dnums.end(), in_range) || - !std::all_of(window_dnums.begin(), window_dnums.end(), in_range)) { + !std::all_of(window_dnums.begin(), window_dnums.end(), in_range) || + !std::all_of(output_dnums.begin(), output_dnums.end(), in_range)) { return InvalidArgument( "A dimension number is out of range in convolution: %s", dnums.DebugString().c_str()); @@ -1473,10 +1651,16 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( "once: %s", dnums.DebugString().c_str()); } + if (output_dnums != expected_dnums) { + return InvalidArgument( + "Output dimensions of convolution must contain each dimension exactly " + "once: %s", + dnums.DebugString().c_str()); + } std::vector input_spatial_dims(num_spatial_dims); for (int i = 0; i < num_spatial_dims; ++i) { - input_spatial_dims[i] = lhs.dimensions(dnums.spatial_dimensions(i)); + input_spatial_dims[i] = lhs.dimensions(dnums.input_spatial_dimensions(i)); } const int64 input_features = lhs.dimensions(dnums.input_feature_dimension()); const int64 input_batch = lhs.dimensions(dnums.input_batch_dimension()); @@ -1524,17 +1708,27 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( dimensions[dnums.output_batch_dimension()] = input_batch; dimensions[dnums.output_feature_dimension()] = kernel_output_features; for (int i = 0; i < num_spatial_dims; ++i) { - dimensions[dnums.spatial_dimensions(i)] = window_output_shape.dimensions(i); + dimensions[dnums.output_spatial_dimensions(i)] = + window_output_shape.dimensions(i); } return ShapeUtil::MakeShape(lhs.element_type(), dimensions); } /* static */ StatusOr ShapeInference::InferCrossReplicaSumShape( - const Shape& operand) { - TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(operand, "operand of cross replica sum")); - return operand; + tensorflow::gtl::ArraySlice operand_shapes) { + for (const Shape* operand_shape : operand_shapes) { + TF_RETURN_IF_ERROR( + ExpectNotTupleOrOpaque(*operand_shape, "operand of cross replica sum")); + } + if (operand_shapes.size() == 1) { + return *operand_shapes[0]; + } + std::vector operand_shape_values; + for (const Shape* operand_shape : operand_shapes) { + operand_shape_values.push_back(*operand_shape); + } + return ShapeUtil::MakeTupleShape(operand_shape_values); } /* static */ StatusOr ShapeInference::InferReduceShape( @@ -1900,6 +2094,64 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return init; } +/* static */ StatusOr ShapeInference::InferConditionalShape( + const Shape& predicate, const Shape& true_operand, + const Shape& false_operand, const ProgramShape& true_computation, + const ProgramShape& false_computation) { + if (!ShapeUtil::ShapeIs(predicate, PRED, {})) { + return InvalidArgument("predicate must be a boolean; got %s.", + ShapeUtil::HumanString(predicate).c_str()); + } + + if (true_computation.parameters_size() != 1) { + return InvalidArgument("true_computation must take 1 argument; got %d.", + true_computation.parameters_size()); + } + if (!ShapeUtil::Compatible(true_computation.parameters(0), true_operand)) { + auto true_shape_string = [&]() { + return tensorflow::strings::Printf( + "true_operand: %s; true_computation: %s", + ShapeUtil::HumanString(true_operand).c_str(), + ShapeUtil::HumanString(true_computation).c_str()); + }; + return InvalidArgument( + "true_operand must match the shape of the only parameter of " + "true_computation: got %s.", + true_shape_string().c_str()); + } + + if (false_computation.parameters_size() != 1) { + return InvalidArgument("false_computation must take 1 argument; got %d.", + false_computation.parameters_size()); + } + if (!ShapeUtil::Compatible(false_computation.parameters(0), false_operand)) { + auto false_shape_string = [&]() { + return tensorflow::strings::Printf( + "false_operand: %s; false_computation: %s", + ShapeUtil::HumanString(false_operand).c_str(), + ShapeUtil::HumanString(false_computation).c_str()); + }; + return InvalidArgument( + "false_operand must match the shape of the only parameter of " + "false_computation: got %s.", + false_shape_string().c_str()); + } + if (!ShapeUtil::Compatible(true_computation.result(), + false_computation.result())) { + auto shape_string = [&]() { + return tensorflow::strings::Printf( + "true_computation result: %s; false_computation result: %s.", + ShapeUtil::HumanString(true_computation.result()).c_str(), + ShapeUtil::HumanString(false_computation.result()).c_str()); + }; + return InvalidArgument( + "the result of true_computation and false_computation must have the " + "same shape: got %s.", + shape_string().c_str()); + } + return true_computation.result(); +} + /* static */ StatusOr ShapeInference::InferBroadcastShape( const Shape& operand, tensorflow::gtl::ArraySlice broadcast_sizes) { TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(operand, "operand of broadcast")); @@ -1943,7 +2195,10 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( !std::is_permutation(dimensions.begin(), dimensions.end(), indices.begin())) { return InvalidArgument( - "Reshape dimensions not a permutation of the operand dimensions."); + "Reshape dimensions [%s] are not a permutation of the operand " + "dimensions (operand shape is %s).", + tensorflow::str_util::Join(dimensions, ",").c_str(), + ShapeUtil::HumanString(operand).c_str()); } return inferred_shape; diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index d5d497176d6c340d8c8f34cdacf6a9e32040c387..c06340d2d5df239642eb0af4836df64a898a1eaf 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -109,8 +109,10 @@ class ShapeInference { const Shape& lhs, const Shape& rhs, const Window& window, const ConvolutionDimensionNumbers& dimension_numbers); - // Infers the shape produced a cross replica sum with the given operand shape. - static StatusOr InferCrossReplicaSumShape(const Shape& operand); + // Infers the shape produced a cross replica sum with the given operand + // shapes. + static StatusOr InferCrossReplicaSumShape( + tensorflow::gtl::ArraySlice operand_shapes); // Infers the shape produced by applying the given reduction computation // shape to the given input operand shape. @@ -178,6 +180,12 @@ class ShapeInference { const ProgramShape& body, const Shape& init); + // Infers the shape produced by a conditional operation. + static StatusOr InferConditionalShape( + const Shape& predicate, const Shape& true_operand, + const Shape& false_operand, const ProgramShape& true_computation, + const ProgramShape& false_computation); + // Infers the shape produced by a broadcast operation. static StatusOr InferBroadcastShape( const Shape& operand, tensorflow::gtl::ArraySlice broadcast_sizes); @@ -204,6 +212,13 @@ class ShapeInference { static StatusOr InferConvertShape(const Shape& operand_shape, PrimitiveType new_element_type); + // Helper that validates the given operand shape can be bitcast converted to + // the target output_shape via a bitcast convert instruction -- the + // requirement is that the shape is identical except for the element type and + // the element types have identical bit-widths. + static StatusOr InferBitcastConvertShape( + const Shape& operand_shape, PrimitiveType new_element_type); + // Helper that validates the input data type for a reduce-precision operation, // and returns the result shape. static StatusOr InferReducePrecisionShape(const Shape& operand_shape, @@ -222,11 +237,13 @@ class ShapeInference { tensorflow::gtl::ArraySlice arg_shapes, const ProgramShape& to_apply); - private: // Helper that infers the shape produced by performing a dot operation with // the given LHS and RHS shapes. - static StatusOr InferDotOpShape(const Shape& lhs, const Shape& rhs); + static StatusOr InferDotOpShape( + const Shape& lhs, const Shape& rhs, + const DotDimensionNumbers& dimension_numbers); + private: // Helper that infers the shape produced by performing an element-wise binary // operation with the given LHS and RHS shapes. // Note: By "element-wise" we mean operations that look at a single element in diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index d12f7bd1453890db3280e54719a6ce811006336d..99d87f3b550ae72befe254f23fad080dd210aaf4 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -395,8 +395,10 @@ TEST_F(ShapeInferenceTest, Convolve) { dnums.set_output_batch_dimension(0); dnums.set_input_feature_dimension(1); dnums.set_output_feature_dimension(1); - dnums.add_spatial_dimensions(2); - dnums.add_spatial_dimensions(3); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.add_input_spatial_dimensions(3); + dnums.add_output_spatial_dimensions(3); // Dimension order: x1, batch, feature, x0 Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 12, 11, 3}); @@ -437,8 +439,10 @@ TEST_F(ShapeInferenceTest, ConvolveWithWindowDilation) { dnums.set_output_batch_dimension(0); dnums.set_input_feature_dimension(1); dnums.set_output_feature_dimension(1); - dnums.add_spatial_dimensions(2); - dnums.add_spatial_dimensions(3); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.add_input_spatial_dimensions(3); + dnums.add_output_spatial_dimensions(3); // Dimension order: x1, batch, feature, x0 Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 12, 11, 3}); @@ -480,8 +484,10 @@ TEST_F(ShapeInferenceTest, ConvolveWithBaseDilation) { dnums.set_output_batch_dimension(0); dnums.set_input_feature_dimension(1); dnums.set_output_feature_dimension(1); - dnums.add_spatial_dimensions(2); - dnums.add_spatial_dimensions(3); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.add_input_spatial_dimensions(3); + dnums.add_output_spatial_dimensions(3); // Dimension order: x1, batch, feature, x0 Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 12, 11, 4}); @@ -524,8 +530,10 @@ TEST_F(ShapeInferenceTest, ConvolveDimensionNumbersOverlapError) { dnums.set_output_batch_dimension(3); dnums.set_input_feature_dimension(2); dnums.set_output_feature_dimension(2); - dnums.add_spatial_dimensions(0); - dnums.add_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(0); + dnums.add_output_spatial_dimensions(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); dnums.set_kernel_input_feature_dimension(0); // duplicated with kernel_x0 dnums.set_kernel_output_feature_dimension(3); dnums.add_kernel_spatial_dimensions(0); @@ -890,8 +898,11 @@ TEST_F(ShapeInferenceTest, BroadcastScalar) { // scalar vector: error TEST_F(ShapeInferenceTest, ScalarDotVector) { + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); auto inferred_status = - ShapeInference::InferBinaryOpShape(BINOP_DOT, f32_, vector_32_, {}); + ShapeInference::InferDotOpShape(f32_, vector_32_, dot_dnums); ASSERT_FALSE(inferred_status.ok()); ASSERT_THAT(inferred_status.status().error_message(), HasSubstr("dot only supports rank")); @@ -899,61 +910,199 @@ TEST_F(ShapeInferenceTest, ScalarDotVector) { // 3D 2D: error TEST_F(ShapeInferenceTest, DotWithRankHigherThanTwo) { - auto inferred_status = ShapeInference::InferBinaryOpShape( - BINOP_DOT, ShapeUtil::MakeShape(F32, {32, 32, 32}), matrix_32_64_, {}); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + auto inferred_status = ShapeInference::InferDotOpShape( + ShapeUtil::MakeShape(F32, {32, 32, 32}), matrix_32_64_, dot_dnums); ASSERT_FALSE(inferred_status.ok()); ASSERT_THAT(inferred_status.status().error_message(), - HasSubstr("dot only supports rank")); + HasSubstr("batch and contracting dimension number mismatch")); } // vector vector -> scalar TEST_F(ShapeInferenceTest, VectorDotVector) { + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(0); + dot_dnums.add_rhs_contracting_dimensions(0); auto inferred_status = - ShapeInference::InferBinaryOpShape(BINOP_DOT, vector_64_, vector_64_, {}); + ShapeInference::InferDotOpShape(vector_64_, vector_64_, dot_dnums); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(f32_, inferred_status.ValueOrDie())); auto inferred_status_mismatch = - ShapeInference::InferBinaryOpShape(BINOP_DOT, vector_64_, vector_32_, {}); + ShapeInference::InferDotOpShape(vector_64_, vector_32_, dot_dnums); ASSERT_FALSE(inferred_status_mismatch.ok()); } // matrix vector -> vector TEST_F(ShapeInferenceTest, MatrixDotVector) { - auto inferred_status = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_DOT, matrix_32_64_, vector_64_, {}); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + auto inferred_status = + ShapeInference::InferDotOpShape(matrix_32_64_, vector_64_, dot_dnums); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), vector_32_)); - auto inferred_status_mismatch = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_DOT, matrix_32_64_, vector_32_, {}); + auto inferred_status_mismatch = + ShapeInference::InferDotOpShape(matrix_32_64_, vector_32_, dot_dnums); ASSERT_FALSE(inferred_status_mismatch.ok()); } // vector matrix -> vector TEST_F(ShapeInferenceTest, VectorDotMatrix) { - auto inferred_status = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_DOT, vector_32_, matrix_32_64_, {}); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(0); + dot_dnums.add_rhs_contracting_dimensions(0); + auto inferred_status = + ShapeInference::InferDotOpShape(vector_32_, matrix_32_64_, dot_dnums); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), vector_64_)); - auto inferred_status_mismatch = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_DOT, vector_64_, matrix_32_64_, {}); + auto inferred_status_mismatch = + ShapeInference::InferDotOpShape(vector_64_, matrix_32_64_, dot_dnums); ASSERT_FALSE(inferred_status_mismatch.ok()); } // matrix matrix -> matrix TEST_F(ShapeInferenceTest, MatrixDotMatrix) { - auto inferred_status_match = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_DOT, matrix_32_64_, matrix_64_48_, {}); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + auto inferred_status_match = + ShapeInference::InferDotOpShape(matrix_32_64_, matrix_64_48_, dot_dnums); ASSERT_IS_OK(inferred_status_match.status()); ASSERT_TRUE( ShapeUtil::Equal(inferred_status_match.ValueOrDie(), matrix_32_48_)) << "inferred: " << ShapeUtil::HumanString(inferred_status_match.ValueOrDie()) << " expected: " << ShapeUtil::HumanString(matrix_64_48_); - auto inferred_status_mismatch = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_DOT, matrix_32_64_, matrix_32_64_, {}); + auto inferred_status_mismatch = + ShapeInference::InferDotOpShape(matrix_32_64_, matrix_32_64_, dot_dnums); ASSERT_FALSE(inferred_status_mismatch.ok()); } +// BatchMatMul with two batch dimensions and one contracting dimension. +TEST_F(ShapeInferenceTest, DotGeneral) { + Shape lhs_shape = ShapeUtil::MakeShape(F32, {5, 2, 11, 3}); + Shape rhs_shape = ShapeUtil::MakeShape(F32, {5, 2, 3, 14}); + Shape output_shape = ShapeUtil::MakeShape(F32, {5, 2, 11, 14}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(3); + dot_dnums.add_lhs_batch_dimensions(0); + dot_dnums.add_lhs_batch_dimensions(1); + + dot_dnums.add_rhs_contracting_dimensions(2); + dot_dnums.add_rhs_batch_dimensions(0); + dot_dnums.add_rhs_batch_dimensions(1); + + auto inferred_status_match = + ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); + ASSERT_IS_OK(inferred_status_match.status()); + ASSERT_TRUE( + ShapeUtil::Equal(inferred_status_match.ValueOrDie(), output_shape)) + << "inferred: " + << ShapeUtil::HumanString(inferred_status_match.ValueOrDie()) + << " expected: " << ShapeUtil::HumanString(output_shape); +} + +// BatchMatMul with two contracting dimensions fails. +TEST_F(ShapeInferenceTest, DotWithTwoContractingDimsFails) { + Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3, 2}); + Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 14}); + Shape output_shape = ShapeUtil::MakeShape(F32, {2, 11, 14}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(2); + dot_dnums.add_lhs_contracting_dimensions(3); + dot_dnums.add_lhs_batch_dimensions(0); + + dot_dnums.add_rhs_contracting_dimensions(1); + dot_dnums.add_rhs_batch_dimensions(0); + + auto inferred_status = + ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); + ASSERT_FALSE(inferred_status.ok()); + ASSERT_THAT(inferred_status.status().error_message(), + HasSubstr("must specify one contracting dimension for both " + "lhs and rhs")); +} + +// BatchMatMul with different batch dimension sizes fails. +TEST_F(ShapeInferenceTest, DotWithMisatchedBatchDimSizesFails) { + Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3}); + Shape rhs_shape = ShapeUtil::MakeShape(F32, {3, 3, 14}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(2); + dot_dnums.add_lhs_batch_dimensions(0); + + dot_dnums.add_rhs_contracting_dimensions(1); + dot_dnums.add_rhs_batch_dimensions(0); + + auto inferred_status = + ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); + ASSERT_FALSE(inferred_status.ok()); + ASSERT_THAT(inferred_status.status().error_message(), + HasSubstr("batch dimension numbers and sizes must match")); +} + +// BatchMatMul with different batch dimension numbers fails. +TEST_F(ShapeInferenceTest, DotWithMisatchedBatchDimNumbersFails) { + Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3}); + Shape rhs_shape = ShapeUtil::MakeShape(F32, {3, 2, 14}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(2); + dot_dnums.add_lhs_batch_dimensions(0); + + dot_dnums.add_rhs_contracting_dimensions(0); + dot_dnums.add_rhs_batch_dimensions(1); + + auto inferred_status = + ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); + ASSERT_FALSE(inferred_status.ok()); + ASSERT_THAT(inferred_status.status().error_message(), + HasSubstr("batch dimension numbers must precede non-batch")); +} + +// BatchMatMul with out-of-range dimension numbers fails. +TEST_F(ShapeInferenceTest, DotWithContractingDimNumberOutOfRange) { + Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3}); + Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 14}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(3); + dot_dnums.add_lhs_batch_dimensions(0); + + dot_dnums.add_rhs_contracting_dimensions(0); + dot_dnums.add_rhs_batch_dimensions(1); + + auto inferred_status = + ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); + ASSERT_FALSE(inferred_status.ok()); + ASSERT_THAT(inferred_status.status().error_message(), + HasSubstr("A dimension number is out of range")); +} + +// BatchMatMul with non-unique dimension numbers fails. +TEST_F(ShapeInferenceTest, DotWithContractingNonUniqueDimNumber) { + Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3}); + Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 14}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(0); + dot_dnums.add_lhs_batch_dimensions(0); + + dot_dnums.add_rhs_contracting_dimensions(0); + dot_dnums.add_rhs_batch_dimensions(1); + + auto inferred_status = + ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); + ASSERT_FALSE(inferred_status.ok()); + ASSERT_THAT(inferred_status.status().error_message(), + HasSubstr("A dimension number is not unique")); +} + TEST_F(ShapeInferenceTest, BinOpBroadcastMatrixVector) { // Test variations of broadcasting a vector for a binary add with a // matrix. @@ -1288,5 +1437,80 @@ TEST_F(ShapeInferenceTest, Transpose) { ShapeUtil::MakeShape(F32, {3, 4, 5, 2}))); } +TEST_F(ShapeInferenceTest, Conditional) { + auto inferred_status0 = ShapeInference::InferConditionalShape( + pred_, vector_32_, vector_64_, + ShapeUtil::MakeProgramShape({vector_32_}, f32_), + ShapeUtil::MakeProgramShape({vector_64_}, f32_)); + EXPECT_IS_OK(inferred_status0.status()); + EXPECT_TRUE(ShapeUtil::Equal(f32_, inferred_status0.ValueOrDie())); + + auto inferred_status1 = ShapeInference::InferConditionalShape( + pred_, matrix_32_48_, vector_32_, + ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_64_), + ShapeUtil::MakeProgramShape({vector_32_}, vector_64_)); + EXPECT_IS_OK(inferred_status1.status()); + EXPECT_TRUE(ShapeUtil::Equal(vector_64_, inferred_status1.ValueOrDie())); + + auto tuple_f32_v32 = ShapeUtil::MakeTupleShape({f32_, vector_32_}); + auto inferred_status2 = ShapeInference::InferConditionalShape( + pred_, matrix_32_48_, tuple_f32_v32, + ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_), + ShapeUtil::MakeProgramShape({tuple_f32_v32}, vector_32_)); + EXPECT_IS_OK(inferred_status2.status()); + EXPECT_TRUE(ShapeUtil::Equal(vector_32_, inferred_status2.ValueOrDie())); + + auto inferred_status_error0 = ShapeInference::InferConditionalShape( + s32_, vector_32_, vector_64_, + ShapeUtil::MakeProgramShape({vector_32_}, f32_), + ShapeUtil::MakeProgramShape({vector_64_}, f32_)); + EXPECT_FALSE(inferred_status_error0.ok()); + EXPECT_THAT(inferred_status_error0.status().error_message(), + HasSubstr("predicate must be a boolean")); + + auto inferred_status_error1 = ShapeInference::InferConditionalShape( + pred_, ShapeUtil::MakeTupleShape({f32_, vector_32_}), matrix_32_48_, + ShapeUtil::MakeProgramShape({f32_, vector_32_}, vector_32_), + ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_)); + EXPECT_FALSE(inferred_status_error1.ok()); + EXPECT_THAT(inferred_status_error1.status().error_message(), + HasSubstr("true_computation must take 1 argument")); + + auto inferred_status_error2 = ShapeInference::InferConditionalShape( + pred_, vector_32_, vector_64_, + ShapeUtil::MakeProgramShape({vector_64_}, f32_), + ShapeUtil::MakeProgramShape({vector_64_}, f32_)); + EXPECT_FALSE(inferred_status_error2.ok()); + EXPECT_THAT(inferred_status_error2.status().error_message(), + HasSubstr("true_operand must match the shape of the only " + "parameter of true_computation")); + + auto inferred_status_error3 = ShapeInference::InferConditionalShape( + pred_, matrix_32_48_, ShapeUtil::MakeTupleShape({f32_, vector_32_}), + ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_), + ShapeUtil::MakeProgramShape({f32_, vector_32_}, vector_32_)); + EXPECT_FALSE(inferred_status_error3.ok()); + EXPECT_THAT(inferred_status_error3.status().error_message(), + HasSubstr("false_computation must take 1 argument")); + + auto inferred_status_error4 = ShapeInference::InferConditionalShape( + pred_, vector_32_, vector_64_, + ShapeUtil::MakeProgramShape({vector_32_}, f32_), + ShapeUtil::MakeProgramShape({vector_32_}, f32_)); + EXPECT_FALSE(inferred_status_error4.ok()); + EXPECT_THAT(inferred_status_error4.status().error_message(), + HasSubstr("false_operand must match the shape of the only " + "parameter of false_computation")); + + auto inferred_status_error5 = ShapeInference::InferConditionalShape( + pred_, vector_32_, vector_64_, + ShapeUtil::MakeProgramShape({vector_32_}, f32_), + ShapeUtil::MakeProgramShape({vector_64_}, vector_32_)); + EXPECT_FALSE(inferred_status_error5.ok()); + EXPECT_THAT(inferred_status_error5.status().error_message(), + HasSubstr("the result of true_computation and false_computation " + "must have the same shape")); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/shaped_buffer.cc b/tensorflow/compiler/xla/service/shaped_buffer.cc index a2a442eb1a33d976a114f68d112a7d8f3b540f4b..aa0a24a2833ec0b152f32f26f32e57ec6f7b5d14 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer.cc +++ b/tensorflow/compiler/xla/service/shaped_buffer.cc @@ -21,17 +21,19 @@ limitations under the License. #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/ptr_util.h" -#include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" namespace se = ::perftools::gputools; namespace xla { +using ::tensorflow::strings::Appendf; + /* static */ StatusOr> ShapedBuffer::MakeArrayShapedBuffer(const Shape& shape, const se::Platform* platform, @@ -49,6 +51,34 @@ ShapedBuffer::MakeArrayShapedBuffer(const Shape& shape, return std::move(shaped_buffer); } +/* static */ StatusOr> ShapedBuffer::Allocate( + const Shape& shape, DeviceMemoryAllocator* allocator, int device_ordinal, + const std::function& shape_size_fn) { + if (!LayoutUtil::HasLayout(shape)) { + return InvalidArgument("Shape must have a layout: %s", + ShapeUtil::HumanStringWithLayout(shape).c_str()); + } + TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(shape)); + auto shaped_buffer = WrapUnique( + new ShapedBuffer(shape, allocator->platform(), device_ordinal)); + + // Allocate an appropriate sized buffer for each element in the shape + // including the tuple pointer arrays. + for (auto& pair : shaped_buffer->shape_index_to_buffer_entry_) { + const ShapeIndex& index = pair.first; + size_t& buffer_entry = pair.second; + TF_ASSIGN_OR_RETURN( + se::DeviceMemoryBase memory_base, + allocator->Allocate(shaped_buffer->device_ordinal(), + shape_size_fn(ShapeUtil::GetSubshape( + shaped_buffer->shape(), index)))); + shaped_buffer->buffers_.push_back(memory_base); + buffer_entry = shaped_buffer->buffers_.size() - 1; + } + + return std::move(shaped_buffer); +} + ShapedBuffer::ShapedBuffer(const Shape& shape, const se::Platform* platform, int device_ordinal) : shape_(shape), @@ -63,6 +93,14 @@ void ShapedBuffer::clear() { } } +void ShapedBuffer::AddBufferAtIndex( + const perftools::gputools::DeviceMemoryBase& buffer, + const ShapeIndex& shape_index) { + *mutable_shape_index_to_buffer_entry()->mutable_element(shape_index) = + buffers().size(); + mutable_buffers()->push_back(buffer); +} + const se::DeviceMemoryBase& ShapedBuffer::buffer( const ShapeIndex& index) const { return buffers_[shape_index_to_buffer_entry_.element(index)]; @@ -72,67 +110,37 @@ se::DeviceMemoryBase* ShapedBuffer::mutable_buffer(const ShapeIndex& index) { return &buffers_[shape_index_to_buffer_entry_.element(index)]; } -/* static */ StatusOr> -ScopedShapedBuffer::Allocate(const Shape& shape, - DeviceMemoryAllocator* allocator, - int device_ordinal) { - if (!LayoutUtil::HasLayout(shape)) { - return InvalidArgument("Shape must have a layout: %s", - ShapeUtil::HumanStringWithLayout(shape).c_str()); - } - TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(shape)); - auto shaped_buffer = - WrapUnique(new ScopedShapedBuffer(shape, allocator, device_ordinal)); - - // Allocate an appropriate sized buffer for each element in the shape - // including the tuple pointer arrays. Gather tuple element addresses in - // 'element_addresses'. These will be written in the respective tuple's array - // of pointers on the device. - TF_ASSIGN_OR_RETURN(TransferManager * transfer_manager, - TransferManager::GetForPlatform(allocator->platform())); - ShapeTree> element_addresses(shape); - for (auto& pair : shaped_buffer->shape_index_to_buffer_entry_) { - const ShapeIndex& index = pair.first; - size_t& buffer_entry = pair.second; - TF_ASSIGN_OR_RETURN( - se::DeviceMemoryBase memory_base, - shaped_buffer->allocator_->Allocate( - shaped_buffer->device_ordinal(), - transfer_manager->GetByteSizeRequirement( - ShapeUtil::GetSubshape(shaped_buffer->shape(), index)))); - shaped_buffer->buffers_.push_back(memory_base); - buffer_entry = shaped_buffer->buffers_.size() - 1; - - // If this is a tuple element, then push the address on to the - // vector of tuple element addresses. - if (!index.empty()) { - ShapeIndex parent_index = index; - parent_index.pop_back(); - element_addresses.mutable_element(parent_index)->push_back(memory_base); - } - } - - // Fill in the tuple pointer arrays with the addresses of their respective - // elements. - TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, - allocator->platform()->ExecutorForDevice( - shaped_buffer->device_ordinal())); - for (const auto& pair : element_addresses) { - const ShapeIndex& index = pair.first; - const std::vector& addresses = pair.second; - const Shape& subshape = ShapeUtil::GetSubshape(shape, index); +string ShapedBuffer::ToString() const { + string s = "ShapedBuffer(" + platform_->Name() + "):\n"; + ShapeUtil::ForEachSubshape( + shape(), [this, &s](const Shape& subshape, const ShapeIndex& index) { + string shape_str; + if (ShapeUtil::IsTuple(subshape)) { + shape_str = "tuple"; + } else { + shape_str = ShapeUtil::HumanStringWithLayout(subshape); + } + const se::DeviceMemoryBase& memory = buffer(index); + Appendf(&s, " %s%p (%lld bytes) : %s\n", + string(index.size() * 2, ' ').c_str(), memory.opaque(), + memory.size(), shape_str.c_str()); + }); + return s; +} - if (addresses.empty()) { - TF_RET_CHECK(!ShapeUtil::IsTuple(subshape) || - ShapeUtil::TupleElementCount(subshape) == 0); - continue; - } - TF_RET_CHECK(ShapeUtil::IsTuple(subshape)); - TF_RETURN_IF_ERROR(transfer_manager->WriteTuplePointersToDevice( - executor, addresses, subshape, shaped_buffer->mutable_buffer(index))); - } +std::ostream& operator<<(std::ostream& out, const ShapedBuffer& buffer) { + out << buffer.ToString(); + return out; +} - return std::move(shaped_buffer); +/* static */ StatusOr> +ScopedShapedBuffer::Allocate( + const Shape& shape, DeviceMemoryAllocator* allocator, int device_ordinal, + const std::function& shape_size_fn) { + TF_ASSIGN_OR_RETURN( + std::unique_ptr unscoped_buffer, + ShapedBuffer::Allocate(shape, allocator, device_ordinal, shape_size_fn)); + return MakeScoped(unscoped_buffer.get(), allocator); } /* static */ diff --git a/tensorflow/compiler/xla/service/shaped_buffer.h b/tensorflow/compiler/xla/service/shaped_buffer.h index e5ea06fb136fa714eab0f340f98b7191a4c5caa3..ca8bfff674d2fad0fc5731cb2dc30b60bcf11997 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer.h +++ b/tensorflow/compiler/xla/service/shaped_buffer.h @@ -17,6 +17,8 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_SHAPED_BUFFER_H_ #include +#include +#include #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/shape_tree.h" @@ -41,6 +43,14 @@ class ShapedBuffer { const Shape& shape, const perftools::gputools::Platform* platform, int device_ordinal, const perftools::gputools::DeviceMemoryBase& buffer); + // Return a newly allocated ShapedBuffer of an arbitrary shape. Array buffers + // (leaves in the shape) are allocated and uninitialized. Tuple buffers (if + // any) are allocated and initialized to the backend-specific representation + // of an array of pointers to the tuple elements. + static StatusOr> Allocate( + const Shape& shape, DeviceMemoryAllocator* allocator, int device_ordinal, + const std::function& shape_size_fn); + ShapedBuffer(const Shape& shape, const perftools::gputools::Platform* platform, int device_ordinal); @@ -75,6 +85,12 @@ class ShapedBuffer { // Set all device memory pointers in the object to null. void clear(); + // Adds a new buffer at the given shape index. + void AddBufferAtIndex(const perftools::gputools::DeviceMemoryBase& buffer, + const ShapeIndex& shape_index); + + string ToString() const; + protected: // The shape of the device buffer with layout. const Shape shape_; @@ -95,17 +111,17 @@ class ShapedBuffer { ShapeTree shape_index_to_buffer_entry_; }; +std::ostream& operator<<(std::ostream& out, const ShapedBuffer& buffer); + // ShapedBuffer derived class which allocates all internal buffers on // construction and deallocates the memory when the object is // destructed. class ScopedShapedBuffer : public ShapedBuffer { public: - // Return a newly allocated ScopedShapedBuffer of an arbitrary shape. Array - // buffers (leaves in the shape) are allocated and uninitialized. Tuple - // buffers (if any) are allocated and initialized to the backend-specific - // representation of an array of pointers to the tuple elements. + // Identical to ShapedBuffer::Allocate. static StatusOr> Allocate( - const Shape& shape, DeviceMemoryAllocator* allocator, int device_ordinal); + const Shape& shape, DeviceMemoryAllocator* allocator, int device_ordinal, + const std::function& shape_size_fn); // Takes a ShapedBuffer and returns a ScopedShapedBuffer which manages the // deallocation of the device memory held in the shaped buffer. All device diff --git a/tensorflow/compiler/xla/service/transfer_manager.cc b/tensorflow/compiler/xla/service/transfer_manager.cc index 4da0a0d36841a6dfaed5c7eebdfb9e6980ad1090..d5f53ad56fb019d0ae7c27fc28706f05614ece68 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.cc +++ b/tensorflow/compiler/xla/service/transfer_manager.cc @@ -28,12 +28,9 @@ limitations under the License. namespace se = ::perftools::gputools; namespace xla { - -/* static */ tensorflow::mutex* -TransferManager::platform_transfer_manager_mutex() { - static tensorflow::mutex* m = new tensorflow::mutex; - return m; -} +/* static */ tensorflow::mutex + TransferManager::platform_transfer_manager_mutex_( + tensorflow::LINKER_INITIALIZED); /* static */ std::map* @@ -47,7 +44,7 @@ TransferManager::GetPlatformTransferManagers() { se::Platform::Id platform_id, TransferManagerCreationFunction creation_function) { tensorflow::mutex_lock lock( - *TransferManager::platform_transfer_manager_mutex()); + TransferManager::platform_transfer_manager_mutex_); auto* managers = GetPlatformTransferManagers(); CHECK(managers->find(platform_id) == managers->end()); (*managers)[platform_id].creation_function = creation_function; @@ -56,7 +53,7 @@ TransferManager::GetPlatformTransferManagers() { /* static */ StatusOr TransferManager::GetForPlatform( const se::Platform* platform) { tensorflow::mutex_lock lock( - *TransferManager::platform_transfer_manager_mutex()); + TransferManager::platform_transfer_manager_mutex_); auto* managers = GetPlatformTransferManagers(); auto it = managers->find(platform->id()); @@ -75,6 +72,39 @@ TransferManager::GetPlatformTransferManagers() { return it->second.manager.get(); } +Status TransferManager::WriteTupleIndexTables( + perftools::gputools::StreamExecutor* executor, + const ShapedBuffer& device_buffer) { + VLOG(2) << "Writing tuple index tables to ShapedBuffer rooted at " + << device_buffer.buffer(/*index=*/{}).opaque() + << "; shape: " << ShapeUtil::HumanString(device_buffer.shape()); + + TF_RET_CHECK(executor->device_ordinal() == device_buffer.device_ordinal()); + + return ShapeUtil::ForEachSubshapeWithStatus( + device_buffer.shape(), + [&](const Shape& device_subshape, const ShapeIndex& index) -> Status { + if (ShapeUtil::IsTuple(device_subshape)) { + se::DeviceMemoryBase device_memory = device_buffer.buffer(index); + TF_RET_CHECK(GetByteSizeRequirement(device_subshape) == + device_memory.size()); + + std::vector elements; + ShapeIndex element_index = index; + for (int64 i = 0; i < ShapeUtil::TupleElementCount(device_subshape); + ++i) { + element_index.push_back(i); + elements.push_back(device_buffer.buffer(element_index)); + element_index.pop_back(); + } + return WriteTuplePointersToDevice(executor, elements, device_subshape, + &device_memory); + } + + return Status::OK(); + }); +} + Status TransferManager::TransferBufferFromDevice( se::StreamExecutor* executor, const se::DeviceMemoryBase& source, int64 size, void* destination) { diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h index f63d91604cf40edfae98b56a8bacdbded697ffc3..be9b769ac8cf3cf1fcfd13dfe9f1458e55a5323d 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.h +++ b/tensorflow/compiler/xla/service/transfer_manager.h @@ -21,6 +21,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/shaped_buffer.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -47,6 +48,8 @@ class TransferManager { // executor. device_shape is the shape, including layout, of the data on the // device, while literal_shape will be the shape for the literal. device_shape // and literal_shape must be compatible, but need not have the same layout. + // TODO(b/66694934): Remove TransferLiteral* methods which accept bare + // DeviceMemoryBase. virtual Status TransferLiteralFromDevice( perftools::gputools::StreamExecutor* executor, const perftools::gputools::DeviceMemoryBase& region, @@ -59,6 +62,28 @@ class TransferManager { perftools::gputools::StreamExecutor* executor, const Literal& literal, perftools::gputools::DeviceMemoryBase* region) = 0; + // Returns the shape of the on-device representation for the given shape on + // the host. This is intended for use with ShapedBuffer where buffers are + // pre-allocated by the host, e.g. TransferLiteralToDevice, without the user + // needing to consider device-specific behaviors. + virtual Shape HostShapeToDeviceShape(const Shape& host_shape) const { + return host_shape; + } + + // Transfers the data held in the given ShapedBuffer into the provided literal + // using the provided executor. literal_shape will be the shape for the + // literal. The shape of the ShapedBuffer and DeviceShape(literal_shape) must + // be compatible, but need not have the same layout. + virtual StatusOr> TransferLiteralFromDevice( + perftools::gputools::StreamExecutor* executor, + const ShapedBuffer& device_buffer) = 0; + + // Transfers the given literal into the previously allocated device memory + // represented by the given ShapedBuffer using the given executor. + virtual Status TransferLiteralToDevice( + perftools::gputools::StreamExecutor* executor, const Literal& literal, + const ShapedBuffer& device_buffer) = 0; + // Transfers the given literal into the Infeed interface of the device, // using the given executor. virtual Status TransferLiteralToInfeed( @@ -97,15 +122,11 @@ class TransferManager { const perftools::gputools::DeviceMemoryBase& source, const Shape& shape) = 0; - // Writes the given device-memory pointers in 'elements' to the given region - // to construct a tuple in the platform-specific tuple representation. This - // can handle nested tuples as well. In the nested case, the element - // DeviceMemoryBase points to another array of pointers on the device. - virtual Status WriteTuplePointersToDevice( - perftools::gputools::StreamExecutor* executor, - tensorflow::gtl::ArraySlice - elements, - const Shape& shape, perftools::gputools::DeviceMemoryBase* region) = 0; + // Given an allocated ShapedBuffer, constructs the tuple index table(s) in + // each buffer of the given ShapedBuffer corresponding to tuple shapes. If the + // ShapedBuffer is array-shaped this method does nothing. + Status WriteTupleIndexTables(perftools::gputools::StreamExecutor* executor, + const ShapedBuffer& device_buffer); // Returns all buffer pointers that the tuple `source` refers to. Unlike // ShallowCopyTupleFromDevice, this function gather buffer pointers in nested @@ -119,24 +140,7 @@ class TransferManager { // Determines the byte size requirement for the given shape on the underlying // architecture. This will be used to allocate an appropriately sized memory // region for a host-to-device transfer. - virtual int64 GetByteSizeRequirement(const Shape& shape) = 0; - - // Transfer a memory block of the given size from the device source into the - // 'destination' buffer. - // - // size is the size to transfer to destination in bytes. - virtual Status TransferBufferFromDevice( - perftools::gputools::StreamExecutor* executor, - const perftools::gputools::DeviceMemoryBase& source, int64 size, - void* destination); - - // Transfer a memory block of the given size from 'source' buffer to the given - // destination of the device. - // - // size is the size to transfer from source in bytes. - virtual Status TransferBufferToDevice( - perftools::gputools::StreamExecutor* executor, int64 size, - const void* source, perftools::gputools::DeviceMemoryBase* destination); + virtual int64 GetByteSizeRequirement(const Shape& shape) const = 0; typedef std::unique_ptr (*TransferManagerCreationFunction)(); @@ -157,12 +161,37 @@ class TransferManager { static StatusOr GetForPlatform( const perftools::gputools::Platform* platform); + protected: + // Transfer a memory block of the given size from the device source into the + // 'destination' buffer. + // + // size is the size to transfer to destination in bytes. + virtual Status TransferBufferFromDevice( + perftools::gputools::StreamExecutor* executor, + const perftools::gputools::DeviceMemoryBase& source, int64 size, + void* destination); + + // Transfer a memory block of the given size from 'source' buffer to the given + // destination of the device. + // + // size is the size to transfer from source in bytes. + virtual Status TransferBufferToDevice( + perftools::gputools::StreamExecutor* executor, int64 size, + const void* source, perftools::gputools::DeviceMemoryBase* destination); + + // Writes the given device-memory pointers in 'elements' to the given region + // to construct a tuple in the platform-specific tuple representation. This + // can handle nested tuples as well. In the nested case, the element + // DeviceMemoryBase points to another array of pointers on the device. + virtual Status WriteTuplePointersToDevice( + perftools::gputools::StreamExecutor* executor, + tensorflow::gtl::ArraySlice + elements, + const Shape& shape, perftools::gputools::DeviceMemoryBase* region) = 0; + private: - // Routine that returns the mutex that guards the - // platform-to-transfer manager map. Done as a routine to - // ensure correct initialization ordering, since RegisterTransferManager - // can be called during program initialization time. - static tensorflow::mutex* platform_transfer_manager_mutex(); + // The mutex that guards the platform-to-transfer manager map. + static tensorflow::mutex platform_transfer_manager_mutex_; // State kept for each kind of TransferManager. Registration functions // set up creation_function, and then we use that to lazily create diff --git a/tensorflow/compiler/xla/service/transfer_manager_test.cc b/tensorflow/compiler/xla/service/transfer_manager_test.cc deleted file mode 100644 index c25a0861e9b90bc0f2cde43933e14204aa4e3598..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/transfer_manager_test.cc +++ /dev/null @@ -1,161 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include - -#include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/service/generic_transfer_manager.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tests/literal_test_util.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/stream_executor_no_cuda.h" - -#include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/platform/types.h" - -namespace se = ::perftools::gputools; - -namespace xla { - -namespace { - -class CpuTransferManagerTest : public ::testing::Test { - protected: - CpuTransferManagerTest() - : transfer_manager_(se::host::kHostPlatformId, - /*pointer_size=*/sizeof(void*)) { - se::Platform* platform = - se::MultiPlatformManager::PlatformWithId(se::host::kHostPlatformId) - .ValueOrDie(); - stream_exec_ = - platform->GetExecutor(se::StreamExecutorConfig(/*ordinal=*/0)) - .ValueOrDie(); - } - - ~CpuTransferManagerTest() override {} - - se::StreamExecutor* stream_exec_; - GenericTransferManager transfer_manager_; -}; - -TEST_F(CpuTransferManagerTest, TransferR0U32ToDevice) { - std::vector storage(sizeof(uint32), '\x00'); - se::DeviceMemoryBase memptr(storage.data(), storage.size()); - std::unique_ptr literal = Literal::CreateR0(42); - TF_CHECK_OK(transfer_manager_.TransferLiteralToDevice(stream_exec_, *literal, - &memptr)); - - CHECK_EQ(42, *reinterpret_cast(&storage[0])); -} - -TEST_F(CpuTransferManagerTest, TransferR1F32ToDevice) { - std::vector storage(4 * sizeof(float), '\x00'); - se::DeviceMemoryBase memptr(storage.data(), storage.size()); - std::unique_ptr literal = - Literal::CreateR1({1.25f, 2.5f, -17.0f, -20.125f}); - TF_CHECK_OK(transfer_manager_.TransferLiteralToDevice(stream_exec_, *literal, - &memptr)); - - CHECK_EQ(1.25f, *reinterpret_cast(&storage[0])); - CHECK_EQ(2.5f, *reinterpret_cast(&storage[sizeof(float)])); - CHECK_EQ(-17.0f, *reinterpret_cast(&storage[2 * sizeof(float)])); - CHECK_EQ(-20.125f, *reinterpret_cast(&storage[3 * sizeof(float)])); -} - -TEST_F(CpuTransferManagerTest, TransferR1U8ToDevice) { - std::vector storage(16, '\x00'); - se::DeviceMemoryBase memptr(storage.data(), storage.size()); - const char* str = "0123456789abcdef"; - std::unique_ptr literal = Literal::CreateR1U8(str); - TF_CHECK_OK(transfer_manager_.TransferLiteralToDevice(stream_exec_, *literal, - &memptr)); - - CHECK_EQ('0', storage[0]); - CHECK_EQ('8', storage[8]); - CHECK_EQ('f', storage[15]); -} - -TEST_F(CpuTransferManagerTest, TransferR0U32FromDevice) { - std::vector storage(1, 42); - se::DeviceMemoryBase memptr(storage.data(), - storage.size() * sizeof(storage[0])); - Literal literal; - const Shape shape = ShapeUtil::MakeShape(U32, {}); - TF_CHECK_OK(transfer_manager_.TransferLiteralFromDevice( - stream_exec_, memptr, shape, shape, &literal)); - - LiteralTestUtil::ExpectR0Equal(42, literal); -} - -TEST_F(CpuTransferManagerTest, TransferR1F32FromDevice) { - std::vector storage{1.25f, 2.5f, -17.0f, -20.125f}; - se::DeviceMemoryBase memptr(storage.data(), - storage.size() * sizeof(storage[0])); - Literal literal; - const Shape shape = ShapeUtil::MakeShape(F32, {4}); - TF_CHECK_OK(transfer_manager_.TransferLiteralFromDevice( - stream_exec_, memptr, shape, shape, &literal)); - - LiteralTestUtil::ExpectR1Equal({1.25, 2.5, -17.0, -20.125}, literal); -} - -TEST_F(CpuTransferManagerTest, TransferR1U8FromDevice) { - std::vector storage{'k', 'l', 'm', 'n'}; - se::DeviceMemoryBase memptr(storage.data(), - storage.size() * sizeof(storage[0])); - Literal literal; - const Shape shape = ShapeUtil::MakeShape(U8, {4}); - TF_CHECK_OK(transfer_manager_.TransferLiteralFromDevice( - stream_exec_, memptr, shape, shape, &literal)); - CHECK_EQ("klmn", literal.u8s_string()); -} - -TEST_F(CpuTransferManagerTest, TransferBufferFromDevice) { - std::vector storage{1, 5, 42}; - int64 size = storage.size() * sizeof(storage[0]); - se::DeviceMemoryBase memptr(storage.data(), size); - - std::vector dest(3, 0); - TF_CHECK_OK(transfer_manager_.TransferBufferFromDevice(stream_exec_, memptr, - size, dest.data())); - ASSERT_EQ(1, dest[0]); - ASSERT_EQ(5, dest[1]); - ASSERT_EQ(42, dest[2]); -} - -TEST_F(CpuTransferManagerTest, TransferBufferToDevice) { - int64 size = 3 * sizeof(uint64); - std::vector storage(size, 0); - se::DeviceMemoryBase memptr(storage.data(), size); - - std::vector dest{1, 5, 42}; - TF_CHECK_OK(transfer_manager_.TransferBufferToDevice(stream_exec_, size, - dest.data(), &memptr)); - std::vector* storage64 = - reinterpret_cast*>(&storage); - ASSERT_EQ(1, (*storage64)[0]); - ASSERT_EQ(5, (*storage64)[1]); - ASSERT_EQ(42, (*storage64)[2]); -} - -// TODO(b/24679870): add similar tests for GPUs - -} // namespace - -} // namespace xla diff --git a/tensorflow/compiler/xla/service/transpose_folding.cc b/tensorflow/compiler/xla/service/transpose_folding.cc index 8c2640adf52f10c387e7a9c09c0d73a09c054919..42b616f4c3446957eec13874eac74e80195f85a4 100644 --- a/tensorflow/compiler/xla/service/transpose_folding.cc +++ b/tensorflow/compiler/xla/service/transpose_folding.cc @@ -58,27 +58,11 @@ TransposeFolding::OperandIndices CanFoldOperandsIntoConvolution( return {}; } - const ConvolutionDimensionNumbers& dnums = - convolution.convolution_dimension_numbers(); - TransposeFolding::OperandIndices operand_set; for (int64 i = 0; i < convolution.operand_count(); ++i) { auto& operand = *convolution.operand(i); if (operand.opcode() == HloOpcode::kTranspose && operand.user_count() == 1) { - const auto& transpose_dimensions = operand.dimensions(); - // We can transpose the LHS so long as it doesn't move around spatial - // dimensions because ConvolutionDimensionNumbers doesn't have different - // fields for input and output spatial dimensions. - if (i == 0 && - std::any_of(dnums.spatial_dimensions().begin(), - dnums.spatial_dimensions().end(), - [&](const int64 spatial_dimension) { - return transpose_dimensions[spatial_dimension] != - spatial_dimension; - })) { - continue; - } operand_set.push_back(i); } } @@ -118,6 +102,10 @@ bool FoldTransposeIntoConvolution(InstructionOperandsPair pair) { auto& convolution = *pair.first; auto& operand_indices = pair.second; + if (operand_indices.empty()) { + return false; + } + const ConvolutionDimensionNumbers& dnums = convolution.convolution_dimension_numbers(); ConvolutionDimensionNumbers new_dnums = dnums; @@ -137,8 +125,9 @@ bool FoldTransposeIntoConvolution(InstructionOperandsPair pair) { transpose_dimensions[dnums.input_batch_dimension()]); new_dnums.set_input_feature_dimension( transpose_dimensions[dnums.input_feature_dimension()]); - for (const auto& spatial_dimension : dnums.spatial_dimensions()) { - CHECK_EQ(spatial_dimension, transpose_dimensions[spatial_dimension]); + for (auto& input_spatial_dimension : + *new_dnums.mutable_input_spatial_dimensions()) { + input_spatial_dimension = transpose_dimensions[input_spatial_dimension]; } new_lhs = &transpose_operand; } else { diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc index 00462f9be1e9beb2f2694060ebfaa70b0b9dd4a0..caa1a111ad880b9dee62c1c94e32e8275c196fbf 100644 --- a/tensorflow/compiler/xla/service/transpose_folding_test.cc +++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc @@ -64,9 +64,12 @@ TEST_F(TransposeFoldingTest, FoldDotTranspose) { HloInstruction* transpose_y = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {3, 2}), y, {1, 0})); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {2, 2}), /*opcode=*/HloOpcode::kDot, - /*lhs=*/x, /*rhs=*/transpose_y)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + HloInstruction* dot = builder.AddInstruction( + HloInstruction::CreateDot(ShapeUtil::MakeShape(F32, {2, 2}), /*lhs=*/x, + /*rhs=*/transpose_y, dot_dnums)); HloModule module("test_module"); HloComputation* entry_computation = @@ -104,9 +107,12 @@ TEST_F(TransposeFoldingTest, FoldDotTransposeConstant) { HloInstruction* transpose1 = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {2, 3}), const1, {1, 0})); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {1, 3}), /*opcode=*/HloOpcode::kDot, - /*lhs=*/transpose0, /*rhs=*/transpose1)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( + ShapeUtil::MakeShape(F32, {1, 3}), + /*lhs=*/transpose0, /*rhs=*/transpose1, dot_dnums)); HloModule module("test_module"); HloComputation* entry_computation = @@ -169,9 +175,12 @@ TEST_F(TransposeFoldingTest, FoldDotTransposeInWhile) { HloInstruction* transpose_y = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {3, 2}), y, {1, 0})); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {2, 2}), /*opcode=*/HloOpcode::kDot, - /*lhs=*/x, /*rhs=*/transpose_y)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + HloInstruction* dot = builder.AddInstruction( + HloInstruction::CreateDot(ShapeUtil::MakeShape(F32, {2, 2}), /*lhs=*/x, + /*rhs=*/transpose_y, dot_dnums)); HloModule module("test_module"); HloComputation* entry_computation = @@ -362,10 +371,82 @@ TEST_F(TransposeFoldingTest, FoldConvTransposeLhs) { EXPECT_EQ( dnums.input_batch_dimension(), new_conv->convolution_dimension_numbers().input_feature_dimension()); - EXPECT_EQ(dnums.spatial_dimensions(0), - new_conv->convolution_dimension_numbers().spatial_dimensions(0)); - EXPECT_EQ(dnums.spatial_dimensions(1), - new_conv->convolution_dimension_numbers().spatial_dimensions(1)); + EXPECT_EQ( + dnums.input_spatial_dimensions(0), + new_conv->convolution_dimension_numbers().input_spatial_dimensions(0)); + EXPECT_EQ( + dnums.input_spatial_dimensions(1), + new_conv->convolution_dimension_numbers().input_spatial_dimensions(1)); + EXPECT_EQ( + dnums.output_spatial_dimensions(0), + new_conv->convolution_dimension_numbers().output_spatial_dimensions(0)); + EXPECT_EQ( + dnums.output_spatial_dimensions(1), + new_conv->convolution_dimension_numbers().output_spatial_dimensions(1)); +} + +// Test that a transpose of every dimension in the activations gets folded into +// convolution. +TEST_F(TransposeFoldingTest, FoldConvComplexTransposeLhs) { + auto builder = HloComputation::Builder("entry_computation"); + HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {3, 2, 1, 1}), + /*name=*/"x")); + HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), + /*name=*/"y")); + HloInstruction* transpose_x = + builder.AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), x, {1, 0, 3, 2})); + auto dnums = ComputationBuilder::CreateDefaultConvDimensionNumbers(); + Window window; + for (int i = 0; i < 2; ++i) { + WindowDimension* dim = window.add_dimensions(); + dim->set_padding_low(0); + dim->set_padding_high(0); + dim->set_base_dilation(1); + dim->set_window_dilation(1); + dim->set_stride(1); + dim->set_size(y->shape().dimensions(dnums.kernel_spatial_dimensions(i))); + } + StatusOr conv_shape = ShapeInference::InferConvolveShape( + transpose_x->shape(), y->shape(), window, dnums); + EXPECT_IS_OK(conv_shape); + HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( + conv_shape.ValueOrDie(), transpose_x, y, window, dnums)); + + HloModule module("test_module"); + HloComputation* entry_computation = + module.AddEntryComputation(builder.Build(conv)); + FoldTranspose(&module); + + // Instructions after folding: x, y, and the convolution. + std::unordered_set instruction_set( + entry_computation->instructions().begin(), + entry_computation->instructions().end()); + EXPECT_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation."; + EXPECT_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation."; + EXPECT_EQ(1, instruction_set.size()) + << "entry_computation should contain exactly 3 instructions."; + HloInstruction* new_conv = *instruction_set.begin(); + EXPECT_EQ(HloOpcode::kConvolution, new_conv->opcode()); + EXPECT_EQ(dnums.input_feature_dimension(), + new_conv->convolution_dimension_numbers().input_batch_dimension()); + EXPECT_EQ( + dnums.input_batch_dimension(), + new_conv->convolution_dimension_numbers().input_feature_dimension()); + EXPECT_EQ( + dnums.input_spatial_dimensions(0), + new_conv->convolution_dimension_numbers().input_spatial_dimensions(1)); + EXPECT_EQ( + dnums.input_spatial_dimensions(1), + new_conv->convolution_dimension_numbers().input_spatial_dimensions(0)); + EXPECT_EQ( + dnums.output_spatial_dimensions(0), + new_conv->convolution_dimension_numbers().output_spatial_dimensions(0)); + EXPECT_EQ( + dnums.output_spatial_dimensions(1), + new_conv->convolution_dimension_numbers().output_spatial_dimensions(1)); } } // namespace diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc index df537bd7c15a1f15ed77ca9be6ce70fbfd2e63be..0c848566478a25d4862cb0698e029dacd71f7a6a 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc @@ -120,6 +120,23 @@ void PointsToSet::add_tuple_source(const ShapeIndex& index, tree_.mutable_element(index)->tuple_sources.insert(tuple); } +namespace { + +// Gather fusion instructions from 'instruction' into 'fusion_instructions'. +void GatherFusionInstructions( + HloInstruction* instruction, + std::vector* fusion_instructions) { + CHECK_EQ(HloOpcode::kFusion, instruction->opcode()); + for (auto* fused : instruction->fused_instructions()) { + if (fused->opcode() == HloOpcode::kFusion) { + GatherFusionInstructions(fused, fusion_instructions); + } + } + fusion_instructions->push_back(instruction); +} + +} // namespace + /* static */ StatusOr> TuplePointsToAnalysis::Run(const HloModule* module) { auto logical_buffer_analysis = LogicalBufferAnalysis::Run(module); @@ -137,20 +154,23 @@ Status TuplePointsToAnalysis::Analyze() { logical_buffer_aliases_.resize( logical_buffer_analysis_->num_logical_buffers()); + std::vector fusion_instructions; for (auto* computation : module_->MakeNonfusionComputations()) { TF_RETURN_IF_ERROR(computation->Accept(this)); TF_RETURN_IF_ERROR( PopulateDefinedBuffersAndAliases(computation->instructions())); - // Run points-to analysis on fusion instructions in 'computation'. for (auto* instruction : computation->instructions()) { - if (instruction->opcode() != HloOpcode::kFusion) { - continue; + if (instruction->opcode() == HloOpcode::kFusion) { + GatherFusionInstructions(instruction, &fusion_instructions); } - TF_RETURN_IF_ERROR(instruction->fused_expression_root()->Accept(this)); - TF_RETURN_IF_ERROR( - PopulateDefinedBuffersAndAliases(instruction->fused_instructions())); } } + // Run points-to analysis on fusion instructions in 'computation'. + for (auto* instruction : fusion_instructions) { + TF_RETURN_IF_ERROR(instruction->fused_expression_root()->Accept(this)); + TF_RETURN_IF_ERROR( + PopulateDefinedBuffersAndAliases(instruction->fused_instructions())); + } XLA_VLOG_LINES(3, ToString()); @@ -253,6 +273,64 @@ Status TuplePointsToAnalysis::HandleBitcast(HloInstruction* bitcast) { return Status::OK(); } +Status TuplePointsToAnalysis::HandleRecvDone(HloInstruction* recv_done) { + // RecvDone aliases its input (Recv) tuple element {0} to its output. + PointsToSet& points_to_set = CreateEmptyPointsToSet(recv_done); + const PointsToSet& operand_points_to_set = + GetPointsToSet(recv_done->operand(0)); + + // Recursively copy the points to set of the operand tuple {0}. + points_to_set.ForEachMutableElement( + [this, &points_to_set, &operand_points_to_set]( + const ShapeIndex& index, PointsToSet::BufferList* buffers) { + ShapeIndex src_index({0}); + for (auto element : index) { + src_index.push_back(element); + } + *buffers = operand_points_to_set.element(src_index); + for (auto& tuple_source : + operand_points_to_set.tuple_sources(src_index)) { + points_to_set.add_tuple_source(index, tuple_source); + } + }); + return Status::OK(); +} + +Status TuplePointsToAnalysis::HandleSend(HloInstruction* send) { + // Send creates a tuple of {aliased operand, U32 context}. + PointsToSet& points_to_set = CreateEmptyPointsToSet(send); + + // Creates the points to set for the tuple and its element at {1}. + auto top_buffer = points_to_set.mutable_element(ShapeIndex({})); + top_buffer->push_back( + &logical_buffer_analysis_->GetBuffer(send, ShapeIndex({}))); + points_to_set.add_tuple_source({}, send); + + auto context_buffer = points_to_set.mutable_element(ShapeIndex({1})); + context_buffer->push_back( + &logical_buffer_analysis_->GetBuffer(send, ShapeIndex({1}))); + + // Recursively copy the points to set of the operand to output tuple {0}. + const PointsToSet& operand_points_to_set = GetPointsToSet(send->operand(0)); + operand_points_to_set.ForEachElement( + [&points_to_set, &operand_points_to_set]( + const ShapeIndex& src_index, + const PointsToSet::BufferList& points_to) { + ShapeIndex target_index({0}); + for (auto element : src_index) { + target_index.push_back(element); + } + *points_to_set.mutable_element(target_index) = points_to; + + for (HloInstruction* tuple : + operand_points_to_set.tuple_sources(src_index)) { + points_to_set.add_tuple_source(target_index, tuple); + } + }); + + return Status::OK(); +} + Status TuplePointsToAnalysis::HandleTuple(HloInstruction* tuple) { tensorflow::gtl::ArraySlice operands(tuple->operands()); PointsToSet& points_to_set = CreateEmptyPointsToSet(tuple); diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h index e6157a1ed11b5df24458fe820a4e0e329eb86ae4..8928de107eed8c40bbe2130e26fe83ca3802d2f6 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h @@ -251,6 +251,8 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; Status HandleBitcast(HloInstruction* bitcast) override; Status HandleCopy(HloInstruction* copy) override; + Status HandleRecvDone(HloInstruction* recv_done) override; + Status HandleSend(HloInstruction* send) override; Status HandleSelect(HloInstruction* select) override; string ToString() const; diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc index 694ed57fa24d59bd0a28c7bb9b67af8165e90363..dec446d4dac650ba43992f7870764eedc80cb2cf 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc @@ -313,6 +313,51 @@ TEST_F(TuplePointsToAnalysisTest, TupleCopy) { {constant1, constant2, copy}); } +TEST_F(TuplePointsToAnalysisTest, SendAndSendDone) { + // Send forwards its operand to the output tuple at {0}. + auto builder = HloComputation::Builder(TestName()); + auto constant = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + auto send = builder.AddInstruction( + HloInstruction::CreateSend(constant, /*channel_id=*/0)); + auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send)); + + BuildModuleAndRunAnalysis(builder.Build()); + + EXPECT_FALSE(points_to_analysis_->GetPointsToSet(send).IsAmbiguous()); + EXPECT_TRUE(points_to_analysis_->GetPointsToSet(send).IsDistinct()); + EXPECT_FALSE(points_to_analysis_->GetPointsToSet(send_done).IsAmbiguous()); + EXPECT_TRUE(points_to_analysis_->GetPointsToSet(send_done).IsDistinct()); + + ExpectHasTopLevelBuffers( + points_to_analysis_->GetPointsToSet(send).element({}), {send}); + ExpectHasTopLevelBuffers( + points_to_analysis_->GetPointsToSet(send).element({0}), {constant}); + ExpectHasTopLevelBuffers( + points_to_analysis_->GetPointsToSet(send_done).CreateFlattenedSet(), + {send_done}); + ExpectHasBufferAliases(constant, {}, {{constant, {}}, {send, {0}}}); +} + +TEST_F(TuplePointsToAnalysisTest, RecvAndRecvDone) { + // RecvDone forwards its operand tuple element at {0} to the output. + auto builder = HloComputation::Builder(TestName()); + auto recv = builder.AddInstruction(HloInstruction::CreateRecv( + ShapeUtil::MakeShape(F32, {1, 2, 3}), /*channel_id=*/0)); + auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv)); + + BuildModuleAndRunAnalysis(builder.Build()); + + EXPECT_FALSE(points_to_analysis_->GetPointsToSet(recv).IsAmbiguous()); + EXPECT_TRUE(points_to_analysis_->GetPointsToSet(recv).IsDistinct()); + EXPECT_FALSE(points_to_analysis_->GetPointsToSet(recv_done).IsAmbiguous()); + EXPECT_TRUE(points_to_analysis_->GetPointsToSet(recv_done).IsDistinct()); + + ExpectHasTopLevelBuffers( + points_to_analysis_->GetPointsToSet(recv).element({}), {recv}); + ExpectHasBufferAliases(recv, {0}, {{recv, {0}}, {recv_done, {}}}); +} + TEST_F(TuplePointsToAnalysisTest, TupleSelect) { // Select from two different tuples. This should create an ambiguous points to // set containing the union of both sides. diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc index 006c814996df9b209e6cd4d75bc04689c4e297c5..e6893c8133b17cac3ca381df58d417eef15b60c4 100644 --- a/tensorflow/compiler/xla/service/user_computation.cc +++ b/tensorflow/compiler/xla/service/user_computation.cc @@ -88,8 +88,6 @@ HloOpcode BinaryOperationToHloOpcode(BinaryOperation binop) { return HloOpcode::kAtan2; case BINOP_COMPLEX: return HloOpcode::kComplex; - case BINOP_DOT: - return HloOpcode::kDot; case BINOP_MUL: return HloOpcode::kMultiply; case BINOP_ADD: @@ -765,6 +763,54 @@ StatusOr UserComputation::AddWhileInstruction( return handle; } +StatusOr UserComputation::AddConditionalInstruction( + const ConditionalRequest& conditional_request, + const UserComputation& true_computation, + const UserComputation& false_computation) { + tensorflow::mutex_lock lock(mutex_); + + TF_ASSIGN_OR_RETURN(const OperationRequest* pred, + LookUpRequest(conditional_request.predicate())); + TF_ASSIGN_OR_RETURN(const OperationRequest* true_operand, + LookUpRequest(conditional_request.true_operand())); + TF_ASSIGN_OR_RETURN(const OperationRequest* false_operand, + LookUpRequest(conditional_request.false_operand())); + + VersionedComputationHandle::Version true_computation_version = + true_computation.version(); + TF_ASSIGN_OR_RETURN( + std::shared_ptr true_computation_shape, + true_computation.ComputeProgramShape(true_computation_version)); + + VersionedComputationHandle::Version false_computation_version = + false_computation.version(); + TF_ASSIGN_OR_RETURN( + std::shared_ptr false_computation_shape, + false_computation.ComputeProgramShape(false_computation_version)); + + TF_ASSIGN_OR_RETURN(Shape inferred_shape, + ShapeInference::InferConditionalShape( + pred->output_shape(), true_operand->output_shape(), + false_operand->output_shape(), + *true_computation_shape, *false_computation_shape)); + + ComputationDataHandle handle = CreateComputationDataHandle(); + + OperationRequest& request = + (*session_computation_.mutable_requests())[handle.handle()]; + *request.mutable_output_handle() = handle; + *request.mutable_output_shape() = inferred_shape; + request.add_embedded_computation_versions(true_computation_version); + request.add_embedded_computation_versions(false_computation_version); + *request.mutable_request()->mutable_conditional_request() = + conditional_request; + + VLOG(1) << "AddConditionalInstruction (" << GetVersionedHandleInternal() + << "), data handle " << handle.handle() << ": " + << conditional_request.ShortDebugString(); + return handle; +} + StatusOr UserComputation::AddBroadcastInstruction( const BroadcastRequest& broadcast_request) { tensorflow::mutex_lock lock(mutex_); @@ -994,6 +1040,32 @@ StatusOr UserComputation::AddConvertInstruction( return handle; } +StatusOr UserComputation::AddBitcastConvertInstruction( + const ConvertRequest& convert_request) { + tensorflow::mutex_lock lock(mutex_); + + TF_ASSIGN_OR_RETURN(const OperationRequest* operand, + LookUpRequest(convert_request.operand())); + + TF_ASSIGN_OR_RETURN(Shape new_shape, ShapeInference::InferConvertShape( + operand->output_shape(), + convert_request.new_element_type())); + + ComputationDataHandle handle = CreateComputationDataHandle(); + + OperationRequest& request = + (*session_computation_.mutable_requests())[handle.handle()]; + *request.mutable_output_handle() = handle; + *request.mutable_output_shape() = new_shape; + *request.mutable_request()->mutable_bitcast_convert_request() = + convert_request; + + VLOG(1) << "AddBitcastConvertInstruction (" << GetVersionedHandleInternal() + << "), data handle " << handle.handle() << ": " + << convert_request.ShortDebugString(); + return handle; +} + StatusOr UserComputation::AddReducePrecisionInstruction( const ReducePrecisionRequest& reduce_precision_request) { tensorflow::mutex_lock lock(mutex_); @@ -1056,7 +1128,7 @@ StatusOr UserComputation::AddCrossReplicaSumInstruction( TF_ASSIGN_OR_RETURN(const OperationRequest* operand, LookUpRequest(cross_replica_sum_request.operand())); TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferCrossReplicaSumShape( - operand->output_shape())); + {&operand->output_shape()})); ComputationDataHandle handle = CreateComputationDataHandle(); @@ -1181,6 +1253,33 @@ StatusOr UserComputation::AddCustomCallInstruction( return handle; } +StatusOr UserComputation::AddDotInstruction( + const DotRequest& dot_request) { + tensorflow::mutex_lock lock(mutex_); + + TF_ASSIGN_OR_RETURN(const OperationRequest* lhs, + LookUpRequest(dot_request.lhs())); + TF_ASSIGN_OR_RETURN(const OperationRequest* rhs, + LookUpRequest(dot_request.rhs())); + + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferDotOpShape( + lhs->output_shape(), rhs->output_shape(), + dot_request.dimension_numbers())); + + const ComputationDataHandle handle = CreateComputationDataHandle(); + + OperationRequest& request = + (*session_computation_.mutable_requests())[handle.handle()]; + *request.mutable_output_handle() = handle; + *request.mutable_output_shape() = shape; + *request.mutable_request()->mutable_dot_request() = dot_request; + + VLOG(1) << "AddDotInstruction (" << GetVersionedHandleInternal() + << "), data handle " << handle.handle() << ": " + << dot_request.ShortDebugString(); + return handle; +} + StatusOr UserComputation::AddUnaryInstruction( const UnaryOpRequest& unary_request) { tensorflow::mutex_lock lock(mutex_); @@ -1482,14 +1581,15 @@ UserComputation::ComputeProgramShape( namespace { -// A visitor which checks whether an operation is a compile-time constant. That -// is, the operation does not depend on any parameter instructions. The visitor -// walks the computation starting at a given operation and sets is_constant to -// false iff a parameter or RNG operation is encountered. -void ConstantVisitor(const SessionComputation& session_computation, - const ComputationDataHandle& handle, - std::set* visited, bool* is_constant) { - if (visited->count(handle.handle()) != 0 || !*is_constant) { +// A visitor which checks whether an operation is pure functional meaning that +// it doesn't depend on any parameter with an index higher then num_parameters. +// The visitor walks the computation starting at a given operation and sets +// is_functional to false iff a parameter or RNG operation is encountered. +void PureFunctionalVisitor(const SessionComputation& session_computation, + const ComputationDataHandle& handle, + int64 num_parameters, std::set* visited, + bool* is_functional) { + if (visited->count(handle.handle()) != 0 || !*is_functional) { return; } @@ -1497,7 +1597,7 @@ void ConstantVisitor(const SessionComputation& session_computation, session_computation.requests().at(handle.handle()); switch (request.request().op_case()) { case OpRequest::kRngRequest: - *is_constant = false; + *is_functional = false; break; case OpRequest::kConstantRequest: @@ -1506,41 +1606,43 @@ void ConstantVisitor(const SessionComputation& session_computation, case OpRequest::kGetTupleElementRequest: { const GetTupleElementRequest& get_tuple_element_request = request.request().get_tuple_element_request(); - ConstantVisitor(session_computation, get_tuple_element_request.operand(), - visited, is_constant); + PureFunctionalVisitor(session_computation, + get_tuple_element_request.operand(), num_parameters, + visited, is_functional); break; } case OpRequest::kSliceRequest: { const SliceRequest& slice_request = request.request().slice_request(); - ConstantVisitor(session_computation, slice_request.operand(), visited, - is_constant); + PureFunctionalVisitor(session_computation, slice_request.operand(), + num_parameters, visited, is_functional); break; } case OpRequest::kDynamicSliceRequest: { const DynamicSliceRequest& dynamic_slice_request = request.request().dynamic_slice_request(); - ConstantVisitor(session_computation, dynamic_slice_request.operand(), - visited, is_constant); - ConstantVisitor(session_computation, - dynamic_slice_request.start_indices(), visited, - is_constant); + PureFunctionalVisitor(session_computation, + dynamic_slice_request.operand(), num_parameters, + visited, is_functional); + PureFunctionalVisitor(session_computation, + dynamic_slice_request.start_indices(), + num_parameters, visited, is_functional); break; } case OpRequest::kDynamicUpdateSliceRequest: { const DynamicUpdateSliceRequest& dynamic_update_slice_request = request.request().dynamic_update_slice_request(); - ConstantVisitor(session_computation, - dynamic_update_slice_request.operand(), visited, - is_constant); - ConstantVisitor(session_computation, - dynamic_update_slice_request.update(), visited, - is_constant); - ConstantVisitor(session_computation, - dynamic_update_slice_request.start_indices(), visited, - is_constant); + PureFunctionalVisitor(session_computation, + dynamic_update_slice_request.operand(), + num_parameters, visited, is_functional); + PureFunctionalVisitor(session_computation, + dynamic_update_slice_request.update(), + num_parameters, visited, is_functional); + PureFunctionalVisitor(session_computation, + dynamic_update_slice_request.start_indices(), + num_parameters, visited, is_functional); break; } @@ -1549,7 +1651,8 @@ void ConstantVisitor(const SessionComputation& session_computation, request.request().concatenate_request(); for (const ComputationDataHandle& handle : concatenate_request.operands()) { - ConstantVisitor(session_computation, handle, visited, is_constant); + PureFunctionalVisitor(session_computation, handle, num_parameters, + visited, is_functional); } break; } @@ -1557,61 +1660,72 @@ void ConstantVisitor(const SessionComputation& session_computation, case OpRequest::kConvolveRequest: { const ConvolveRequest& convolve_request = request.request().convolve_request(); - ConstantVisitor(session_computation, convolve_request.lhs(), visited, - is_constant); - ConstantVisitor(session_computation, convolve_request.rhs(), visited, - is_constant); + PureFunctionalVisitor(session_computation, convolve_request.lhs(), + num_parameters, visited, is_functional); + PureFunctionalVisitor(session_computation, convolve_request.rhs(), + num_parameters, visited, is_functional); break; } case OpRequest::kCrossReplicaSumRequest: { // TODO(b/33009255): Implmement constant folding for cross replica sum. - *is_constant = false; + *is_functional = false; break; } case OpRequest::kInfeedRequest: { - *is_constant = false; + *is_functional = false; break; } case OpRequest::kOutfeedRequest: { - *is_constant = false; + *is_functional = false; break; } case OpRequest::kCallRequest: { const CallRequest& call_request = request.request().call_request(); for (const ComputationDataHandle& handle : call_request.operands()) { - ConstantVisitor(session_computation, handle, visited, is_constant); + PureFunctionalVisitor(session_computation, handle, num_parameters, + visited, is_functional); } // TODO(b/32495713): We aren't checking the to_apply computation itself, // so we conservatively say that computations containing the Call op - // cannot be constant. We cannot set is_constant=false in other similar + // cannot be constant. We cannot set is_functional=false in other similar // cases since we're already relying on IsConstant to return true. - *is_constant = false; + *is_functional = false; break; } case OpRequest::kCustomCallRequest: { - *is_constant = false; + *is_functional = false; + break; + } + + case OpRequest::kDotRequest: { + const DotRequest& dot_request = request.request().dot_request(); + PureFunctionalVisitor(session_computation, dot_request.lhs(), + num_parameters, visited, is_functional); + PureFunctionalVisitor(session_computation, dot_request.rhs(), + num_parameters, visited, is_functional); break; } case OpRequest::kSendRequest: { - *is_constant = false; + *is_functional = false; break; } case OpRequest::kRecvRequest: { - *is_constant = false; + *is_functional = false; break; } case OpRequest::kMapRequest: { const MapRequest& map_request = request.request().map_request(); for (const ComputationDataHandle& handle : map_request.operands()) { - ConstantVisitor(session_computation, handle, visited, is_constant); + PureFunctionalVisitor(session_computation, handle, num_parameters, + visited, is_functional); } // TODO(b/32495713): We aren't checking the to_apply computation itself. break; @@ -1619,10 +1733,10 @@ void ConstantVisitor(const SessionComputation& session_computation, case OpRequest::kReduceRequest: { const ReduceRequest& reduce_request = request.request().reduce_request(); - ConstantVisitor(session_computation, reduce_request.operand(), visited, - is_constant); - ConstantVisitor(session_computation, reduce_request.init_value(), visited, - is_constant); + PureFunctionalVisitor(session_computation, reduce_request.operand(), + num_parameters, visited, is_functional); + PureFunctionalVisitor(session_computation, reduce_request.init_value(), + num_parameters, visited, is_functional); // TODO(b/32495713): We aren't checking the to_apply computation itself. break; } @@ -1630,10 +1744,12 @@ void ConstantVisitor(const SessionComputation& session_computation, case OpRequest::kReduceWindowRequest: { const ReduceWindowRequest& reduce_window_request = request.request().reduce_window_request(); - ConstantVisitor(session_computation, reduce_window_request.operand(), - visited, is_constant); - ConstantVisitor(session_computation, reduce_window_request.init_value(), - visited, is_constant); + PureFunctionalVisitor(session_computation, + reduce_window_request.operand(), num_parameters, + visited, is_functional); + PureFunctionalVisitor(session_computation, + reduce_window_request.init_value(), num_parameters, + visited, is_functional); // TODO(b/32495713): We aren't checking the to_apply computation itself. break; } @@ -1641,13 +1757,15 @@ void ConstantVisitor(const SessionComputation& session_computation, case OpRequest::kSelectAndScatterRequest: { const SelectAndScatterRequest& select_and_scatter_request = request.request().select_and_scatter_request(); - ConstantVisitor(session_computation, select_and_scatter_request.operand(), - visited, is_constant); - ConstantVisitor(session_computation, select_and_scatter_request.source(), - visited, is_constant); - ConstantVisitor(session_computation, - select_and_scatter_request.init_value(), visited, - is_constant); + PureFunctionalVisitor(session_computation, + select_and_scatter_request.operand(), + num_parameters, visited, is_functional); + PureFunctionalVisitor(session_computation, + select_and_scatter_request.source(), num_parameters, + visited, is_functional); + PureFunctionalVisitor(session_computation, + select_and_scatter_request.init_value(), + num_parameters, visited, is_functional); // TODO(b/32495713): We aren't checking the select and scatter // computations themselves. break; @@ -1656,76 +1774,105 @@ void ConstantVisitor(const SessionComputation& session_computation, case OpRequest::kBroadcastRequest: { const BroadcastRequest& broadcast_request = request.request().broadcast_request(); - ConstantVisitor(session_computation, broadcast_request.operand(), visited, - is_constant); + PureFunctionalVisitor(session_computation, broadcast_request.operand(), + num_parameters, visited, is_functional); break; } case OpRequest::kReshapeRequest: { const ReshapeRequest& reshape_request = request.request().reshape_request(); - ConstantVisitor(session_computation, reshape_request.operand(), visited, - is_constant); + PureFunctionalVisitor(session_computation, reshape_request.operand(), + num_parameters, visited, is_functional); break; } case OpRequest::kReverseRequest: { const ReverseRequest& reverse_request = request.request().reverse_request(); - ConstantVisitor(session_computation, reverse_request.operand(), visited, - is_constant); + PureFunctionalVisitor(session_computation, reverse_request.operand(), + num_parameters, visited, is_functional); break; } case OpRequest::kPadRequest: { const PadRequest& pad_request = request.request().pad_request(); - ConstantVisitor(session_computation, pad_request.operand(), visited, - is_constant); - ConstantVisitor(session_computation, pad_request.padding_value(), visited, - is_constant); + PureFunctionalVisitor(session_computation, pad_request.operand(), + num_parameters, visited, is_functional); + PureFunctionalVisitor(session_computation, pad_request.padding_value(), + num_parameters, visited, is_functional); break; } case OpRequest::kParameterRequest: { - *is_constant = false; + const ParameterRequest& parameter_request = + request.request().parameter_request(); + if (parameter_request.parameter() >= num_parameters) { + *is_functional = false; + } break; } case OpRequest::kConvertRequest: { const ConvertRequest& convert_request = request.request().convert_request(); - ConstantVisitor(session_computation, convert_request.operand(), visited, - is_constant); + PureFunctionalVisitor(session_computation, convert_request.operand(), + num_parameters, visited, is_functional); + break; + } + + case OpRequest::kBitcastConvertRequest: { + const ConvertRequest& convert_request = + request.request().bitcast_convert_request(); + PureFunctionalVisitor(session_computation, convert_request.operand(), + num_parameters, visited, is_functional); break; } case OpRequest::kWhileRequest: { const WhileRequest& while_request = request.request().while_request(); - ConstantVisitor(session_computation, while_request.init(), visited, - is_constant); + PureFunctionalVisitor(session_computation, while_request.init(), + num_parameters, visited, is_functional); // TODO(b/32495713): We aren't checking the condition and body // computations themselves. - *is_constant = false; + *is_functional = false; + break; + } + + case OpRequest::kConditionalRequest: { + const ConditionalRequest& conditional_request = + request.request().conditional_request(); + PureFunctionalVisitor(session_computation, + conditional_request.predicate(), num_parameters, + visited, is_functional); + PureFunctionalVisitor(session_computation, + conditional_request.true_operand(), num_parameters, + visited, is_functional); + PureFunctionalVisitor(session_computation, + conditional_request.false_operand(), num_parameters, + visited, is_functional); + // TODO(b/32495713): We aren't checking the true and false computations + // themselves. break; } case OpRequest::kTernaryOpRequest: { const TernaryOpRequest& ternary_op_request = request.request().ternary_op_request(); - ConstantVisitor(session_computation, ternary_op_request.lhs(), visited, - is_constant); - ConstantVisitor(session_computation, ternary_op_request.rhs(), visited, - is_constant); - ConstantVisitor(session_computation, ternary_op_request.ehs(), visited, - is_constant); + PureFunctionalVisitor(session_computation, ternary_op_request.lhs(), + num_parameters, visited, is_functional); + PureFunctionalVisitor(session_computation, ternary_op_request.rhs(), + num_parameters, visited, is_functional); + PureFunctionalVisitor(session_computation, ternary_op_request.ehs(), + num_parameters, visited, is_functional); break; } case OpRequest::kTransposeRequest: { const TransposeRequest& transpose_request = request.request().transpose_request(); - ConstantVisitor(session_computation, transpose_request.operand(), visited, - is_constant); + PureFunctionalVisitor(session_computation, transpose_request.operand(), + num_parameters, visited, is_functional); break; } @@ -1734,7 +1881,8 @@ void ConstantVisitor(const SessionComputation& session_computation, request.request().variadic_op_request(); for (const ComputationDataHandle& handle : variadic_op_request.operands()) { - ConstantVisitor(session_computation, handle, visited, is_constant); + PureFunctionalVisitor(session_computation, handle, num_parameters, + visited, is_functional); } break; } @@ -1742,67 +1890,74 @@ void ConstantVisitor(const SessionComputation& session_computation, case OpRequest::kUnaryOpRequest: { const UnaryOpRequest& unary_op_request = request.request().unary_op_request(); - ConstantVisitor(session_computation, unary_op_request.operand(), visited, - is_constant); + PureFunctionalVisitor(session_computation, unary_op_request.operand(), + num_parameters, visited, is_functional); break; } case OpRequest::kBatchNormTrainingRequest: { const BatchNormTrainingRequest& batch_norm_training_request = request.request().batch_norm_training_request(); - ConstantVisitor(session_computation, - batch_norm_training_request.operand(), visited, - is_constant); - ConstantVisitor(session_computation, batch_norm_training_request.scale(), - visited, is_constant); - ConstantVisitor(session_computation, batch_norm_training_request.offset(), - visited, is_constant); + PureFunctionalVisitor(session_computation, + batch_norm_training_request.operand(), + num_parameters, visited, is_functional); + PureFunctionalVisitor(session_computation, + batch_norm_training_request.scale(), num_parameters, + visited, is_functional); + PureFunctionalVisitor(session_computation, + batch_norm_training_request.offset(), + num_parameters, visited, is_functional); break; } case OpRequest::kBatchNormInferenceRequest: { const BatchNormInferenceRequest& batch_norm_inference_request = request.request().batch_norm_inference_request(); - ConstantVisitor(session_computation, - batch_norm_inference_request.operand(), visited, - is_constant); - ConstantVisitor(session_computation, batch_norm_inference_request.scale(), - visited, is_constant); - ConstantVisitor(session_computation, - batch_norm_inference_request.offset(), visited, - is_constant); - ConstantVisitor(session_computation, batch_norm_inference_request.mean(), - visited, is_constant); - ConstantVisitor(session_computation, - batch_norm_inference_request.variance(), visited, - is_constant); + PureFunctionalVisitor(session_computation, + batch_norm_inference_request.operand(), + num_parameters, visited, is_functional); + PureFunctionalVisitor(session_computation, + batch_norm_inference_request.scale(), + num_parameters, visited, is_functional); + PureFunctionalVisitor(session_computation, + batch_norm_inference_request.offset(), + num_parameters, visited, is_functional); + PureFunctionalVisitor(session_computation, + batch_norm_inference_request.mean(), num_parameters, + visited, is_functional); + PureFunctionalVisitor(session_computation, + batch_norm_inference_request.variance(), + num_parameters, visited, is_functional); break; } case OpRequest::kBatchNormGradRequest: { const BatchNormGradRequest& batch_norm_grad_request = request.request().batch_norm_grad_request(); - ConstantVisitor(session_computation, batch_norm_grad_request.operand(), - visited, is_constant); - ConstantVisitor(session_computation, batch_norm_grad_request.scale(), - visited, is_constant); - ConstantVisitor(session_computation, batch_norm_grad_request.mean(), - visited, is_constant); - ConstantVisitor(session_computation, batch_norm_grad_request.variance(), - visited, is_constant); - ConstantVisitor(session_computation, - batch_norm_grad_request.grad_output(), visited, - is_constant); + PureFunctionalVisitor(session_computation, + batch_norm_grad_request.operand(), num_parameters, + visited, is_functional); + PureFunctionalVisitor(session_computation, + batch_norm_grad_request.scale(), num_parameters, + visited, is_functional); + PureFunctionalVisitor(session_computation, batch_norm_grad_request.mean(), + num_parameters, visited, is_functional); + PureFunctionalVisitor(session_computation, + batch_norm_grad_request.variance(), num_parameters, + visited, is_functional); + PureFunctionalVisitor(session_computation, + batch_norm_grad_request.grad_output(), + num_parameters, visited, is_functional); break; } case OpRequest::kBinaryOpRequest: { const BinaryOpRequest& binary_op_request = request.request().binary_op_request(); - ConstantVisitor(session_computation, binary_op_request.lhs(), visited, - is_constant); - ConstantVisitor(session_computation, binary_op_request.rhs(), visited, - is_constant); + PureFunctionalVisitor(session_computation, binary_op_request.lhs(), + num_parameters, visited, is_functional); + PureFunctionalVisitor(session_computation, binary_op_request.rhs(), + num_parameters, visited, is_functional); break; } @@ -1817,8 +1972,8 @@ void ConstantVisitor(const SessionComputation& session_computation, } // namespace -StatusOr UserComputation::IsConstant( - const ComputationDataHandle& handle) { +StatusOr UserComputation::IsConstant(const ComputationDataHandle& handle, + int64 num_parameters) { tensorflow::mutex_lock lock(mutex_); // Verify that the handle is valid. @@ -1829,7 +1984,8 @@ StatusOr UserComputation::IsConstant( bool is_constant = true; std::set visited; - ConstantVisitor(session_computation_, handle, &visited, &is_constant); + PureFunctionalVisitor(session_computation_, handle, num_parameters, &visited, + &is_constant); return is_constant; } @@ -1928,6 +2084,21 @@ UserComputation::GetEmbeddedComputations( break; } + case OpRequest::kConditionalRequest: { + CHECK_EQ(2, request.embedded_computation_versions_size()); + const ConditionalRequest& conditional_request = + request.request().conditional_request(); + const VersionedComputationHandle true_computation_versioned_handle = { + conditional_request.true_computation(), + request.embedded_computation_versions(0)}; + computations.push_back(true_computation_versioned_handle); + const VersionedComputationHandle false_computation_versioned_handle = + {conditional_request.false_computation(), + request.embedded_computation_versions(1)}; + computations.push_back(false_computation_versioned_handle); + break; + } + default: // No embedded computation. break; @@ -2014,6 +2185,16 @@ Status UserComputation::RemapEmbeddedComputations( TF_RETURN_IF_ERROR(update(while_request->mutable_body())); break; } + case OpRequest::kConditionalRequest: { + TF_RET_CHECK(2 == request.embedded_computation_versions_size()); + ConditionalRequest* conditional_request = + request.mutable_request()->mutable_conditional_request(); + TF_RETURN_IF_ERROR( + update(conditional_request->mutable_true_computation())); + TF_RETURN_IF_ERROR( + update(conditional_request->mutable_false_computation())); + break; + } default: // No embedded computation. TF_RET_CHECK(0 == request.embedded_computation_versions_size()); @@ -2347,12 +2528,28 @@ static void ForEachOperand( break; } + case OpRequest::kBitcastConvertRequest: { + const ConvertRequest& convert_request = + request.request().bitcast_convert_request(); + apply(convert_request.operand()); + break; + } + case OpRequest::kWhileRequest: { const WhileRequest& while_request = request.request().while_request(); apply(while_request.init()); break; } + case OpRequest::kConditionalRequest: { + const ConditionalRequest& conditional_request = + request.request().conditional_request(); + apply(conditional_request.predicate()); + apply(conditional_request.true_operand()); + apply(conditional_request.false_operand()); + break; + } + case OpRequest::kTernaryOpRequest: { const TernaryOpRequest& ternary_op_request = request.request().ternary_op_request(); @@ -2389,6 +2586,13 @@ static void ForEachOperand( break; } + case OpRequest::kDotRequest: { + const DotRequest& dot_request = request.request().dot_request(); + apply(dot_request.rhs()); + apply(dot_request.lhs()); + break; + } + case OpRequest::kUnaryOpRequest: { const UnaryOpRequest& unary_op_request = request.request().unary_op_request(); @@ -2515,6 +2719,7 @@ HloInstruction* ComputationLowerer::ImplicitBroadcastToExplicitBroadcast( if (ShapeUtil::IsScalar(operand->shape())) { HloInstruction* broadcast = hlo_builder_.AddInstruction( HloInstruction::CreateBroadcast(broadcast_shape, operand, {})); + broadcast->set_metadata(operand->metadata()); if (operand->has_sharding()) { broadcast->set_sharding(operand->sharding()); } @@ -2535,6 +2740,7 @@ HloInstruction* ComputationLowerer::ImplicitBroadcastToExplicitBroadcast( ShapeUtil::MakeShape(operand->shape().element_type(), reshaped_dimensions), operand)); + reshaped_operand->set_metadata(operand->metadata()); if (operand->has_sharding()) { reshaped_operand->set_sharding(operand->sharding()); } @@ -2542,6 +2748,7 @@ HloInstruction* ComputationLowerer::ImplicitBroadcastToExplicitBroadcast( HloInstruction* broadcast = hlo_builder_.AddInstruction(HloInstruction::CreateBroadcast( broadcast_shape, reshaped_operand, broadcast_dimensions)); + broadcast->set_metadata(operand->metadata()); if (operand->has_sharding()) { broadcast->set_sharding(operand->sharding()); } @@ -2665,13 +2872,22 @@ void ComputationLowerer::Visit( break; } + case OpRequest::kDotRequest: { + const DotRequest& dot_request = request.request().dot_request(); + HloInstruction* lhs = lookup_instruction(dot_request.lhs()); + HloInstruction* rhs = lookup_instruction(dot_request.rhs()); + hlo_instruction = add_instruction(HloInstruction::CreateDot( + request.output_shape(), lhs, rhs, dot_request.dimension_numbers())); + break; + } + case OpRequest::kCrossReplicaSumRequest: { const CrossReplicaSumRequest& cross_replica_sum_request = request.request().cross_replica_sum_request(); HloInstruction* operand = lookup_instruction(cross_replica_sum_request.operand()); hlo_instruction = add_instruction(HloInstruction::CreateCrossReplicaSum( - request.output_shape(), operand)); + request.output_shape(), {operand})); break; } @@ -2904,8 +3120,9 @@ void ComputationLowerer::Visit( case OpRequest::kRecvRequest: { const RecvRequest& recv_request = request.request().recv_request(); - hlo_instruction = add_instruction(HloInstruction::CreateRecv( + HloInstruction* recv = add_instruction(HloInstruction::CreateRecv( request.output_shape(), recv_request.channel_handle().handle())); + hlo_instruction = add_instruction(HloInstruction::CreateRecvDone(recv)); break; } @@ -2927,6 +3144,15 @@ void ComputationLowerer::Visit( break; } + case OpRequest::kBitcastConvertRequest: { + const ConvertRequest& convert_request = + request.request().bitcast_convert_request(); + HloInstruction* operand = lookup_instruction(convert_request.operand()); + hlo_instruction = add_instruction(HloInstruction::CreateBitcastConvert( + request.output_shape(), operand)); + break; + } + case OpRequest::kWhileRequest: { const WhileRequest& while_request = request.request().while_request(); CHECK_EQ(2, request.embedded_computation_versions_size()); @@ -2944,6 +3170,30 @@ void ComputationLowerer::Visit( break; } + case OpRequest::kConditionalRequest: { + const ConditionalRequest& conditional_request = + request.request().conditional_request(); + CHECK_EQ(2, request.embedded_computation_versions_size()); + VersionedComputationHandle::Version true_computation_version = + request.embedded_computation_versions(0); + HloComputation* true_computation = ResolveComputation( + conditional_request.true_computation(), true_computation_version); + VersionedComputationHandle::Version false_computation_version = + request.embedded_computation_versions(1); + HloComputation* false_computation = ResolveComputation( + conditional_request.false_computation(), false_computation_version); + HloInstruction* predicate = + lookup_instruction(conditional_request.predicate()); + HloInstruction* true_operand = + lookup_instruction(conditional_request.true_operand()); + HloInstruction* false_operand = + lookup_instruction(conditional_request.false_operand()); + hlo_instruction = add_instruction(HloInstruction::CreateConditional( + request.output_shape(), predicate, true_operand, true_computation, + false_operand, false_computation)); + break; + } + case OpRequest::kTernaryOpRequest: { const TernaryOpRequest& ternary_op_request = request.request().ternary_op_request(); @@ -2951,6 +3201,25 @@ void ComputationLowerer::Visit( HloInstruction* rhs = lookup_instruction(ternary_op_request.rhs()); HloInstruction* ehs = lookup_instruction(ternary_op_request.ehs()); auto hlo_opcode = TernaryOperationToHloOpcode(ternary_op_request.triop()); + + if (debug_options_.xla_eliminate_hlo_implicit_broadcast()) { + if (!ShapeUtil::SameDimensions(request.output_shape(), lhs->shape())) { + // lhs side is being implicitly broadcast. Change to explicit. + lhs = + ImplicitBroadcastToExplicitBroadcast(lhs, request.output_shape()); + } + + if (!ShapeUtil::SameDimensions(request.output_shape(), rhs->shape())) { + rhs = + ImplicitBroadcastToExplicitBroadcast(rhs, request.output_shape()); + } + + if (!ShapeUtil::SameDimensions(request.output_shape(), ehs->shape())) { + ehs = + ImplicitBroadcastToExplicitBroadcast(ehs, request.output_shape()); + } + } + hlo_instruction = add_instruction(HloInstruction::CreateTernary( request.output_shape(), hlo_opcode, lhs, rhs, ehs)); break; @@ -3055,8 +3324,7 @@ void ComputationLowerer::Visit( lhs = (lhs == operand_to_broadcast) ? broadcasted_operand : lhs; rhs = (rhs == operand_to_broadcast) ? broadcasted_operand : rhs; } - if (debug_options_.xla_eliminate_hlo_implicit_broadcast() && - binary_op_request.binop() != BINOP_DOT) { + if (debug_options_.xla_eliminate_hlo_implicit_broadcast()) { if (!ShapeUtil::SameDimensions(request.output_shape(), lhs->shape())) { // lhs side is being implicitly broadcast. Change to explicit. lhs = @@ -3097,8 +3365,9 @@ void ComputationLowerer::Visit( case OpRequest::kSendRequest: { const SendRequest& send_request = request.request().send_request(); HloInstruction* operand = lookup_instruction(send_request.operand()); - hlo_instruction = add_instruction(HloInstruction::CreateSend( + HloInstruction* send = add_instruction(HloInstruction::CreateSend( operand, send_request.channel_handle().handle())); + hlo_instruction = add_instruction(HloInstruction::CreateSendDone(send)); break; } @@ -3109,7 +3378,7 @@ void ComputationLowerer::Visit( LOG(FATAL) << "Unexpected request type: " << request.request().op_case(); } (*instructions)[handle.handle()] = hlo_instruction; -} +} // NOLINT(readability/fn_size) } // namespace diff --git a/tensorflow/compiler/xla/service/user_computation.h b/tensorflow/compiler/xla/service/user_computation.h index dabf68e298ed2600d5248b7b8c7b1e014efedb14..8a78d520e19024f5e397d6e0c2f4e0523264e176 100644 --- a/tensorflow/compiler/xla/service/user_computation.h +++ b/tensorflow/compiler/xla/service/user_computation.h @@ -70,7 +70,7 @@ class UserComputation { // Enqueues a pad instruction onto this user computation. StatusOr AddPadInstruction( - const PadRequest& parameter_request); + const PadRequest& pad_request); // Enqueues a tracing instruction onto this user computation. // Returns an error status if the operand cannot be resolved. @@ -105,7 +105,7 @@ class UserComputation { // Enqueues a ternary instruction onto this user computation. // Returns an error status if the operand indices are out of bounds. StatusOr AddTernaryInstruction( - const TernaryOpRequest& request); + const TernaryOpRequest& ternary_request); // Enqueues a variadic instruction onto this user computation. // Returns an error status if the operand indices are out of bounds. @@ -153,6 +153,10 @@ class UserComputation { StatusOr AddCustomCallInstruction( const CustomCallRequest& custom_call_request); + // Enqueues a dot instruction onto this user computation. + StatusOr AddDotInstruction( + const DotRequest& dot_request); + // Enqueues a broadcast instruction onto this user computation. StatusOr AddBroadcastInstruction( const BroadcastRequest& broadcast_request); @@ -179,26 +183,30 @@ class UserComputation { // Enqueues a concatenate instruction onto this user computation. StatusOr AddConcatenateInstruction( - const ConcatenateRequest& slice_request); + const ConcatenateRequest& concatenate_request); // Enqueues a convert instruction onto this user computation. StatusOr AddConvertInstruction( const ConvertRequest& convert_request); + // Enqueues a bitcast element instruction onto this user computation. + StatusOr AddBitcastConvertInstruction( + const ConvertRequest& convert_request); + // Enqueues a reduce instruction onto this user computation. StatusOr AddReduceInstruction( const ReduceRequest& reduce_request, - const UserComputation& reduction_computation); + const UserComputation& to_apply_computation); // Enqueues a windowed reduce instruction onto this user computation. StatusOr AddReduceWindowInstruction( const ReduceWindowRequest& reduce_window_request, - const UserComputation& reduction_computation); + const UserComputation& to_apply_computation); // Enqueues a select-and-scatter instruction onto this user // computation. StatusOr AddSelectAndScatterInstruction( - const SelectAndScatterRequest& scatter_to_selected_window_element_request, + const SelectAndScatterRequest& select_and_scatter_request, const UserComputation& select_computation, const UserComputation& scatter_computation); @@ -212,6 +220,12 @@ class UserComputation { const UserComputation& condition_computation, const UserComputation& body_computation); + // Enqueues a conditional instruction on this user computation. + StatusOr AddConditionalInstruction( + const ConditionalRequest& conditional_request, + const UserComputation& true_computation, + const UserComputation& false_computation); + // Enqueues a Send instruction onto this user computation. Status AddSendInstruction(const SendRequest& send_request); @@ -250,9 +264,11 @@ class UserComputation { StatusOr> ComputeProgramShape( VersionedComputationHandle::Version version) const; - // Returns true if the given data handle does not depend on any - // parameters. That is, the value can be computed at compile time. - StatusOr IsConstant(const ComputationDataHandle& handle); + // Returns true if the given data handle does not depend on any parameter with + // index higher then num_parameters. That is, the value can be computed at + // compile time if we know the first num_parameters arguments. + StatusOr IsConstant(const ComputationDataHandle& handle, + int64 num_parameters); // Returns the output shape of the operation indicated by the given handle. StatusOr GetShape(const ComputationDataHandle& handle); diff --git a/tensorflow/compiler/xla/service/user_computation_test.cc b/tensorflow/compiler/xla/service/user_computation_test.cc index 5afaf226ae0cce7e9afc966c6b4adf838aeebc91..e45673300b6c5f85be4153f2db821d8abbced7cd 100644 --- a/tensorflow/compiler/xla/service/user_computation_test.cc +++ b/tensorflow/compiler/xla/service/user_computation_test.cc @@ -334,50 +334,5 @@ TEST_F(UserComputationTest, EliminateDegenerateBroadcastAfterIndimBroadcast) { operands[1]->opcode() == HloOpcode::kBroadcast); } -TEST_F(UserComputationTest, SkipDotInEliminatingImplicitBroadcast) { - auto debug_options = DebugOptions(); - debug_options.set_xla_eliminate_hlo_implicit_broadcast(true); - - // %a = Param({1, 3}); - // %b = Param({3, 1}); - // %dot = Dot(%a, %b); - ComputationHandle handle; - handle.set_handle(123); - UserComputation computation("TheComputation", handle); - - ParameterRequest a_request; - *a_request.mutable_shape() = ShapeUtil::MakeShape(F32, {1, 3}); - a_request.set_name("a"); - a_request.set_parameter(0); - TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle a_handle, - computation.AddParameterInstruction(a_request)); - - ParameterRequest b_request; - *b_request.mutable_shape() = ShapeUtil::MakeShape(F32, {3, 1}); - b_request.set_name("b"); - b_request.set_parameter(1); - TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle b_handle, - computation.AddParameterInstruction(b_request)); - - BinaryOpRequest dot; - dot.set_binop(BINOP_DOT); - *dot.mutable_lhs() = a_handle; - *dot.mutable_rhs() = b_handle; - TF_ASSERT_OK(computation.AddBinaryInstruction(dot).status()); - - auto hlo_resolver = [](const VersionedComputationHandle& handle) { - return nullptr; - }; - VersionedComputationHandle latest_version = computation.GetVersionedHandle(); - - // Build the HLO computation. - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr hlo_computation, - computation.BuildHloComputation(latest_version.version, hlo_resolver, - debug_options)); - - EXPECT_EQ(3, hlo_computation->instruction_count()); -} - } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.cc b/tensorflow/compiler/xla/service/while_loop_simplifier.cc new file mode 100644 index 0000000000000000000000000000000000000000..b2fd64a4d9f3dc343b2e44b5efa31aacc6085042 --- /dev/null +++ b/tensorflow/compiler/xla/service/while_loop_simplifier.cc @@ -0,0 +1,644 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/while_loop_simplifier.h" +#include "tensorflow/compiler/xla/service/call_inliner.h" +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" +#include "tensorflow/core/lib/gtl/optional.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" + +namespace xla { + +using tensorflow::gtl::nullopt; +using tensorflow::gtl::optional; + +// Finds and returns the non-constant operand in instr. +// +// CHECK-fails if instr doesn't have exactly one unique non-constant operand. +static const HloInstruction* NonConstantOperand(const HloInstruction* instr) { + const HloInstruction* result = nullptr; + for (const HloInstruction* operand : instr->operands()) { + if (!operand->IsConstant()) { + if (result != nullptr) { + CHECK_EQ(result, operand); + } + result = operand; + } + } + CHECK_NE(result, nullptr); + return result; +} + +// Determines whether the given instruction is a send/recv node, or has a +// subcomputation which contains a send/recv node. +static bool IsOrContainsSendOrRecv(const HloInstruction* instr); + +// Determines whether the given computation contains a send or recv node. +static bool ContainsSendOrRecv(const HloComputation* comp) { + for (const auto* instr : comp->instructions()) { + if (IsOrContainsSendOrRecv(instr)) { + return true; + } + } + return false; +} + +static bool IsOrContainsSendOrRecv(const HloInstruction* instr) { + if (instr->opcode() == HloOpcode::kSend || + instr->opcode() == HloOpcode::kSendDone || + instr->opcode() == HloOpcode::kRecv || + instr->opcode() == HloOpcode::kRecvDone) { + return true; + } + for (const auto& subcomp : instr->called_computations()) { + if (ContainsSendOrRecv(subcomp)) { + return true; + } + } + return false; +} + +// If all of instr's operands are either constants or have the form +// get-tuple-element(gte_operand, N) +// for the same value N, returns N. Otherwise, returns nullopt. +static optional GetGTEOperandIndex(const HloInstruction* instr, + const HloInstruction* gte_operand) { + VLOG(2) << "GetGTEOperandIndex(" << instr->ToString() << ", " + << gte_operand->ToString() << ")"; + optional tuple_idx; + for (const HloInstruction* operand : instr->operands()) { + if (operand->IsConstant()) { + continue; + } + if (operand->opcode() != HloOpcode::kGetTupleElement) { + VLOG(2) << "instr uses something other than gte(gte_operand): " + << operand->ToString(); + return nullopt; + } + if (operand->operand(0) != gte_operand) { + VLOG(2) << "instr has gte whose operand is not gte_operand: " + << operand->ToString(); + return nullopt; + } + if (tuple_idx && tuple_idx != operand->tuple_index()) { + VLOG(2) << "instr has operands with conflicting gte indices, " + << *tuple_idx << " vs " << operand->tuple_index(); + return nullopt; + } + + tuple_idx = operand->tuple_index(); + } + return tuple_idx; +} + +// Tries to get the tuple index of the induction variable of a while loop. +// +// Checks that the loop condition and root both plumb the induction variable +// through the same tuple index, and that they both apply exactly one op to the +// induction variable before deciding whether to do another loop iteration (in +// the loop condition's case) or packing the induction variable into the result +// tuple (in the loop body's case). +// +// Specifically, checks that the loop condition has structure +// +// root = op(constants, get-tuple-elem(param0, N), constants) +// +// and the loop body has the structure +// +// inc = op(constants, get-tuple-elem(param0, N), constants) +// root = tuple(..., inc, ...) // inc is N'th operand of tuple(). +// +// If so, returns N. Otherwise, returns nullopt. +static optional GetLoopInductionVarTupleIdx( + const HloInstruction* while_op) { + CHECK_EQ(while_op->opcode(), HloOpcode::kWhile); + VLOG(2) << "Finding induction variable for loop " + << while_op->ToShortString(); + + // The while_cond computation should have the form + // + // while_cond_root = + // op(constants, get-tuple-elem(while_cond_param, N), constants). + // + // If it does, set indvar_tuple_idx to N. + auto* while_cond = while_op->while_condition(); + auto* while_cond_root = while_cond->root_instruction(); + auto* while_cond_param = while_cond->parameter_instruction(0); + optional indvar_tuple_idx = + GetGTEOperandIndex(while_cond_root, while_cond_param); + if (!indvar_tuple_idx) { + VLOG(2) << "Induction variable not found in loop condition: " + << while_cond->root_instruction()->ToString(); + return nullopt; + } + + // The while_body computation should have the form + // + // while_body_inc = + // op(constants, get-tuple-elem(while_body_param, N), constants) + // while_body_root = tuple(..., while_body_inc, ...) + // + // where while_body_inc is operand N of while_body_root. + auto* while_body = while_op->while_body(); + auto* while_body_root = while_body->root_instruction(); + if (while_body_root->opcode() != HloOpcode::kTuple) { + VLOG(2) << "While body's root is not a tuple instruction: " + << while_body_root->ToString(); + return nullopt; + } + + auto* while_body_inc = while_body_root->operand(*indvar_tuple_idx); + auto* while_body_param = while_body->parameter_instruction(0); + optional while_body_indvar_tuple_idx = + GetGTEOperandIndex(while_body_inc, while_body_param); + if (!while_body_indvar_tuple_idx) { + VLOG(2) + << "Induction variable not found in while body increment instruction: " + << while_body_inc->ToString(); + return nullopt; + } + if (while_body_indvar_tuple_idx != indvar_tuple_idx) { + VLOG(2) << "Tuple index of induction variable does not match between loop " + "condition (" + << *indvar_tuple_idx << ") and while body (" + << *while_body_indvar_tuple_idx << ")"; + return nullopt; + } + + // Finally, check that the while loop's initial value is a tuple with enough + // elements. + auto* while_init = while_op->operand(0); + if (while_init->opcode() != HloOpcode::kTuple) { + VLOG(2) << "While init expected to be a tuple: " << while_init->ToString(); + return nullopt; + } + + VLOG(2) << "Induction variable's tuple index: " << *indvar_tuple_idx; + return indvar_tuple_idx; +} + +// Tries to determine the number of times the given loop executes. Currently +// simply returns 0, 1, or "can't tell" (nullopt). +static optional GetLoopTripCount(HloInstruction* while_op) { + CHECK_EQ(while_op->opcode(), HloOpcode::kWhile); + VLOG(2) << "Getting trip count for loop " << while_op->ToString(); + + // The loop's induction variable is found at + // + // get-tuple-elem(comp->parameter_instruction(0), *indvar_tuple_idx), + // + // where comp is while_op->while_body() or while_op->while_condition(). + optional indvar_tuple_idx = GetLoopInductionVarTupleIdx(while_op); + if (!indvar_tuple_idx) { + return nullopt; + } + + VLOG(2) << "Induction variable is at index " << *indvar_tuple_idx + << " in input tuple."; + + // Now that we know the index of the induction variable, we can we can try to + // compute how many times the loop executes. Start by computing the induction + // variable's initial value. + HloEvaluator evaluator; + auto* while_init = while_op->mutable_operand(0); + auto* indvar_init = while_init->mutable_operand(*indvar_tuple_idx); + StatusOr> indvar_init_result = + evaluator.Evaluate(indvar_init); + if (!indvar_init_result.ok()) { + VLOG(2) << "Couldn't evaluate induction variable init: " + << indvar_init_result.status(); + return nullopt; + } + + // Evaluates the while loop's condition, returning either "true" (continue + // looping), "false" (stop looping), or nullopt (can't evaluate). + auto evaluate_while_cond = [&](const Literal& indvar) -> optional { + auto* while_cond = while_op->while_condition(); + auto* while_cond_root = while_cond->root_instruction(); + auto* while_cond_indvar = NonConstantOperand(while_cond_root); + StatusOr> result = + evaluator.EvaluateWithSubstitutions(while_cond_root, + {{while_cond_indvar, &indvar}}); + if (!result.ok()) { + VLOG(2) << "Couldn't evaluate while cond: " << result.status(); + return nullopt; + } + return result.ValueOrDie()->GetArraySlice() == + tensorflow::gtl::ArraySlice{true}; + }; + + // The initial value of the induction variable. + const Literal& indvar_iter0_val = *indvar_init_result.ValueOrDie(); + + // Evaluate whether the while condition is true when seeded with + // indvar_iter0_val. + optional while_cond_iter0_val = evaluate_while_cond(indvar_iter0_val); + if (while_cond_iter0_val == false) { + VLOG(2) << "Loop has static trip count of 0."; + return 0; + } + + // Calculate the value of the induction variable after one iteration of the + // loop, and check whether the while condition is true with this new value. + auto* while_body = while_op->while_body(); + auto* while_body_indvar_update = + while_body->root_instruction()->operand(*indvar_tuple_idx); + auto* while_body_indvar = NonConstantOperand(while_body_indvar_update); + StatusOr> indvar_iter1_result = + evaluator.EvaluateWithSubstitutions( + while_body_indvar_update, {{while_body_indvar, &indvar_iter0_val}}); + if (!indvar_iter1_result.ok()) { + VLOG(2) << "Couldn't evaluate induction variable update: " + << indvar_iter1_result.status(); + return nullopt; + } + const Literal& indvar_iter1_val = *indvar_iter1_result.ValueOrDie(); + optional while_cond_iter1_val = evaluate_while_cond(indvar_iter1_val); + if (while_cond_iter1_val == false) { + VLOG(2) << "Determined that loop has static trip count of 1."; + return 1; + } + + VLOG(2) << "Loop has unknown trip count >= 1."; + return nullopt; +} + +// Tries to remove elements in a while loop's tuple that aren't used within the +// loop. +// +// Specifically, if a loop is tuple-shaped, and there exists some element of +// that tuple that is not used by the loop condition and is not used by the loop +// body except to pass it to the next iteration of the loop, then we can remove +// that element from the loop's tuples. +static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { + CHECK_EQ(while_op->opcode(), HloOpcode::kWhile); + + // Don't try this transformation if the while loop isn't removable, since if + // it succeeds ultimately we're going to have to replace the old while loop + // with a new one. + if (!while_op->parent()->IsRemovable(while_op) || while_op->HasSideEffect()) { + VLOG(2) << "Can't remove dead parameters from non-removable while op."; + return false; + } + + HloModule* module = while_op->GetModule(); + HloComputation* computation = while_op->parent(); + HloInstruction* while_init = while_op->mutable_operand(0); + HloComputation* while_cond = while_op->while_condition(); + HloComputation* while_body = while_op->while_body(); + HloInstruction* while_body_root = while_body->root_instruction(); + + if (!ShapeUtil::IsTuple(while_init->shape())) { + VLOG(2) << "While op's carried value isn't tuple shaped."; + return false; + } + + // Bail if param0 of while_cond or while_body has users which aren't of type + // get-tuple-element. + for (const HloInstruction* instr : {while_body->parameter_instruction(0), + while_cond->parameter_instruction(0)}) { + for (const HloInstruction* user : instr->users()) { + if (user->opcode() != HloOpcode::kGetTupleElement) { + VLOG(2) << "Cowardly refusing to analyze while loop with " + << instr->ToStringNoMetadata() + << " used by non-GTE instruction " << user->ToStringNoMetadata() + << " in computation " << instr->parent()->name(); + return false; + } + } + } + + const int64 tuple_size = ShapeUtil::TupleElementCount(while_init->shape()); + if (tuple_size == 0) { + VLOG(2) << "Can't remove elements from while loop's tuple -- it's already " + "empty."; + return false; + } + + tensorflow::gtl::FlatSet used_tuple_indices; + for (HloComputation* comp : {while_body, while_cond}) { + // The HLO verifier ensures that while_input's shape matches while_init's + // shape, which we verified above is a tuple. + HloInstruction* while_input = comp->parameter_instruction(0); + + for (const HloInstruction* user : while_input->users()) { + // This user doesn't count if it's only used by the while body's root, and + // the root places the tuple element into the same index of the tuple as + // it came from. That just amounts to us carrying the variable through + // the loop. + // + // Careful: HloInstruction::operand_index returns the first index the + // operand appears in, but it may appear more than once! + if (user->user_count() == 1 && user->users().front() == while_body_root && + while_body_root->operand_index(user) == user->tuple_index() && + std::count(while_body_root->operands().begin(), + while_body_root->operands().end(), user) == 1) { + continue; + } + + used_tuple_indices.insert(user->tuple_index()); + if (used_tuple_indices.size() == tuple_size) { + VLOG(2) << "Loop " << while_op->ToStringNoMetadata() + << " uses all of its inputs; no simplification possible."; + return false; + } + } + } + + // If a tuple element is not passed unmodified from the while body's param0 + // through to the while body's root, count that element as "used", since + // removing that element would be observable. + for (int64 i = 0; i < while_body_root->operand_count(); ++i) { + if (used_tuple_indices.count(i)) { + continue; + } + + auto* operand = while_body_root->operand(i); + if (operand->opcode() != HloOpcode::kGetTupleElement || + operand->operand(0) != while_body->parameter_instruction(0) || + operand->tuple_index() != i) { + VLOG(2) << "Tuple index " << i + << " is not passed through loop body unmodified."; + used_tuple_indices.insert(i); + + if (used_tuple_indices.size() == tuple_size) { + VLOG(2) << "Loop " << while_op->ToStringNoMetadata() + << " uses all of its inputs; no simplification possible."; + return false; + } + } + } + + // If we got here, used_tuple_indices.size() < tuple_size, meaning some + // elements of the loop's tuple aren't used by while_body or while_cond. + CHECK_LT(used_tuple_indices.size(), tuple_size); + + VLOG(1) << "Eliminating " << tuple_size - used_tuple_indices.size() + << " elements from tuple of " << while_op->ToStringNoMetadata(); + + // Build up maps from the old/new to the new/old tuple indices. + std::vector new_to_old_tuple_idx(used_tuple_indices.begin(), + used_tuple_indices.end()); + std::sort(new_to_old_tuple_idx.begin(), new_to_old_tuple_idx.end()); + + tensorflow::gtl::FlatMap old_to_new_tuple_idx; + for (int64 new_idx = 0; new_idx < new_to_old_tuple_idx.size(); ++new_idx) { + int64 old_idx = new_to_old_tuple_idx[new_idx]; + old_to_new_tuple_idx[old_idx] = new_idx; + VLOG(2) << "Remapping tuple index " << old_idx << " to " << new_idx; + } + + // Compute the shape of the while op after we remove the dead indices. + std::vector new_while_tuple_elem_shapes; + new_while_tuple_elem_shapes.reserve(new_to_old_tuple_idx.size()); + for (int64 old_idx : new_to_old_tuple_idx) { + new_while_tuple_elem_shapes.push_back( + while_init->shape().tuple_shapes(old_idx)); + } + Shape new_while_shape = + ShapeUtil::MakeTupleShape(new_while_tuple_elem_shapes); + + // Returns a map from elements in the computation to new instructions which + // replace the old instructions after we remove unused elements from the while + // tuple. + auto make_while_computation_replacements = [&](const HloComputation* comp) { + std::unordered_map> + replacements; + + auto* param = comp->parameter_instruction(0); + replacements.emplace(param, HloInstruction::CreateParameter( + 0, new_while_shape, param->name())); + + // Materialize param's users, since we're about to add new ones below. + std::vector materialized_users(param->users().begin(), + param->users().end()); + for (const auto* user : materialized_users) { + // The while body root is handled separately. + if (user == while_body_root) { + continue; + } + CHECK_EQ(user->opcode(), HloOpcode::kGetTupleElement) + << user->ToStringNoMetadata(); + + int64 old_idx = user->tuple_index(); + auto new_idx_iter = old_to_new_tuple_idx.find(old_idx); + if (new_idx_iter != old_to_new_tuple_idx.end()) { + // This is a GTE of an index that survives. Replace it. + replacements.emplace( + user, HloInstruction::CreateGetTupleElement(user->shape(), param, + new_idx_iter->second)); + } else { + // This is a GTE of an index that we've removed. Remove it from the + // cloned computation. + CHECK(user->user_count() == 0 || + user->user_count() == 1 && + user->users().front() == while_body_root) + << "Instruction " << user->ToStringNoMetadata() + << " should be unused (except by root of while body), but has " + "users: {" + << tensorflow::str_util::Join( + user->users(), ", ", + [](string* out, const HloInstruction* instr) { + tensorflow::strings::StrAppend( + out, instr->ToStringNoMetadata()); + }) + << "}"; + + replacements.emplace(user, nullptr); + } + } + return replacements; + }; + + // Create the new while condition, body, and init value. + std::unique_ptr new_while_cond = + while_cond->CloneWithReplacements( + make_while_computation_replacements(while_cond)); + + std::unordered_map> + while_body_replacements = make_while_computation_replacements(while_body); + std::vector new_while_body_root_elems; + new_while_body_root_elems.reserve(new_to_old_tuple_idx.size()); + for (int64 old_idx : new_to_old_tuple_idx) { + new_while_body_root_elems.push_back( + while_body_root->mutable_operand(old_idx)); + } + while_body_replacements.emplace( + while_body_root, HloInstruction::CreateTuple(new_while_body_root_elems)); + std::unique_ptr new_while_body = + while_body->CloneWithReplacements(std::move(while_body_replacements)); + + // Add a new while_init instruction that repackages the old while_init + // instruction's elements. We rely on the AlgebraicSimplifier and DCE to + // clean this up in the common case where while_init is a tuple op. (It's + // definitely tuple-shaped, but it's not necessarily a tuple op.) + std::vector new_while_init_elems; + new_while_init_elems.reserve(new_to_old_tuple_idx.size()); + for (int64 old_idx : new_to_old_tuple_idx) { + new_while_init_elems.push_back( + computation->AddInstruction(HloInstruction::CreateGetTupleElement( + while_init->shape().tuple_shapes(old_idx), while_init, old_idx))); + } + auto* new_while_init = computation->AddInstruction( + HloInstruction::CreateTuple(new_while_init_elems)); + + // Create the new while op. + auto* new_while_op = computation->AddInstruction(HloInstruction::CreateWhile( + new_while_shape, + module->AddEmbeddedComputation(std::move(new_while_cond)), + module->AddEmbeddedComputation(std::move(new_while_body)), + new_while_init)); + + // Create a tuple op that recreates the output of the old while op. That is, + // we transform to + // + // new_while_init while_init + // | | + // V | + // new_while | + // | | + // -------| |---- + // V V + // new_tuple + // | + // V + // (orig. users of while op) + // + // The tuple simplifier will then simplify this if possible, removing + // new_tuple and while_init. + std::vector new_tuple_elems; + for (int64 old_idx = 0; old_idx < tuple_size; ++old_idx) { + auto new_tuple_idx_it = old_to_new_tuple_idx.find(old_idx); + if (new_tuple_idx_it != old_to_new_tuple_idx.end()) { + int64 gte_idx = new_tuple_idx_it->second; + new_tuple_elems.push_back( + computation->AddInstruction(HloInstruction::CreateGetTupleElement( + new_while_op->shape().tuple_shapes(gte_idx), new_while_op, + gte_idx))); + } else { + new_tuple_elems.push_back( + computation->AddInstruction(HloInstruction::CreateGetTupleElement( + while_init->shape().tuple_shapes(old_idx), while_init, old_idx))); + } + } + HloInstruction* new_tuple = + computation->AddInstruction(HloInstruction::CreateTuple(new_tuple_elems)); + TF_RETURN_IF_ERROR(while_op->ReplaceAllUsesWith(new_tuple)); + + return true; +} + +// Tries to remove a while loop from the graph. +// +// - Loops with trip count of 0 can be replaced by the loop's "init" value. +// - Loops with trip count of 1 can be replaced by the loop's body, with the +// loop itself removed. +// +// Returns true if it made a change to the graph. +static StatusOr TryRemoveWhileLoop(HloInstruction* while_op) { + // Cowardly refuse to remove loops that are not removable. In practice, + // this means that we can't remove loops that contain side-effecting + // instructions or have control predecessors/successors. + // + // This is not a fundamental limitation. The control operands can be moved + // onto the new HLOs after simplification, and any side-effecting ops inside + // the loop aren't removed, just cloned and added back to the loop. + // Nevertheless our infrastructure sees loop simplification as removal of + // these nodes and currently doesn't allow it. + if (!while_op->parent()->IsRemovable(while_op) || while_op->HasSideEffect()) { + VLOG(2) << "Not attempting to remove while loop it is not removable: " + << while_op->ToShortString(); + return false; + } + + // Remove while loops with static trip count of 0. + optional trip_count = GetLoopTripCount(while_op); + if (trip_count && *trip_count == 0) { + // The loop never executes, so the value of the loop is the value of its + // "init" operand. + auto computation = while_op->parent(); + + // Remove while_op (i.e., call ReplaceInstruction rather than + // ReplaceUsesWithInstruction) so that if the algebraic simplifier is run in + // a loop without an intervening DCE, we don't try to re-remove the loop. + TF_RETURN_IF_ERROR(computation->ReplaceInstruction( + while_op, while_op->mutable_operand(0))); + return true; + } + + // Transform while loops with static trip count of 1 into a call op, then + // inline the call. + if (trip_count && *trip_count == 1) { + auto computation = while_op->parent(); + auto call_op = computation->AddInstruction(HloInstruction::CreateCall( + while_op->shape(), while_op->operands(), while_op->while_body())); + TF_RETURN_IF_ERROR(computation->ReplaceInstruction(while_op, call_op)); + TF_RETURN_IF_ERROR(CallInliner::Inline(call_op)); + return true; + } + return false; +} + +StatusOr WhileLoopSimplifier::Run(HloModule* module) { + XLA_VLOG_LINES(3, + "WhileLoopSimplifier::Run(), before:\n" + module->ToString()); + bool changed = false; + + // Gather all the while ops in our module. We do this ahead of time so we + // don't have to worry about mutating the lists of computations or + // instructions while we iterate. + std::vector while_ops; + for (auto* comp : module->computations()) { + for (auto* instr : comp->instructions()) { + if (instr->opcode() == HloOpcode::kWhile) { + while_ops.push_back(instr); + } + } + } + + for (HloInstruction* while_op : while_ops) { + // We can't remove while loops that contain send/recv nodes, because we rely + // on the particular loop structure around the node matching on the send and + // recv sides. Removing dead while params requires us to remove the loop + // and replace it with a new one, so we can't do that either. + if (ContainsSendOrRecv(while_op->while_body()) || + ContainsSendOrRecv(while_op->while_condition())) { + VLOG(2) << "Not attempting to simplify while loop because it contains a " + "send/recv node: " + << while_op->ToShortString(); + continue; + } + + StatusOr result = TryRemoveWhileLoop(while_op); + TF_RETURN_IF_ERROR(result.status()); + if (result.ValueOrDie()) { + changed = true; + // Don't try to remove dead while params after successfully removing the + // while loop -- that would result in use-after-free nastiness. + continue; + } + + result = TryRemoveDeadWhileParams(while_op); + TF_RETURN_IF_ERROR(result.status()); + changed |= result.ValueOrDie(); + } + + XLA_VLOG_LINES(3, + "WhileLoopSimplifier::Run(), after:\n" + module->ToString()); + return changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.h b/tensorflow/compiler/xla/service/while_loop_simplifier.h new file mode 100644 index 0000000000000000000000000000000000000000..50dac32a4ab0a5de756c1ddf5e62c3560e54a079 --- /dev/null +++ b/tensorflow/compiler/xla/service/while_loop_simplifier.h @@ -0,0 +1,44 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_SIMPLIFIER_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_SIMPLIFIER_H_ + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +// HLO pass that makes the following transformations on while loops: +// +// - A while loop with static trip count of 0 is deleted. +// - A while loops with static trip count of 1 is replaced by its body (sans +// loop). +// - Elements of a while loop's tuple that the loop doesn't use are removed +// from the tuple. +// +class WhileLoopSimplifier : public HloPassInterface { + public: + ~WhileLoopSimplifier() override {} + tensorflow::StringPiece name() const override { + return "simplify-while-loops"; + } + StatusOr Run(HloModule* module) override; +}; + +} // namespace xla + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_SIMPLIFIER_H_ diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..d99b31dc0037968bc88d5f22d53309a6a4546963 --- /dev/null +++ b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc @@ -0,0 +1,422 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/while_loop_simplifier.h" + +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +namespace op = xla::testing::opcode_matchers; + +class WhileLoopSimplifierTest : public HloVerifiedTestBase { + public: + // Makes a computation that contains a loop that runs num_iters times. + HloComputation* MakeSimpleLoop(int num_iters, HloModule* module); + + // Makes a computation which has one parameter, of the given shape, and always + // returns PRED[]{true}. This is useful as a dummy loop condition. + HloComputation* MakeAlwaysTrueComputation(const Shape& param_shape, + HloModule* module); +}; + +HloComputation* WhileLoopSimplifierTest::MakeSimpleLoop(int num_iters, + HloModule* module) { + HloComputation::Builder builder(TestName()); + + auto loop_iter_init = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42))); + auto loop_data_init = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({0, 1, 2}))); + auto loop_init = builder.AddInstruction( + HloInstruction::CreateTuple({loop_iter_init, loop_data_init})); + + HloComputation* condition; + { + HloComputation::Builder cond_builder(TestName() + ".condition"); + auto loop_var = cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, loop_init->shape(), "loop_var")); + auto loop_induction_var = + cond_builder.AddInstruction(HloInstruction::CreateGetTupleElement( + ShapeUtil::MakeShape(S32, {}), loop_var, 0)); + auto limit = cond_builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR0(42 + num_iters))); + cond_builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, loop_induction_var, + limit)); + condition = module->AddEmbeddedComputation(cond_builder.Build()); + } + + HloComputation* body; + { + HloComputation::Builder body_builder(TestName() + ".body"); + auto loop_var = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, loop_init->shape(), "loop_var")); + auto loop_induction_var = + body_builder.AddInstruction(HloInstruction::CreateGetTupleElement( + ShapeUtil::MakeShape(S32, {}), loop_var, 0)); + auto new_loop_induction_var = + body_builder.AddInstruction(HloInstruction::CreateBinary( + loop_induction_var->shape(), HloOpcode::kAdd, loop_induction_var, + body_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1))))); + auto loop_data = + body_builder.AddInstruction(HloInstruction::CreateGetTupleElement( + loop_data_init->shape(), loop_var, 1)); + auto new_loop_data = + body_builder.AddInstruction(HloInstruction::CreateBinary( + loop_data_init->shape(), HloOpcode::kMultiply, loop_data, + loop_data)); + body_builder.AddInstruction( + HloInstruction::CreateTuple({new_loop_induction_var, new_loop_data})); + body = module->AddEmbeddedComputation(body_builder.Build()); + } + + builder.AddInstruction(HloInstruction::CreateWhile( + loop_init->shape(), condition, body, loop_init)); + + return module->AddEntryComputation(builder.Build()); +} + +HloComputation* WhileLoopSimplifierTest::MakeAlwaysTrueComputation( + const Shape& param_shape, HloModule* module) { + HloComputation::Builder builder(TestName() + ".always_true"); + builder.AddInstruction( + HloInstruction::CreateParameter(0, param_shape, "param")); + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(true))); + return module->AddEmbeddedComputation(builder.Build()); +} + +TEST_F(WhileLoopSimplifierTest, WhileLoopWithZeroIterations) { + HloComputation* computation = MakeSimpleLoop(/*num_iters=*/0, &module()); + ASSERT_TRUE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); + EXPECT_THAT(computation->root_instruction(), + op::Tuple(op::Constant(), op::Constant())); +} + +TEST_F(WhileLoopSimplifierTest, WhileLoopWithOneIteration) { + HloComputation* computation = MakeSimpleLoop(/*num_iters=*/1, &module()); + ASSERT_TRUE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); + EXPECT_THAT(computation->root_instruction(), + op::Tuple(op::Add(), op::Multiply())); +} + +TEST_F(WhileLoopSimplifierTest, WhileLoopWithTwoIterations) { + MakeSimpleLoop(/*num_iters=*/2, &module()); + EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); +} + +TEST_F(WhileLoopSimplifierTest, WhileLoopWithControlDependency) { + HloComputation* computation = MakeSimpleLoop(/*num_iters=*/1, &module()); + auto* while_op = computation->root_instruction(); + ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile); + auto* true_op = while_op->while_body()->AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(true))); + TF_ASSERT_OK(true_op->AddControlDependencyTo( + while_op->while_body()->root_instruction())); + ASSERT_TRUE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); + EXPECT_THAT(computation->root_instruction()->control_predecessors(), + ElementsAre(op::Constant())) + << computation->ToString(); +} + +// Loops that contain send/recv nodes can't be simplified; the loop structure +// around send/recv nodes must be preserved. +TEST_F(WhileLoopSimplifierTest, NotRemovedIfContainsSend) { + HloComputation* computation = MakeSimpleLoop(/*num_iters=*/1, &module()); + auto* while_op = computation->root_instruction(); + ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile); + auto* while_body = while_op->while_body(); + auto* send = while_body->AddInstruction(HloInstruction::CreateSend( + while_body->AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(true))), + /*channel_id=*/0)); + while_body->AddInstruction(HloInstruction::CreateSendDone(send)); + EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); +} + +TEST_F(WhileLoopSimplifierTest, NotRemovedIfContainsRecv) { + HloComputation* computation = MakeSimpleLoop(/*num_iters=*/1, &module()); + auto* while_op = computation->root_instruction(); + ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile); + auto* while_body = while_op->while_body(); + auto* recv = while_body->AddInstruction( + HloInstruction::CreateRecv(ShapeUtil::MakeShape(F32, {1}), + /*channel_id=*/0)); + while_body->AddInstruction(HloInstruction::CreateRecvDone(recv)); + EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); +} + +// The limitation on not being able to simplify loops that contain infeeds (and +// other non-removable instructions) isn't fundamental -- it just stems from the +// fact that our infrastructure sees simplifying such a loop as tantamount to +// removing the non-removable instruction. +TEST_F(WhileLoopSimplifierTest, NotRemovedIfContainsNonRemovableInstruction) { + HloComputation* computation = MakeSimpleLoop(/*num_iters=*/1, &module()); + auto* while_op = computation->root_instruction(); + ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile); + auto* while_body = while_op->while_body(); + while_body->AddInstruction( + HloInstruction::CreateInfeed(ShapeUtil::MakeShape(F32, {1}), "config")); + EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); +} + +// Check that we don't crash when given a loop whose shape is not a tuple. +TEST_F(WhileLoopSimplifierTest, IgnoreNonTupleShapedLoop) { + HloComputation::Builder builder(TestName()); + auto loop_init = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42))); + + HloComputation* condition; + { + HloComputation::Builder cond_builder(TestName() + ".condition"); + auto param = cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, loop_init->shape(), "loop_var")); + cond_builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, param, + cond_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(100))))); + condition = module().AddEmbeddedComputation(cond_builder.Build()); + } + + HloComputation* body; + { + HloComputation::Builder body_builder(TestName() + ".body"); + auto param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, loop_init->shape(), "loop_var")); + body_builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(S32, {}), HloOpcode::kAdd, param, + body_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(-1))))); + body = module().AddEmbeddedComputation(body_builder.Build()); + } + + builder.AddInstruction(HloInstruction::CreateWhile( + loop_init->shape(), condition, body, loop_init)); + + module().AddEntryComputation(builder.Build()); + EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); +} + +// Construct a loop where we swap the tuple elements in each iteration. +// Although the tuple elements aren't used in the loop, we don't eliminate them, +// because the swapping side-effect is visible to users of the loop. +TEST_F(WhileLoopSimplifierTest, SwapTupleIndices) { + HloComputation::Builder builder(TestName()); + auto loop_init = builder.AddInstruction(HloInstruction::CreateTuple({ + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0))), + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1))), + })); + + HloComputation* condition = + MakeAlwaysTrueComputation(loop_init->shape(), &module()); + HloComputation* body; + { + HloComputation::Builder body_builder(TestName() + ".body"); + auto param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, loop_init->shape(), "loop_var")); + auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); + body_builder.AddInstruction(HloInstruction::CreateTuple({ + body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, 1)), + body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, 0)), + })); + body = module().AddEmbeddedComputation(body_builder.Build()); + } + + builder.AddInstruction(HloInstruction::CreateWhile( + loop_init->shape(), condition, body, loop_init)); + + module().AddEntryComputation(builder.Build()); + EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); +} + +// Construct a loop where we assign a constant to tuple element 0 in each +// iteration. We can't eliminate tuple element 0, even though we never use its +// value. +TEST_F(WhileLoopSimplifierTest, UnusedButModifiedTupleElement) { + HloComputation::Builder builder(TestName()); + auto loop_init = builder.AddInstruction( + HloInstruction::CreateTuple({builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0)))})); + + HloComputation* condition = + MakeAlwaysTrueComputation(loop_init->shape(), &module()); + HloComputation* body; + { + HloComputation::Builder body_builder(TestName() + ".body"); + body_builder.AddInstruction( + HloInstruction::CreateParameter(0, loop_init->shape(), "loop_var")); + body_builder.AddInstruction(HloInstruction::CreateTuple({ + body_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1))), + })); + body = module().AddEmbeddedComputation(body_builder.Build()); + } + + builder.AddInstruction(HloInstruction::CreateWhile( + loop_init->shape(), condition, body, loop_init)); + + module().AddEntryComputation(builder.Build()); + EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); +} + +// Nothing to simplify in a while loop whose tuple has 0 elements. +TEST_F(WhileLoopSimplifierTest, EmptyTuple) { + HloComputation::Builder builder(TestName()); + auto loop_init = builder.AddInstruction(HloInstruction::CreateTuple({})); + + HloComputation* condition = + MakeAlwaysTrueComputation(loop_init->shape(), &module()); + HloComputation* body; + { + HloComputation::Builder body_builder(TestName() + ".body"); + body_builder.AddInstruction( + HloInstruction::CreateParameter(0, loop_init->shape(), "loop_var")); + body_builder.AddInstruction(HloInstruction::CreateTuple({})); + body = module().AddEmbeddedComputation(body_builder.Build()); + } + + builder.AddInstruction(HloInstruction::CreateWhile( + loop_init->shape(), condition, body, loop_init)); + module().AddEntryComputation(builder.Build()); + EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); +} + +// While loop where one tuple element is used twice in the body, and thus can't +// be simplified away. +TEST_F(WhileLoopSimplifierTest, ElemUsedTwice) { + HloComputation::Builder builder(TestName()); + auto loop_init = builder.AddInstruction(HloInstruction::CreateTuple({ + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0))), + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1))), + })); + + HloComputation* condition = + MakeAlwaysTrueComputation(loop_init->shape(), &module()); + + auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); + HloComputation* body; + { + HloComputation::Builder body_builder(TestName() + ".body"); + auto* param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, loop_init->shape(), "param0")); + auto* gte0 = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, /*index=*/0)); + // get0 is used twice in the loop body's tuple. + body_builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte0})); + body = module().AddEmbeddedComputation(body_builder.Build()); + } + + builder.AddInstruction(HloInstruction::CreateWhile( + loop_init->shape(), condition, body, loop_init)); + module().AddEntryComputation(builder.Build()); + EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); +} + +// This while loop has three tuple elements. Element 0 is unused and should be +// removed. Element 1 is used by the loop body, and element 2 is used by the +// loop condition; these two should stay. +TEST_F(WhileLoopSimplifierTest, RemoveUnusedOperand) { + HloComputation::Builder builder(TestName()); + auto loop_init = builder.AddInstruction(HloInstruction::CreateTuple({ + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0))), + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0))), + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0))), + })); + auto loop_shape = loop_init->shape(); + auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); + + HloComputation* condition; + { + HloComputation::Builder cond_builder(TestName() + ".loop_condition"); + auto param = cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, loop_shape, "param0")); + cond_builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {}), HloOpcode::kEq, + cond_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0))), + cond_builder.AddInstruction(HloInstruction::CreateGetTupleElement( + scalar_s32, param, /*index=*/2)))); + condition = module().AddEmbeddedComputation(cond_builder.Build()); + } + + HloComputation* body; + { + HloComputation::Builder body_builder(TestName() + ".body"); + auto* param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, loop_shape, "loop_var")); + + auto* tuple0 = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, /*index=*/0)); + auto* tuple1 = body_builder.AddInstruction(HloInstruction::CreateBinary( + scalar_s32, HloOpcode::kAdd, + body_builder.AddInstruction(HloInstruction::CreateGetTupleElement( + scalar_s32, param, /*index=*/1)), + body_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1))))); + auto* tuple2 = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, /*index=*/2)); + body_builder.AddInstruction( + HloInstruction::CreateTuple({tuple0, tuple1, tuple2})); + + body = module().AddEmbeddedComputation(body_builder.Build()); + } + + auto* while_op = builder.AddInstruction(HloInstruction::CreateWhile( + loop_init->shape(), condition, body, loop_init)); + + module().AddEntryComputation(builder.Build()); + EXPECT_TRUE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); + + // We leave most of the checking to HloVerifiedTestBase, which runs the + // verifier on module() at the end of this test. + HloInstruction* new_while_op = *std::find_if( + module().entry_computation()->instructions().begin(), + module().entry_computation()->instructions().end(), + [&](const HloInstruction* instr) { + return instr != while_op && instr->opcode() == HloOpcode::kWhile; + }); + EXPECT_TRUE( + ShapeUtil::Equal(new_while_op->shape(), + ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32}))) + << ShapeUtil::HumanString(new_while_op->shape()); + EXPECT_THAT( + new_while_op->while_body()->root_instruction(), + op::Tuple( + op::Add(op::GetTupleElement(op::Parameter(0), /*tuple_index=*/0), + op::Constant()), + op::GetTupleElement(op::Parameter(0), /*tuple_index=*/1))); + + EXPECT_THAT(new_while_op->while_condition()->root_instruction(), + op::Eq(op::Constant(), + op::GetTupleElement(op::Parameter(0), /*tuple_index=*/1))); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/shape_layout.cc b/tensorflow/compiler/xla/shape_layout.cc index 5bf9842a6ce7be747f58c10f302f85c6f82ac6f9..789eba5780d37e1fd4d80ec881855951c8bba0eb 100644 --- a/tensorflow/compiler/xla/shape_layout.cc +++ b/tensorflow/compiler/xla/shape_layout.cc @@ -32,13 +32,13 @@ tensorflow::Status ShapeLayout::CopyLayoutFromShape(const Shape& other_shape) { return tensorflow::Status::OK(); } -tensorflow::Status ShapeLayout::AssignLayoutToShape(Shape* other_shape) const { - if (!ShapeUtil::Compatible(*other_shape, shape_)) { +tensorflow::Status ShapeLayout::AssignLayoutToShape(Shape* to_shape) const { + if (!ShapeUtil::Compatible(*to_shape, shape_)) { return InvalidArgument("Shape %s is not compatible with shape %s", - ShapeUtil::HumanString(*other_shape).c_str(), + ShapeUtil::HumanString(*to_shape).c_str(), ShapeUtil::HumanString(shape()).c_str()); } - *other_shape = shape_; + *to_shape = shape_; return tensorflow::Status::OK(); } diff --git a/tensorflow/compiler/xla/shape_layout.h b/tensorflow/compiler/xla/shape_layout.h index 92564660f21bf1b596c4b9ca04c07eaca27ed192..4c83750f3e6f3c735db66d8e0b86ae3f43e5ca11 100644 --- a/tensorflow/compiler/xla/shape_layout.h +++ b/tensorflow/compiler/xla/shape_layout.h @@ -38,18 +38,19 @@ class ShapeLayout { explicit ShapeLayout(const Shape& shape) : shape_(shape) {} // Assigns the layouts in this ShapeLayout to the Layout fields of the given - // shape. 'shape' and the shape of the ShapeLayout object must be compatible. - tensorflow::Status AssignLayoutToShape(Shape* shape) const; + // shape. 'to_shape' and the shape of the ShapeLayout object must be + // compatible. + tensorflow::Status AssignLayoutToShape(Shape* to_shape) const; // Returns true if the Layouts in this ShapeLayout match the layouts in the // given shape. Returns false otherwise. If the given shape is not compatible // with the ShapeLayout's shape, then false is returned. bool MatchesLayoutInShape(const Shape& shape) const; - // Copies the layout from the given shape into this ShapeLayout. 'shape' must - // be compatible with the ShapeLayout's shape, and 'shape' must have a layout - // (LayoutUtil::HasLayout). - tensorflow::Status CopyLayoutFromShape(const Shape& shape); + // Copies the layout from the given shape into this ShapeLayout. 'other_shape' + // must be compatible with the ShapeLayout's shape, and 'other_shape' must + // have a layout (LayoutUtil::HasLayout). + tensorflow::Status CopyLayoutFromShape(const Shape& other_shape); // Clears (Layout::Clear) all the Layouts stored in this object. void Clear(); diff --git a/tensorflow/compiler/xla/shape_tree.h b/tensorflow/compiler/xla/shape_tree.h index 64a36471b9f1b35517c29c01554e02c5d1035086..d752619bd65751779c24f061e44e206d66b01465 100644 --- a/tensorflow/compiler/xla/shape_tree.h +++ b/tensorflow/compiler/xla/shape_tree.h @@ -116,6 +116,7 @@ class ShapeTree { ShapeTree(const Shape* shape, const T& init_value); ShapeTree(const ShapeTree& other) { *this = other; } + ShapeTree(ShapeTree&&) = default; ShapeTree& operator=(const ShapeTree& other) { root_ = other.root_; @@ -132,6 +133,8 @@ class ShapeTree { return *this; } + ShapeTree& operator=(ShapeTree&& other) = default; + // Returns the data element associated with the array in the shape at the // given index (see ShapeUtil::GetSubshape for how indexes are defined). const T& element(const ShapeIndex& index) const; @@ -152,28 +155,57 @@ class ShapeTree { using const_iterator = ShapeTreeIterator; // begin/end for iterating over all nodes. - iterator begin() { return iterator(&root_, /*iterate_leaves_only=*/false); } - iterator end() { return iterator(nullptr, /*iterate_leaves_only=*/false); } + iterator begin() { + return iterator(&root_, /*iterate_leaves_only=*/false, + /*reverse=*/false); + } + iterator end() { + return iterator(nullptr, /*iterate_leaves_only=*/false, + /*reverse=*/false); + } const_iterator begin() const { - return const_iterator(&root_, /*iterate_leaves_only=*/false); + return const_iterator(&root_, /*iterate_leaves_only=*/false, + /*reverse=*/false); } const_iterator end() const { - return const_iterator(nullptr, /*iterate_leaves_only=*/false); + return const_iterator(nullptr, /*iterate_leaves_only=*/false, + /*reverse=*/false); + } + + // rbegin/rend for iterating over all nodes in reverse. + iterator rbegin() { + return iterator(&root_, /*iterate_leaves_only=*/false, + /*reverse=*/true); + } + iterator rend() { + return iterator(nullptr, /*iterate_leaves_only=*/false, + /*reverse=*/true); + } + const_iterator rbegin() const { + return const_iterator(&root_, /*iterate_leaves_only=*/false, + /*reverse=*/true); + } + const_iterator rend() const { + return const_iterator(nullptr, /*iterate_leaves_only=*/false, + /*reverse=*/true); } // leaf_begin()/leaf_end() iterates over all leaf nodes (nodes with no // children). iterator leaf_begin() { - return iterator(&root_, /*iterate_leaves_only=*/true); + return iterator(&root_, /*iterate_leaves_only=*/true, /*reverse=*/false); } iterator leaf_end() { - return iterator(nullptr, /*iterate_leaves_only=*/true); + return iterator(nullptr, /*iterate_leaves_only=*/true, + /*reverse=*/false); } const_iterator leaf_begin() const { - return const_iterator(&root_, /*iterate_leaves_only=*/true); + return const_iterator(&root_, /*iterate_leaves_only=*/true, + /*reverse=*/false); } const_iterator leaf_end() const { - return const_iterator(nullptr, /*iterate_leaves_only=*/true); + return const_iterator(nullptr, /*iterate_leaves_only=*/true, + /*reverse=*/false); } // range-based iterator for leaf_begin()/leaf_end(). tensorflow::gtl::iterator_range leaves() { @@ -183,6 +215,22 @@ class ShapeTree { return tensorflow::gtl::make_range(leaf_begin(), leaf_end()); } + iterator leaf_rbegin() { + return iterator(&root_, /*iterate_leaves_only=*/true, /*reverse=*/true); + } + iterator leaf_rend() { + return iterator(nullptr, /*iterate_leaves_only=*/true, + /*reverse=*/true); + } + const_iterator leaf_rbegin() const { + return const_iterator(&root_, /*iterate_leaves_only=*/true, + /*reverse=*/true); + } + const_iterator leaf_rend() const { + return const_iterator(nullptr, /*iterate_leaves_only=*/true, + /*reverse=*/true); + } + // Recursively traverses the shape and calls the given function at each // element. The function has the following arguments: // @@ -190,7 +238,7 @@ class ShapeTree { // (or compatible). // index : the index of the element in the shape. See ShapeUtil::GetSubshape // for definition of index. - // data : The data value at this elemnt. + // data : The data value at this element. template void ForEachElement(const Fn& func) const; @@ -277,42 +325,61 @@ class ShapeTreeIterator : public std::iteratorchildren.empty() && iterate_leaves_only) { - ++*this; + // interior tree nodes, only leaves. If reverse is true, the iterator will + // visit nodes in the reverse of pre-order traversal. + ShapeTreeIterator(NodeType* node, bool iterate_leaves_only, bool reverse) + : node_(node), + iterate_leaves_only_(iterate_leaves_only), + reverse_(reverse) { + if (node_) { + if (reverse_) { + while (!node_->children.empty()) { + const int child_index = node_->children.size() - 1; + stack_.push_back({node_, child_index}); + node_ = node_->children[child_index].get(); + } + } else { + if (!node_->children.empty() && iterate_leaves_only) { + ++*this; + } + } } } ShapeTreeIterator(const ShapeTreeIterator& other) : node_(other.node_), stack_(other.stack_), - iterate_leaves_only_(other.iterate_leaves_only_) {} + iterate_leaves_only_(other.iterate_leaves_only_), + reverse_(other.reverse_) {} ShapeTreeIterator& operator++() { CHECK_NE(nullptr, node_) << "walking off the end() of an iterator!"; - // We're doing a pre-order walk, so if our current node has children take - // the first child. - if (!node_->children.empty()) { - stack_.push_back({node_, /*child-index=*/0}); - node_ = node_->children[0].get(); - if (node_->children.empty() || !iterate_leaves_only_) { - return *this; - } else { - // This is a non-leaf; tail-recurse. - return ++(*this); + if (reverse_) { + while (!stack_.empty()) { + node_ = stack_.back().first; + int64 next_child_index = stack_.back().second - 1; + stack_.pop_back(); + if (next_child_index < 0) { + if (!iterate_leaves_only_) { + // All children are visited, yield . + return *this; + } + } else { + stack_.push_back({node_, next_child_index}); + node_ = node_->children[next_child_index].get(); + while (!node_->children.empty()) { + const int child_index = node_->children.size() - 1; + stack_.push_back({node_, child_index}); + node_ = node_->children[child_index].get(); + } + return *this; + } } - } - // Otherwise we are currently at a leaf. Walk back up until a node contains - // a child we haven't visited yet. - while (!stack_.empty()) { - node_ = stack_.back().first; - int64 next_child_index = stack_.back().second + 1; - stack_.pop_back(); - if (node_->children.size() > next_child_index) { - stack_.push_back({node_, next_child_index}); - node_ = node_->children[next_child_index].get(); - + } else { + // We're doing a pre-order walk, so if our current node has children take + // the first child. + if (!node_->children.empty()) { + stack_.push_back({node_, /*child-index=*/0}); + node_ = node_->children[0].get(); if (node_->children.empty() || !iterate_leaves_only_) { return *this; } else { @@ -320,6 +387,24 @@ class ShapeTreeIterator : public std::iteratorchildren.size() > next_child_index) { + stack_.push_back({node_, next_child_index}); + node_ = node_->children[next_child_index].get(); + + if (node_->children.empty() || !iterate_leaves_only_) { + return *this; + } else { + // This is a non-leaf; tail-recurse. + return ++(*this); + } + } + } } // We've walked off the end of the tree. Set node_ to nullptr to signify // end(). @@ -361,6 +446,8 @@ class ShapeTreeIterator : public std::iterator> stack_; // True if we should not include interior nodes in our walk. bool iterate_leaves_only_; + // True if we should yield the reverse of the pre-order traversal. + bool reverse_; // Placeholder for the current value. Ideally this wouldn't exist and would // just be an rvalue, but operator -> needs to return a pointer to something. // We cannot just use a plain old value_type as it contains a reference so diff --git a/tensorflow/compiler/xla/shape_tree_test.cc b/tensorflow/compiler/xla/shape_tree_test.cc index 7b4b5cb0fb5e1564ca12ac6e3b901e94ea4c8db6..4b6ab772811f4a6c6ffc1d10befc7122f883b8f9 100644 --- a/tensorflow/compiler/xla/shape_tree_test.cc +++ b/tensorflow/compiler/xla/shape_tree_test.cc @@ -456,6 +456,26 @@ TEST_F(ShapeTreeTest, IterateOrder) { {2, 1}})); } +TEST_F(ShapeTreeTest, ReverseIterateOrder) { + ShapeTree t(nested_tuple_shape_, 42); + std::vector v; + for (auto it = t.rbegin(); it != t.rend(); ++it) { + v.push_back(it->first); + } + EXPECT_EQ(v, (std::vector{ + {2, 1}, + {2, 0, 1}, + {2, 0, 0}, + {2, 0}, + {2}, + {1, 1}, + {1, 0}, + {1}, + {0}, + {}, + })); +} + TEST_F(ShapeTreeTest, IterateOrderLeaves) { ShapeTree t(nested_tuple_shape_, 42); std::vector v; @@ -466,5 +486,21 @@ TEST_F(ShapeTreeTest, IterateOrderLeaves) { {0}, {1, 0}, {1, 1}, {2, 0, 0}, {2, 0, 1}, {2, 1}})); } +TEST_F(ShapeTreeTest, ReverseIterateOrderLeaves) { + ShapeTree t(nested_tuple_shape_, 42); + std::vector v; + for (auto it = t.leaf_rbegin(); it != t.leaf_rend(); ++it) { + v.push_back(it->first); + } + EXPECT_EQ(v, (std::vector{ + {2, 1}, + {2, 0, 1}, + {2, 0, 0}, + {1, 1}, + {1, 0}, + {0}, + })); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index b5eb81dfc6a4117909dcb18fdbe61443b1a1eb95..fe5166643df573ab8cbbea56ac791bccf5b7a4a8 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -263,6 +264,7 @@ StatusOr MakeShapeWithLayoutInternal( case S32: case S64: case F16: + case BF16: case F32: case F64: return true; @@ -328,6 +330,14 @@ StatusOr MakeShapeWithLayoutInternal( return MakeTupleShape(new_elements); } +// Returns the shape of a real or imaginary component. +/* static */ Shape ShapeUtil::ComplexComponentShape( + const Shape& complex_shape) { + CHECK(ElementIsComplex(complex_shape)) << HumanString(complex_shape); + return ChangeElementType(complex_shape, primitive_util::ComplexComponentType( + complex_shape.element_type())); +} + /* static */ bool ShapeUtil::ShapeIs(const Shape& shape, PrimitiveType element_type, std::initializer_list dimensions) { @@ -395,6 +405,26 @@ const string& LowercasePrimitiveTypeName(PrimitiveType s) { static PrimitiveTypeNameGenerator* gen = new PrimitiveTypeNameGenerator(); return gen->LowercaseName(s); } + +StatusOr StringToPrimitiveType(const string& name) { + static std::unordered_map* name_to_type = [] { + static auto* map = new std::unordered_map; + for (int i = 0; i < PrimitiveType_ARRAYSIZE; i++) { + if (PrimitiveType_IsValid(i)) { + auto value = static_cast(i); + (*map)[LowercasePrimitiveTypeName(value)] = value; + } + } + return map; + }(); + auto found = name_to_type->find(name); + if (found == name_to_type->end()) { + return InvalidArgument("Invalid element type string: \"%s\".", + name.c_str()); + } + return found->second; +} + } // namespace /* static */ string ShapeUtil::HumanStringWithLayout(const Shape& shape) { @@ -499,17 +529,10 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { comma_list_to_int64s(dimensions_string)); // Extract the primitive element type. - PrimitiveType primitive_type = PRIMITIVE_TYPE_INVALID; - for (PrimitiveType i = - static_cast(PRIMITIVE_TYPE_INVALID + 1); - i < TUPLE; i = static_cast(i + 1)) { - if (tensorflow::str_util::Lowercase(PrimitiveType_Name(i)) == - element_type_string) { - primitive_type = i; - break; - } - } - if (primitive_type == PRIMITIVE_TYPE_INVALID) { + TF_ASSIGN_OR_RETURN(const PrimitiveType primitive_type, + StringToPrimitiveType(element_type_string)); + if (primitive_type == PRIMITIVE_TYPE_INVALID || primitive_type == TUPLE || + primitive_type == OPAQUE) { return InvalidArgument("Invalid element type string: \"%s\".", element_type_string.c_str()); } @@ -552,6 +575,16 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { return SameDimensions(lhs, rhs) && SameElementType(lhs, rhs); } +/* static */ bool ShapeUtil::CompatibleIgnoringElementType(const Shape& lhs, + const Shape& rhs) { + if (lhs.element_type() == TUPLE) { + return rhs.element_type() == TUPLE && + ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), + CompatibleIgnoringElementType); + } + return SameDimensions(lhs, rhs); +} + /* static */ int64 ShapeUtil::GetDimension(const Shape& shape, int64 dimension_number) { return shape.dimensions(GetDimensionNumber(shape, dimension_number)); @@ -591,6 +624,8 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { return sizeof(uint32); case U64: return sizeof(uint64); + case BF16: + return sizeof(float) / 2; case F16: return sizeof(float) / 2; case F32: @@ -681,9 +716,9 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { return LayoutUtil::ValidateLayoutInShape(shape); } -/* static */ Shape ShapeUtil::ChangeElementType(const Shape& shape, +/* static */ Shape ShapeUtil::ChangeElementType(const Shape& original, PrimitiveType type) { - Shape new_shape = shape; + Shape new_shape = original; new_shape.set_element_type(type); return new_shape; } diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index 8f8d4a73c9ecb3f4236f3877323ad1127bb0b9c2..666c7da697c7cbad4dc30a7b3feb2b2804562442 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -68,6 +68,9 @@ class ShapeIndex { const int64* data() const { return indices_.data(); } + int64 back() const { return indices_.back(); } + int64& back() { return indices_.back(); } + const int64& operator[](size_t i) const { return indices_[i]; } int64& operator[](size_t i) { return indices_[i]; } @@ -167,7 +170,7 @@ class ShapeUtil { // As above, but for program shapes, returns a string for the form: // // (param_name: f32[42x12], ...) -> f32[24x42] - static string HumanString(const ProgramShape& shape); + static string HumanString(const ProgramShape& program_shape); // Parses a ShapeUtil::HumanString-format shape string back into a shape // object. @@ -187,6 +190,11 @@ class ShapeUtil { // compatibility. static bool Compatible(const Shape& lhs, const Shape& rhs); + // Returns true if the rank and dimension sizes are identical. Element type + // and layout are ignored. Tuple elements are compared recursively for + // compatibility. + static bool CompatibleIgnoringElementType(const Shape& lhs, const Shape& rhs); + // Returns whether the lhs and rhs shapes are identical protobufs. static bool Equal(const Shape& lhs, const Shape& rhs); @@ -343,6 +351,10 @@ class ShapeUtil { // shape. E.g. a tuple like (f32, s32, u32) would slice via 1,3 to (s32, u32). static Shape SliceTuple(const Shape& tuple, int64 start, int64 limit); + // Returns the shape of the real/imaginary components of the given complex + // shape. + static Shape ComplexComponentShape(const Shape& complex_shape); + // Shorthand for testing whether a shape is of a given element type and // sequence of dimensions. // diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index 0ba542ad1bec290c35c52a8dd5177893770310fd..4bce7ca51d0534cbcad6faac12818c5f3e94b29e 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -145,6 +145,7 @@ TEST(ShapeUtilTest, IncompatibleTuplesWithSwappedElements) { Shape tuple2 = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShape(F32, {3, 2}), ShapeUtil::MakeShape(PRED, {4, 5})}); EXPECT_FALSE(ShapeUtil::Compatible(tuple1, tuple2)); + EXPECT_FALSE(ShapeUtil::CompatibleIgnoringElementType(tuple1, tuple2)); } TEST(ShapeUtilTest, IncompatibleTuplesWithDifferentPrimitiveType) { @@ -153,6 +154,7 @@ TEST(ShapeUtilTest, IncompatibleTuplesWithDifferentPrimitiveType) { Shape tuple2 = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShape(PRED, {4, 5}), ShapeUtil::MakeShape(S32, {3, 2})}); EXPECT_FALSE(ShapeUtil::Compatible(tuple1, tuple2)); + EXPECT_TRUE(ShapeUtil::CompatibleIgnoringElementType(tuple1, tuple2)); } TEST(ShapeUtilTest, IncompatibleTuplesWithDifferentDimensions) { diff --git a/tensorflow/compiler/xla/statusor_test.cc b/tensorflow/compiler/xla/statusor_test.cc index 5fa2211ac66177514ac8ecabfa8791e7c8c014a2..f9d25945bc617507735fb6c4d011c39723497f69 100644 --- a/tensorflow/compiler/xla/statusor_test.cc +++ b/tensorflow/compiler/xla/statusor_test.cc @@ -32,26 +32,26 @@ namespace { class Base1 { public: virtual ~Base1() {} - int pad; + int pad_; }; class Base2 { public: virtual ~Base2() {} - int yetotherpad; + int yetotherpad_; }; class Derived : public Base1, public Base2 { public: ~Derived() override {} - int evenmorepad; + int evenmorepad_; }; class CopyNoAssign { public: - explicit CopyNoAssign(int value) : foo(value) {} - CopyNoAssign(const CopyNoAssign& other) : foo(other.foo) {} - int foo; + explicit CopyNoAssign(int value) : foo_(value) {} + CopyNoAssign(const CopyNoAssign& other) : foo_(other.foo_) {} + int foo_; private: const CopyNoAssign& operator=(const CopyNoAssign&); @@ -253,7 +253,7 @@ TEST(StatusOr, TestCopyCtorNonAssignable) { StatusOr original(value); StatusOr copy(original); EXPECT_EQ(copy.status(), original.status()); - EXPECT_EQ(original.ValueOrDie().foo, copy.ValueOrDie().foo); + EXPECT_EQ(original.ValueOrDie().foo_, copy.ValueOrDie().foo_); } TEST(StatusOr, TestCopyCtorStatusOKConverting) { diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 4e1be24b61cc436b0baf62cc6e28ad8d13fe71ac..6af01ae80d9ac8cdf8e7ba5cff4c24ef1d31cf94 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -61,14 +61,19 @@ generate_backend_test_macros() cc_library( name = "test_utils", - testonly = True, + srcs = ["test_utils.cc"], hdrs = ["test_utils.h"], deps = [ "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_dataflow_analysis", + "//tensorflow/compiler/xla/service:hlo_verifier", + "//tensorflow/compiler/xla/service:transfer_manager", "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_headers_lib", ], ) @@ -100,7 +105,9 @@ cc_library( hdrs = ["hlo_test_base.h"], deps = [ ":literal_test_util", + ":test_utils", "//tensorflow/compiler/xla:shape_layout", + "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", @@ -110,6 +117,9 @@ cc_library( "//tensorflow/compiler/xla/service:computation_layout", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_runner", + "//tensorflow/compiler/xla/service:interpreter_plugin", # reference backend + "//tensorflow/compiler/xla/service:platform_util", + "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", @@ -427,6 +437,27 @@ xla_test( ], ) +xla_test( + name = "conditional_test", + srcs = ["conditional_test.cc"], + # Currently, Conditional is supported only in CPU and GPU backends. + backends = [ + "cpu", + "gpu", + ], + deps = [ + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + xla_test( name = "unary_op_test", srcs = ["unary_op_test.cc"], @@ -508,6 +539,7 @@ xla_test( name = "array_elementwise_ops_test", srcs = ["array_elementwise_ops_test.cc"], shard_count = 25, + tags = ["enable_for_xla_interpreter"], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", @@ -766,6 +798,41 @@ xla_test( ], ) +xla_test( + name = "bfloat16_test", + srcs = ["bfloat16_test.cc"], + blacklisted_backends = [ + "gpu", + ], + shard_count = 40, + deps = [ + ":test_utils", + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:array4d", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:reference_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:computation", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + xla_test( name = "slice_test", srcs = ["slice_test.cc"], @@ -1226,6 +1293,23 @@ xla_test( ], ) +xla_test( + name = "bitcast_convert_test", + srcs = ["bitcast_convert_test.cc"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core:test", + ], +) + xla_test( name = "compilation_cache_test", srcs = ["compilation_cache_test.cc"], @@ -1290,6 +1374,7 @@ xla_test( srcs = ["client_test.cc"], deps = [ "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", @@ -1343,22 +1428,23 @@ xla_test( ], ) -xla_test( +tf_cc_test( name = "llvm_compiler_test", srcs = ["llvm_compiler_test.cc"], - backends = [ - "cpu", - "gpu", - "cpu_parallel", - ], + tags = ["requires-gpu-sm35"], deps = [ - "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla/service:backend", + "//tensorflow/compiler/xla/service:cpu_plugin", + "//tensorflow/compiler/xla/service:gpu_plugin", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:llvm_compiler", - "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:literal_test_util", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/compiler/xla/service:platform_util", + "//tensorflow/compiler/xla/service/cpu:cpu_compiler", + "//tensorflow/compiler/xla/service/gpu:gpu_compiler", "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/stream_executor", "@llvm//:core", ], ) @@ -1596,6 +1682,65 @@ tf_cc_test( ], ) +xla_test( + name = "transfer_manager_test", + srcs = ["transfer_manager_test.cc"], + deps = [ + ":literal_test_util", + ":local_client_test_base", + ":xla_internal_test_main", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:device_memory_allocator", + "//tensorflow/compiler/xla/service:generic_transfer_manager", + "//tensorflow/compiler/xla/service:shaped_buffer", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + ], +) + +# A demo of textual IR based test. +xla_test( + name = "sample_text_test", + srcs = ["sample_text_test.cc"], + # You can leave this empty if you want to test all supported backends. + backends = [ + "cpu", + "gpu", + ], + deps = [ + ":hlo_test_base", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + ], +) + +# A demo of test that loads an hlo module from a file and compares results on gpu and cpu. +tf_cc_test( + name = "sample_file_test", + srcs = ["sample_file_test.cc"], + data = ["isolated_convolution.hlo"], + tags = ["requires-gpu-sm35"], + deps = [ + ":hlo_test_base", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla/service:cpu_plugin", # reference backend + "//tensorflow/compiler/xla/service:gpu_plugin", # test backend + "//tensorflow/compiler/xla/service:platform_util", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + # ----------------------------------------------------------------------------- filegroup( diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc index a62b13e04ff35b06846039d7665dfc8e4205eec2..c6e8b24d1211743d07878d388522feacf9c0e7f1 100644 --- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc +++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc @@ -82,6 +82,25 @@ XLA_TEST_F(ArrayElementwiseOpTest, NegConstantS32) { {}); } +XLA_TEST_F(ArrayElementwiseOpTest, NegConstantZeroElementC64) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({}); + auto result = builder.Neg(a); + + ComputeAndCompareR1(&builder, {}, {}, error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, NegConstantC64) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1( + {{-2.5f, 1.0f}, {0.0f, 3.14f}, {2.25f, -1.0f}, {-10.0f, 0.0f}}); + auto result = builder.Neg(a); + + ComputeAndCompareR1( + &builder, {{2.5f, -1.0f}, {0.0f, -3.14f}, {-2.25f, 1.0f}, {10.0f, 0.0f}}, + {}, error_spec_); +} + XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteZeroElementF32s) { ComputationBuilder builder(client_, TestName()); auto a = builder.ConstantR1({}); @@ -145,6 +164,28 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantZeroElementF32s) { ComputeAndCompareR1(&builder, {}, {}, error_spec_); } +XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantC64s) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1( + {{-2.5f, 0.0f}, {0.0f, 3.14f}, {2.25f, 0.0f}, {1.0f, -10.0f}}); + auto b = builder.ConstantR1( + {{100.0f, 0.0f}, {3.13f, 0.0f}, {2.75f, 1.0f}, {-2.0f, 10.5f}}); + auto add = builder.Add(a, b); + + ComputeAndCompareR1( + &builder, {97.5f, {3.13f, 3.14f}, {5.0f, 1.0f}, {-1.0f, 0.5f}}, {}, + error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantZeroElementC64s) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({}); + auto b = builder.ConstantR1({}); + auto add = builder.Add(a, b); + + ComputeAndCompareR1(&builder, {}, {}, error_spec_); +} + TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) { const int count = GetParam(); ComputationBuilder builder(client_, TestName()); @@ -222,6 +263,28 @@ XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementS32s) { ComputeAndCompareR1(&builder, {}, {}); } +XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantC64s) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1( + {{-2.5f, 0.0f}, {0.0f, 3.14f}, {3.0f, 2.25f}}); + auto b = builder.ConstantR1( + {{0.0f, 10.0f}, {3.13f, 0.0f}, {2.75f, -0.25f}}); + auto add = builder.Sub(a, b); + + ComputeAndCompareR1( + &builder, {{-2.5f, -10.0f}, {-3.13f, 3.14f}, {0.25f, 2.5f}}, {}, + error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementC64s) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({}); + auto b = builder.ConstantR1({}); + auto add = builder.Sub(a, b); + + ComputeAndCompareR1(&builder, {}, {}, error_spec_); +} + XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantF32s) { ComputationBuilder builder(client_, TestName()); auto a = builder.ConstantR1({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f}); @@ -385,6 +448,27 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivU32s) { } } +XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantC64s) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1( + {{-2.5f, 1.0f}, {-25.5f, 0.0f}, {2.0f, -1.0f}}); + auto b = builder.ConstantR1( + {{10.0f, 0.0f}, {0.0f, 1.0f}, {2.0f, -1.0f}}); + auto div = builder.Div(a, b); + + ComputeAndCompareR1( + &builder, {{-0.25f, 0.1f}, {0.0f, 25.5f}, {1.0f, 0.0f}}, {}, error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantZeroElementC64s) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({}); + auto b = builder.ConstantR1({}); + auto div = builder.Div(a, b); + + ComputeAndCompareR1(&builder, {}, {}, error_spec_); +} + XLA_TEST_F(ArrayElementwiseOpTest, RemF32s) { ComputationBuilder builder(client_, TestName()); auto a = builder.ConstantR1( @@ -496,6 +580,28 @@ XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantU32s) { ComputeAndCompareR1(&builder, expected, {}); } +XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantC64s) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1( + {{-2.5f, 0.0f}, {0.0f, 25.5f}, {2.0f, -10.0f}}); + auto b = builder.ConstantR1( + {{0.0f, 10.0f}, {5.0f, 1.0f}, {10.0f, -6.0f}}); + auto add = builder.Mul(a, b); + + ComputeAndCompareR1( + &builder, {{0.0f, -25.0f}, {-25.5f, 127.5f}, {-40.0f, -112.0}}, {}, + error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantZeroElementC64s) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({}); + auto b = builder.ConstantR1({}); + auto add = builder.Mul(a, b); + + ComputeAndCompareR1(&builder, {}, {}, error_spec_); +} + XLA_TEST_F(ArrayElementwiseOpTest, AndPredR1) { ComputationBuilder builder(client_, TestName()); auto a = builder.ConstantR1({false, false, true, true}); @@ -886,6 +992,53 @@ XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementS32s) { ComputeAndCompareR1(&builder, {}, {}); } +XLA_TEST_F(ArrayElementwiseOpTest, CompareEqC64s) { + SetFastMathDisabled(true); + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR1({{-2.5f, 10.0f}, + {1.0f, 25.5f}, + {2.25f, -3.0f}, + {NAN, 0.0f}, + {1.0f, 6.0f}}); + auto rhs = builder.ConstantR1({{0.0f, 10.0f}, + {1.0f, 5.0f}, + {2.25f, -3.0f}, + {10.0f, 0.0f}, + {1.0f, NAN}}); + auto compare = builder.Eq(lhs, rhs); + + ComputeAndCompareR1(&builder, {false, false, true, false, false}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementC64s) { + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR1({}); + auto rhs = builder.ConstantR1({}); + auto compare = builder.Eq(lhs, rhs); + + ComputeAndCompareR1(&builder, {}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, CompareNeC64s) { + // Disable fast-math because we're operating on NaNs. + SetFastMathDisabled(true); + + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR1({{-2.5f, 10.0f}, + {1.0f, 25.5f}, + {2.25f, -3.0f}, + {NAN, 0.0f}, + {1.0f, 6.0f}}); + auto rhs = builder.ConstantR1({{0.0f, 10.0f}, + {1.0f, 5.0f}, + {2.25f, -3.0f}, + {10.0f, 0.0f}, + {1.0f, NAN}}); + auto compare = builder.Ne(lhs, rhs); + + ComputeAndCompareR1(&builder, {true, true, false, true, true}, {}); +} + XLA_TEST_F(ArrayElementwiseOpTest, CompareNeF32s) { // Disable fast-math because we're operating on NaNs. SetFastMathDisabled(true); @@ -2027,7 +2180,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ne) { const string expected = R"(pred[2,2] { { 00 }, - { 01 }, + { 01 } })"; EXPECT_EQ(expected, ExecuteToString(&builder, {})); } @@ -2041,7 +2194,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ge) { const string expected = R"(pred[2,4] { { 1100 }, - { 0001 }, + { 0001 } })"; EXPECT_EQ(expected, ExecuteToString(&builder, {})); } @@ -2055,7 +2208,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Gt) { const string expected = R"(pred[2,4] { { 0100 }, - { 0000 }, + { 0000 } })"; EXPECT_EQ(expected, ExecuteToString(&builder, {})); } @@ -2069,7 +2222,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Le) { const string expected = R"(pred[2,4] { { 1011 }, - { 1111 }, + { 1111 } })"; EXPECT_EQ(expected, ExecuteToString(&builder, {})); } @@ -2083,7 +2236,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Lt) { const string expected = R"(pred[2,4] { { 0011 }, - { 1110 }, + { 1110 } })"; EXPECT_EQ(expected, ExecuteToString(&builder, {})); } diff --git a/tensorflow/compiler/xla/tests/bfloat16_test.cc b/tensorflow/compiler/xla/tests/bfloat16_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..ac3f3f4c9ddb03d003a44f5abd7a2e26c42f490d --- /dev/null +++ b/tensorflow/compiler/xla/tests/bfloat16_test.cc @@ -0,0 +1,160 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/array4d.h" +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/reference_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +class Bfloat16Test : public ClientLibraryTestBase { + protected: + const ErrorSpec error_spec_{0.001, 0.001}; +}; + +XLA_TEST_F(Bfloat16Test, ScalarOperation) { + ComputationBuilder builder(client_, TestName()); + auto x = builder.ConstantR0(static_cast(2.0f)); + auto y = builder.ConstantR0(static_cast(1.0f)); + builder.Add(x, y); + + ComputeAndCompareR0(&builder, static_cast(3.0f), {}, + error_spec_); +} + +XLA_TEST_F(Bfloat16Test, LogOperation) { + ComputationBuilder builder(client_, TestName()); + auto x = builder.ConstantR0(static_cast(4.0f)); + builder.Log(x); + + ComputeAndCompareR0(&builder, static_cast(1.387f), {}, + error_spec_); +} + +XLA_TEST_F(Bfloat16Test, NegateScalarF16) { + ComputationBuilder builder(client_, TestName()); + builder.Neg(builder.ConstantR0(static_cast(2.1f))); + + ComputeAndCompareR0(&builder, static_cast(-2.1f), {}, + error_spec_); +} + +XLA_TEST_F(Bfloat16Test, BatchNormTraining) { + const int kFeatureIndex = 2; + ComputationBuilder builder(client_, TestName()); + + auto operand = builder.ConstantR4FromArray4D( + {{{{static_cast(1.f)}, {static_cast(2.f)}}, + {{static_cast(3.f)}, {static_cast(4.f)}}}, + {{{static_cast(5.f)}, {static_cast(6.f)}}, + {{static_cast(7.f)}, {static_cast(8.f)}}}}); + + auto scale = builder.ConstantR1( + {static_cast(2.0f), static_cast(3.0f)}); + + auto offset = builder.ConstantR1( + {static_cast(1.0f), static_cast(2.0f)}); + + auto tuple = builder.BatchNormTraining(operand, scale, offset, + /*epsilon=*/0.001, kFeatureIndex); + + auto expected = *Literal::MakeTuple( + {Literal::CreateR4( + {{{{static_cast(-1.7f)}, {static_cast(-2.04f)}}, + {{static_cast(0.105f)}, {static_cast(0.65f)}}}, + {{{static_cast(1.89f)}, {static_cast(3.35f)}}, + {{static_cast(3.7f)}, {static_cast(6.04f)}}}}) + .get(), + Literal::CreateR1( + {static_cast(4), static_cast(5)}) + .get(), + Literal::CreateR1( + {static_cast(5), static_cast(5)}) + .get()}); + + ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.01)); +} + +XLA_TEST_F(Bfloat16Test, BatchNormGrad) { + const int kFeatureIndex = 2; + ComputationBuilder builder(client_, TestName()); + + auto operand = builder.ConstantR4FromArray4D( + Array4D(2, 2, 2, 1, static_cast(0.0f))); + + auto scale = builder.ConstantR1( + {static_cast(1.0f), static_cast(1.0f)}); + + auto mean = builder.ConstantR1( + {static_cast(0.0f), static_cast(0.0f)}); + + auto var = builder.ConstantR1( + {static_cast(1.0f), static_cast(1.0f)}); + + auto grad_output = builder.ConstantR4FromArray4D( + {{{{static_cast(1.f)}, {static_cast(2.f)}}, + {{static_cast(3.f)}, {static_cast(4.f)}}}, + {{{static_cast(5.f)}, {static_cast(6.f)}}, + {{static_cast(7.f)}, {static_cast(8.f)}}}}); + + builder.BatchNormGrad(operand, scale, mean, var, grad_output, + /*epsilon=*/0.0, kFeatureIndex); + + auto expected = *Literal::MakeTuple( + {Literal::CreateR4( + {{{{static_cast(-3.f)}, {static_cast(-3.f)}}, + {{static_cast(-1.f)}, {static_cast(-1.f)}}}, + {{{static_cast(1.f)}, {static_cast(1.f)}}, + {{static_cast(3.f)}, {static_cast(3.f)}}}}) + .get(), + Literal::CreateR1( + {static_cast(0), static_cast(0)}) + .get(), + Literal::CreateR1( + {static_cast(16), static_cast(20)}) + .get()}); + + ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.01)); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/bitcast_convert_test.cc b/tensorflow/compiler/xla/tests/bitcast_convert_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..0d94d65c1015fb54ada3fdfc95d0c31d0a0f158b --- /dev/null +++ b/tensorflow/compiler/xla/tests/bitcast_convert_test.cc @@ -0,0 +1,141 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +class BitcastConvertTest : public ClientLibraryTestBase { + public: + explicit BitcastConvertTest(perftools::gputools::Platform* platform = nullptr) + : ClientLibraryTestBase(platform) { + mutable_debug_options()->add_xla_disable_hlo_passes("algsimp"); + mutable_debug_options()->add_xla_disable_hlo_passes("inline"); + } +}; + +TEST_F(BitcastConvertTest, ConvertR1S32ToR1S32) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({42, 64}); + builder.BitcastConvertType(a, S32); + + std::vector expected = {42, 64}; + ComputeAndCompareR1(&builder, expected, {}); +} + +TEST_F(BitcastConvertTest, ConvertR1F32ToR1F32) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({42.0f, 64.0f}); + builder.BitcastConvertType(a, F32); + + std::vector expected = {42.0f, 64.0f}; + ComputeAndCompareR1(&builder, expected, {}); +} + +TEST_F(BitcastConvertTest, BitcastR1S32ToR1F32) { + ComputationBuilder builder(client_, TestName()); + auto a = + builder.ConstantR1({0, static_cast(0x80000000), 0x3F800000, + static_cast(0xBF800000), 0x3F000000, + static_cast(0xBF000000)}); + builder.BitcastConvertType(a, F32); + + std::vector expected = {0.0f, -0.0f, 1.0f, -1.0f, 0.5f, -0.5f}; + ComputeAndCompareR1(&builder, expected, {}); +} + +XLA_TEST_F(BitcastConvertTest, ConvertR1S0S32ToR1S0F32) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({}); + builder.BitcastConvertType(a, F32); + + std::vector expected = {}; + ComputeAndCompareR1(&builder, expected, {}); +} + +TEST_F(BitcastConvertTest, ConvertR1F32ToR1S32) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({42.6, 64.4}); + builder.BitcastConvertType(a, S32); + + std::vector expected = {0x422a6666, 0x4280cccd}; + ComputeAndCompareR1(&builder, expected, {}); +} + +TEST_F(BitcastConvertTest, ConvertS32Extremes) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1( + {std::numeric_limits::min(), std::numeric_limits::max()}); + builder.BitcastConvertType(a, F32); + + std::vector expected = {-0.0f, NAN}; + ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0, 0)); +} + +TEST_F(BitcastConvertTest, ConvertMapToS32) { + ComputationBuilder builder(client_, TestName()); + auto b = builder.CreateSubBuilder("convert"); + auto param = b->Parameter(0, ShapeUtil::MakeShape(F32, {}), "in"); + b->BitcastConvertType(param, S32); + auto a = builder.ConstantR1({42.0f, 64.0f}); + builder.Map({a}, b->BuildAndNoteError(), {0}); + + std::vector expected = {0x42280000, 0x42800000}; + ComputeAndCompareR1(&builder, expected, {}); +} + +TEST_F(BitcastConvertTest, ConvertMapToF32) { + ComputationBuilder builder(client_, TestName()); + auto b = builder.CreateSubBuilder("convert"); + auto param = b->Parameter(0, ShapeUtil::MakeShape(S32, {}), "in"); + b->BitcastConvertType(param, F32); + auto a = builder.ConstantR1({0x42280000, 0x42800000}); + builder.Map({a}, b->BuildAndNoteError(), {0}); + + std::vector expected = {42.0f, 64.0f}; + ComputeAndCompareR1(&builder, expected, {}); +} + +// Regression test for b/31758660. When ReshapeMover transforms +// input -> reshape -> convert +// to +// input -> convert -> reshape +// the new convert should have the same element type as the old convert. +TEST_F(BitcastConvertTest, ConvertReshape) { + ComputationBuilder builder(client_, TestName()); + auto input = builder.ConstantR1({0x42280000}); + auto reshape = builder.Reshape(input, /*dimensions=*/{0}, /*new_sizes=*/{}); + builder.BitcastConvertType(reshape, F32); + + ComputeAndCompareR0(&builder, 42.0f, {}); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/build_defs.bzl b/tensorflow/compiler/xla/tests/build_defs.bzl index 36d10fff5400b78fa3ea9a03f6b9cd73059f1427..610302ac1256a57db6ed6e18016a4136973e3891 100644 --- a/tensorflow/compiler/xla/tests/build_defs.bzl +++ b/tensorflow/compiler/xla/tests/build_defs.bzl @@ -29,6 +29,7 @@ def xla_test(name, deps, xla_test_library_deps=[], backends=[], + blacklisted_backends=[], args=[], tags=[], copts=[], @@ -92,17 +93,24 @@ def xla_test(name, backends: A list of backends to generate tests for. Supported values: "cpu", "cpu_parallel", "gpu". If this list is empty, the test will be generated for all supported backends. + blacklisted_backends: A list of backends to NOT generate tests for. args: Test arguments for the target. tags: Tags for the target. - backend_args: A dict mapping backend name to list of additional args to - use for that target. + copts: Additional copts to pass to the build. + data: Additional data to pass to the build. backend_tags: A dict mapping backend name to list of additional tags to use for that target. + backend_args: A dict mapping backend name to list of additional args to + use for that target. + **kwargs: Additional keyword arguments to pass to native.cc_test. """ test_names = [] if not backends: backends = all_backends + backends = [backend for backend in backends + if backend not in blacklisted_backends] + native.cc_library( name="%s_lib" % name, srcs=srcs, @@ -248,5 +256,6 @@ def generate_backend_test_macros(backends=[]): deps = [ "//tensorflow/compiler/xla:types", "//tensorflow/core:lib", + "//tensorflow/core:regexp_internal", "//tensorflow/core:test", ]) diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc index 065bce7e3146c93568bbce2b0e7e23ddddc4ea31..50bf185936808fbd9c49f7fbd5ab0c0b4a76504b 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.cc +++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc @@ -262,20 +262,39 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( expected.shape().element_type() == PRED) << ShapeUtil::HumanString(expected.shape()); } + // We allow using a float expected literal for a bfloat16 output. In this + // case, we need to convert the expected literal to bfloat16. + const Literal* expected_ptr = &expected; + std::unique_ptr converted_expected; + Shape layout_shape; + if (use_bfloat16_) { + converted_expected = LiteralTestUtil::ConvertF32ToBF16(expected); + expected_ptr = converted_expected.get(); + if (shape_with_layout != nullptr) { + layout_shape = *shape_with_layout; + ShapeUtil::ForEachMutableSubshape( + &layout_shape, [&](Shape* subshape, const ShapeIndex& /*index*/) { + if (subshape->element_type() == F32) { + subshape->set_element_type(BF16); + } + }); + shape_with_layout = &layout_shape; + } + } auto expect_equal = [&](const Literal& actual, const string& error_message) { - LiteralTestUtil::ExpectEqual(expected, actual, error_message); + LiteralTestUtil::ExpectEqual(*expected_ptr, actual, error_message); }; if (execution_options_.debug_options().xla_test_all_output_layouts()) { return ComputeAndCompareLiteralWithAllOutputLayouts( - computation, expected, arguments, expect_equal); + computation, *expected_ptr, arguments, expect_equal); } if (execution_options_.debug_options().xla_test_all_input_layouts()) { return ComputeAndCompareLiteralWithAllInputLayouts( - computation, expected, arguments, expect_equal, shape_with_layout); + computation, *expected_ptr, arguments, expect_equal, shape_with_layout); } TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments, shape_with_layout)); - LiteralTestUtil::ExpectEqual(expected, *actual); + LiteralTestUtil::ExpectEqual(*expected_ptr, *actual); return tensorflow::Status::OK(); } @@ -286,20 +305,39 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( TF_RET_CHECK(ShapeUtil::ElementIsFloating(expected.shape()) || ShapeUtil::ElementIsComplex(expected.shape())); TF_ASSIGN_OR_RETURN(auto computation, builder->Build()); + // We allow using a float expected literal for a bfloat16 output. In this + // case, we need to convert the expected literal to bfloat16. + const Literal* expected_ptr = &expected; + std::unique_ptr converted_expected; + Shape layout_shape; + if (use_bfloat16_) { + converted_expected = LiteralTestUtil::ConvertF32ToBF16(expected); + expected_ptr = converted_expected.get(); + if (shape_with_layout != nullptr) { + layout_shape = *shape_with_layout; + ShapeUtil::ForEachMutableSubshape( + &layout_shape, [&](Shape* subshape, const ShapeIndex& /*index*/) { + if (subshape->element_type() == F32) { + subshape->set_element_type(BF16); + } + }); + shape_with_layout = &layout_shape; + } + } auto expect_near = [&](const Literal& actual, const string& error_message) { - LiteralTestUtil::ExpectNear(expected, actual, error, error_message); + LiteralTestUtil::ExpectNear(*expected_ptr, actual, error, error_message); }; if (execution_options_.debug_options().xla_test_all_output_layouts()) { - return ComputeAndCompareLiteralWithAllOutputLayouts(computation, expected, - arguments, expect_near); + return ComputeAndCompareLiteralWithAllOutputLayouts( + computation, *expected_ptr, arguments, expect_near); } if (execution_options_.debug_options().xla_test_all_input_layouts()) { return ComputeAndCompareLiteralWithAllInputLayouts( - computation, expected, arguments, expect_near, shape_with_layout); + computation, *expected_ptr, arguments, expect_near, shape_with_layout); } TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments, shape_with_layout)); - LiteralTestUtil::ExpectNear(expected, *actual, error); + LiteralTestUtil::ExpectNear(*expected_ptr, *actual, error); return tensorflow::Status::OK(); } @@ -346,10 +384,67 @@ void ClientLibraryTestBase::ComputeAndCompareTuple( LiteralTestUtil::ExpectNearTuple(expected, *actual, error); } +void ClientLibraryTestBase::ComputeAndCompare( + ComputationBuilder* builder, const ComputationDataHandle& operand, + tensorflow::gtl::ArraySlice arguments) { + auto status_or_data = ComputeValueAndReference(builder, operand, arguments); + EXPECT_IS_OK(status_or_data); + if (!status_or_data.ok()) { + return; + } + std::unique_ptr reference, result; + std::tie(reference, result) = status_or_data.ConsumeValueOrDie(); + LiteralTestUtil::ExpectEqual(*reference, *result); +} + +void ClientLibraryTestBase::ComputeAndCompare( + ComputationBuilder* builder, const ComputationDataHandle& operand, + tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { + auto status_or_data = ComputeValueAndReference(builder, operand, arguments); + EXPECT_IS_OK(status_or_data); + if (!status_or_data.ok()) { + return; + } + std::unique_ptr reference, result; + std::tie(reference, result) = status_or_data.ConsumeValueOrDie(); + LiteralTestUtil::ExpectNear(*reference, *result, error); +} + +StatusOr, std::unique_ptr>> +ClientLibraryTestBase::ComputeValueAndReference( + ComputationBuilder* builder, const ComputationDataHandle& operand, + tensorflow::gtl::ArraySlice arguments) { + // Transfer the arguments to the executor service. We put the unique_ptr's + // into a vector to keep the data alive on the service until the end of this + // function. + std::vector> argument_data; + for (const auto& arg : arguments) { + TF_ASSIGN_OR_RETURN(auto data, client_->TransferToServer(arg)); + argument_data.push_back(std::move(data)); + } + + // Create raw pointers to the GlobalData for the rest of the call stack. + std::vector argument_data_ptr; + std::transform( + argument_data.begin(), argument_data.end(), + std::back_inserter(argument_data_ptr), + [](const std::unique_ptr& data) { return data.get(); }); + + TF_ASSIGN_OR_RETURN( + auto reference, + builder->ComputeConstant(operand, /*output_layout=*/nullptr, arguments)); + TF_ASSIGN_OR_RETURN(auto result, + ExecuteAndTransfer(builder, argument_data_ptr)); + return std::make_pair(std::move(reference), std::move(result)); +} + Computation ClientLibraryTestBase::CreateScalarRelu() { ComputationBuilder builder(client_, "relu"); - auto z_value = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "z_value"); - auto zero = builder.ConstantR0(0.0); + auto shape = ShapeUtil::MakeShape(use_bfloat16_ ? BF16 : F32, {}); + auto z_value = builder.Parameter(0, shape, "z_value"); + auto zero = use_bfloat16_ + ? builder.ConstantR0(static_cast(0.0f)) + : builder.ConstantR0(0.0f); builder.Max(z_value, zero); auto computation_status = builder.Build(); TF_CHECK_OK(computation_status.status()); @@ -358,8 +453,9 @@ Computation ClientLibraryTestBase::CreateScalarRelu() { Computation ClientLibraryTestBase::CreateScalarMax() { ComputationBuilder builder(client_, "max"); - auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y"); + auto shape = ShapeUtil::MakeShape(use_bfloat16_ ? BF16 : F32, {}); + auto x = builder.Parameter(0, shape, "x"); + auto y = builder.Parameter(1, shape, "y"); builder.Max(x, y); auto computation_status = builder.Build(); TF_CHECK_OK(computation_status.status()); @@ -368,11 +464,12 @@ Computation ClientLibraryTestBase::CreateScalarMax() { Computation ClientLibraryTestBase::CreateScalarReluSensitivity() { ComputationBuilder builder(client_, "relu_sensitivity"); - auto activation = - builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "activation"); - auto backprop = - builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "backprop"); - auto zero = builder.ConstantR0(0.0); + auto shape = ShapeUtil::MakeShape(use_bfloat16_ ? BF16 : F32, {}); + auto activation = builder.Parameter(0, shape, "activation"); + auto backprop = builder.Parameter(1, shape, "backprop"); + auto zero = use_bfloat16_ + ? builder.ConstantR0(static_cast(0.0f)) + : builder.ConstantR0(0.0f); auto activation_gtz = builder.Gt(activation, zero); builder.Select(activation_gtz, /*on_true=*/backprop, /*on_false=*/zero); @@ -407,4 +504,27 @@ ClientLibraryTestBase::CreatePatternedMatrixWithZeroPadding(int rows, int cols, return array; } +std::unique_ptr +ClientLibraryTestBase::CreateParameterAndTransferLiteral( + int64 parameter_number, const Literal& literal, const string& name, + ComputationBuilder* builder, ComputationDataHandle* data_handle) { + const Literal* param_literal = &literal; + std::unique_ptr converted_literal; + if (use_bfloat16_) { + converted_literal = LiteralTestUtil::ConvertF32ToBF16(literal); + param_literal = converted_literal.get(); + } + std::unique_ptr data = + client_->TransferToServer(*param_literal).ConsumeValueOrDie(); + *data_handle = + builder->Parameter(parameter_number, param_literal->shape(), name); + return data; +} + +ComputationDataHandle ClientLibraryTestBase::CreateConstantFromLiteral( + const Literal& literal, ComputationBuilder* builder) { + return builder->ConstantLiteral( + use_bfloat16_ ? *LiteralTestUtil::ConvertF32ToBF16(literal) : literal); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index 7cfc276ec19e3b177f87a08e716cb34b7676dd6b..4d0cf8bf71cf22d7c046bb22754a8d4e299ed9db 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -194,7 +194,17 @@ class ClientLibraryTestBase : public ::testing::Test { tensorflow::gtl::ArraySlice arguments); void ComputeAndCompareTuple( ComputationBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments, ErrorSpec abs_error); + tensorflow::gtl::ArraySlice arguments, ErrorSpec error); + + // Convenience method for running a built computation and comparing the result + // with the HloEvaluator. + void ComputeAndCompare(ComputationBuilder* builder, + const ComputationDataHandle& operand, + tensorflow::gtl::ArraySlice arguments); + void ComputeAndCompare(ComputationBuilder* builder, + const ComputationDataHandle& operand, + tensorflow::gtl::ArraySlice arguments, + ErrorSpec error); // Create scalar operations for use in reductions. Computation CreateScalarRelu(); @@ -235,51 +245,102 @@ class ClientLibraryTestBase : public ::testing::Test { const int rows, const int cols, const int rows_padded, const int cols_padded); - // Create a parameter instruction that wraps a given value and then stores + // Creates a parameter instruction, transfers the literal for the parameter to + // server, then stores into "data_handle" the global handle for that + // parameter. When the use_bfloat16 flag is set but the literal has F32 + // elements, the literal will be converted to BF16 before being transferred. + std::unique_ptr CreateParameterAndTransferLiteral( + int64 parameter_number, const Literal& literal, const string& name, + ComputationBuilder* builder, ComputationDataHandle* data_handle); + + // Creates a constant instruction with the given literal. When the + // use_bfloat16 flag is set but the literal has F32 elements, the elements + // will be converted to BF16s. + ComputationDataHandle CreateConstantFromLiteral(const Literal& literal, + ComputationBuilder* builder); + + // Creates a constant instruction with the given array. When the use_bfloat16 + // flag is set but the array has float elements, the elements will be + // converted to bfloat16s. + template + ComputationDataHandle CreateConstantFromArray(const Array& array, + ComputationBuilder* builder) { + return CreateConstantFromLiteral(*Literal::CreateFromArray(array), builder); + } + + // Same as CreateConstantFromArray, but for scalars. + template + ComputationDataHandle CreateConstantFromScalar(NativeT value, + ComputationBuilder* builder) { + return CreateConstantFromLiteral(*Literal::CreateR0(value), + builder); + } + + // Creates a parameter instruction that wraps a given value and then stores // into "data_handle" the global handle for that parameter. // // "parameter_number" is the parameter number. // "name" is the name of the parameter instruction. + // + // When the use_bfloat16 flag is set but NativeT is float, the data will be + // converted to bfloat16. template std::unique_ptr CreateR0Parameter( NativeT value, int64 parameter_number, const string& name, ComputationBuilder* builder, ComputationDataHandle* data_handle); - // Create a parameter instruction that wraps the given values and then stores + // Creates a parameter instruction that wraps the given values and then stores // into "data_handle" the global handle for that parameter. // // "parameter_number" is the parameter number. // "name" is the name of the parameter instruction. + // + // When the use_bfloat16 flag is set but NativeT is float, the data will be + // converted to bfloat16. template std::unique_ptr CreateR1Parameter( tensorflow::gtl::ArraySlice values, int64 parameter_number, const string& name, ComputationBuilder* builder, ComputationDataHandle* data_handle); - // Create a parameter instruction that wraps the given constant array + // Creates a parameter instruction that wraps the given constant array // "array_2d" and then stores to "data_handle" the global handle for that // parameter. // // "parameter_number" is the parameter number. // "name" is the name of the parameter instruction. + // + // When the use_bfloat16 flag is set but NativeT is float, the data will be + // converted to bfloat16. template std::unique_ptr CreateR2Parameter( const Array2D& array_2d, int64 parameter_number, const string& name, ComputationBuilder* builder, ComputationDataHandle* data_handle); - // Create a parameter instruction that wraps the given constant array + // Creates a parameter instruction that wraps the given constant array // "array_3d" and then stores to "data_handle" the global handle for that // parameter. // // "parameter_number" is the parameter number. // "name" is the name of the parameter instruction. + // + // When the use_bfloat16 flag is set but NativeT is float, the data will be + // converted to bfloat16. template std::unique_ptr CreateR3Parameter( const Array3D& array_3d, int64 parameter_number, const string& name, ComputationBuilder* builder, ComputationDataHandle* data_handle); + // Getter and setter for the use_bfloat16 flag, which indicates whether to run + // tests with all float-type input/output converted to bfloat16. + bool use_bfloat16() const { return use_bfloat16_; } + void set_use_bfloat16(bool value) { use_bfloat16_ = value; } + + // The float type used in this test, BF16 or F32 according to use_bfloat16. + PrimitiveType FloatType() const { return use_bfloat16_ ? BF16 : F32; } + Client* client_; ExecutionOptions execution_options_; @@ -298,6 +359,17 @@ class ClientLibraryTestBase : public ::testing::Test { const std::function& verify_output, const Shape* output_with_layout = nullptr); + + // Executes the computation and calculates the expected reference value using + // the HloEvaluator. Returns two literal in the order of (expected, actual). + StatusOr, std::unique_ptr>> + ComputeValueAndReference(ComputationBuilder* builder, + const ComputationDataHandle& operand, + tensorflow::gtl::ArraySlice arguments); + + // Whether to run tests with all float-type input/output converted to + // bfloat16. + bool use_bfloat16_ = false; }; template @@ -315,8 +387,10 @@ void ClientLibraryTestBase::ComputeAndCompareR0( ComputationBuilder* builder, NativeT expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { static_assert(std::is_same::value || - std::is_same::value, - "Floating point type required when specifying an ErrorSpec"); + std::is_same::value || + std::is_same::value || + std::is_same::value, + "Float or complex type required when specifying an ErrorSpec"); std::unique_ptr expected_literal = Literal::CreateR0(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, @@ -338,8 +412,10 @@ void ClientLibraryTestBase::ComputeAndCompareR1( ComputationBuilder* builder, tensorflow::gtl::ArraySlice expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { static_assert(std::is_same::value || - std::is_same::value, - "Floating point type required when specifying an ErrorSpec"); + std::is_same::value || + std::is_same::value || + std::is_same::value, + "Float or complex type required when specifying an ErrorSpec"); std::unique_ptr expected_literal = Literal::CreateR1(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, @@ -362,6 +438,7 @@ void ClientLibraryTestBase::ComputeAndCompareR2( tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { static_assert(std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); std::unique_ptr expected_literal = @@ -386,6 +463,7 @@ void ClientLibraryTestBase::ComputeAndCompareR3( tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { static_assert(std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); std::unique_ptr expected_literal = @@ -410,6 +488,7 @@ void ClientLibraryTestBase::ComputeAndCompareR4( tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { static_assert(std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); std::unique_ptr expected_literal = @@ -423,6 +502,9 @@ std::unique_ptr ClientLibraryTestBase::CreateR0Parameter( NativeT value, int64 parameter_number, const string& name, ComputationBuilder* builder, ComputationDataHandle* data_handle) { std::unique_ptr literal = Literal::CreateR0(value); + if (use_bfloat16_ && literal->shape().element_type() == F32) { + literal = LiteralTestUtil::ConvertF32ToBF16(*literal); + } std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); *data_handle = builder->Parameter(parameter_number, literal->shape(), name); @@ -435,6 +517,9 @@ std::unique_ptr ClientLibraryTestBase::CreateR1Parameter( const string& name, ComputationBuilder* builder, ComputationDataHandle* data_handle) { std::unique_ptr literal = Literal::CreateR1(values); + if (use_bfloat16_ && literal->shape().element_type() == F32) { + literal = LiteralTestUtil::ConvertF32ToBF16(*literal); + } std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); *data_handle = builder->Parameter(parameter_number, literal->shape(), name); @@ -447,6 +532,9 @@ std::unique_ptr ClientLibraryTestBase::CreateR2Parameter( const string& name, ComputationBuilder* builder, ComputationDataHandle* data_handle) { std::unique_ptr literal = Literal::CreateR2FromArray2D(array_2d); + if (use_bfloat16_ && literal->shape().element_type() == F32) { + literal = LiteralTestUtil::ConvertF32ToBF16(*literal); + } std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); *data_handle = builder->Parameter(parameter_number, literal->shape(), name); @@ -459,6 +547,9 @@ std::unique_ptr ClientLibraryTestBase::CreateR3Parameter( const string& name, ComputationBuilder* builder, ComputationDataHandle* data_handle) { std::unique_ptr literal = Literal::CreateR3FromArray3D(array_3d); + if (use_bfloat16_ && literal->shape().element_type() == F32) { + literal = LiteralTestUtil::ConvertF32ToBF16(*literal); + } std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); *data_handle = builder->Parameter(parameter_number, literal->shape(), name); @@ -469,8 +560,7 @@ template std::vector ClientLibraryTestBase::CreatePseudorandomR1( const int width, NativeT min_value, NativeT max_value, uint32 seed) { std::vector result(width); - test_utils::PseudorandomGenerator generator(min_value, max_value, - seed); + PseudorandomGenerator generator(min_value, max_value, seed); for (int i = 0; i < width; ++i) { result[i] = generator.get(); } @@ -482,8 +572,7 @@ std::unique_ptr> ClientLibraryTestBase::CreatePseudorandomR2( const int rows, const int cols, NativeT min_value, NativeT max_value, uint32 seed) { auto result = MakeUnique>(rows, cols); - test_utils::PseudorandomGenerator generator(min_value, max_value, - seed); + PseudorandomGenerator generator(min_value, max_value, seed); for (int y = 0; y < rows; ++y) { for (int x = 0; x < cols; ++x) { (*result)(y, x) = generator.get(); diff --git a/tensorflow/compiler/xla/tests/client_test.cc b/tensorflow/compiler/xla/tests/client_test.cc index 0853feeebd6f7a249cf767e1f8a63675d4bddd27..8853ed9e5780672d4006c326291767b8b5253f56 100644 --- a/tensorflow/compiler/xla/tests/client_test.cc +++ b/tensorflow/compiler/xla/tests/client_test.cc @@ -20,10 +20,12 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/test.h" @@ -42,26 +44,26 @@ TEST_F(ClientTest, ExecuteWithLayout) { for (const std::vector& transfer_layout : layouts) { b.Add(b.ConstantR2({{1, 2}, {3, 4}}), b.ConstantR2({{10, 20}, {30, 40}})); - auto computation = b.Build(); - ASSERT_TRUE(computation.ok()) << computation.status(); + TF_ASSERT_OK_AND_ASSIGN(auto computation, b.Build()); ExecutionOptions execution_options = execution_options_; *execution_options.mutable_shape_with_output_layout() = ShapeUtil::MakeShapeWithLayout(S32, /*dimensions=*/{2, 2}, execute_layout); - std::unique_ptr data = - client_->Execute(computation.ValueOrDie(), {}, &execution_options) - .ConsumeValueOrDie(); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr data, + client_->Execute(computation, {}, &execution_options)); std::unique_ptr expected_literal = - test_utils::CreateR2LiteralWithLayout({{11, 22}, {33, 44}}, - transfer_layout); + Literal::CreateR2WithLayout( + {{11, 22}, {33, 44}}, LayoutUtil::MakeLayout(transfer_layout)); - auto computed = client_->Transfer(*data, &expected_literal->shape()); + TF_ASSERT_OK_AND_ASSIGN( + auto computed, client_->Transfer(*data, &expected_literal->shape())); - LiteralTestUtil::AssertEqualShapesAndLayouts( - expected_literal->shape(), computed.ValueOrDie()->shape()); - LiteralTestUtil::ExpectEqual(*expected_literal, *computed.ValueOrDie()); + LiteralTestUtil::AssertEqualShapesAndLayouts(expected_literal->shape(), + computed->shape()); + LiteralTestUtil::ExpectEqual(*expected_literal, *computed); } } } @@ -72,8 +74,7 @@ TEST_F(ClientTest, ExecuteWithTupleLayout) { b.Tuple({b.ConstantR2({{1, 2}, {3, 4}}), b.ConstantR2({{10, 20}, {30, 40}})}); - auto computation = b.Build(); - ASSERT_TRUE(computation.ok()) << computation.status(); + TF_ASSERT_OK_AND_ASSIGN(auto computation, b.Build()); ExecutionOptions execution_options = execution_options_; // Create a result shape with one element column major and the other row @@ -85,10 +86,9 @@ TEST_F(ClientTest, ExecuteWithTupleLayout) { ShapeUtil::MakeShapeWithLayout(S32, /*dimensions=*/{2, 2}, /*minor_to_major=*/{1, 0})}); - auto result = - client_ - ->ExecuteAndTransfer(computation.ValueOrDie(), {}, &execution_options) - .ConsumeValueOrDie(); + TF_ASSERT_OK_AND_ASSIGN( + auto result, + client_->ExecuteAndTransfer(computation, {}, &execution_options)); LiteralTestUtil::ExpectR2Equal({{1, 2}, {3, 4}}, result->tuple_literals(0)); LiteralTestUtil::ExpectR2Equal({{10, 20}, {30, 40}}, @@ -107,5 +107,42 @@ TEST_F(ClientTest, ExecuteWithTupleLayout) { /*minor_to_major=*/{1, 0}))); } +TEST_F(ClientTest, DISABLED_ON_CPU_PARALLEL(DISABLED_ON_GPU(ExecuteParallel))) { + Computation add_with_one_arg, mul_with_two_args, dot_with_one_arg; + Shape shape = ShapeUtil::MakeShape(S32, {2, 2}); + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr const_arg, + client_->TransferToServer(*Literal::CreateR2({{5, 6}, {7, 8}}))); + + ComputationBuilder b(client_, TestName() + ".add"); + b.Add(b.Parameter(0, shape, "param_0"), + b.ConstantR2({{1, 2}, {3, 4}})); + TF_ASSERT_OK_AND_ASSIGN(add_with_one_arg, b.Build()); + + // We can't really test parallel execution on CPU since all of the cores in a + // CPU are presented as a single device. So for now we test "parallel" + // execution on a single device. + std::vector computation_instances; + TF_ASSERT_OK_AND_ASSIGN(std::vector devices, + client_->GetDeviceHandles(1)); + ASSERT_EQ(devices.size(), 1); + + ExecutionOptions options = execution_options_; + *options.add_device_handles() = devices[0]; + computation_instances.push_back(Client::ComputationInstance( + add_with_one_arg, {const_arg.get()}, options, nullptr)); + + TF_ASSERT_OK_AND_ASSIGN(auto results, + client_->ExecuteParallel(computation_instances)); + auto expected_result = Literal::CreateR2({{6, 8}, {10, 12}}); + + TF_ASSERT_OK_AND_ASSIGN( + auto result_literal, + client_->Transfer(*results[0], &expected_result->shape())); + + LiteralTestUtil::ExpectEqual(*expected_result, *result_literal); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/codegen_test_base.cc b/tensorflow/compiler/xla/tests/codegen_test_base.cc index 43ea7f6019415a171123ee0315533b8a3b1ff984..e472408dcf7ed5fec74e886fd0092ce47ee2e7eb 100644 --- a/tensorflow/compiler/xla/tests/codegen_test_base.cc +++ b/tensorflow/compiler/xla/tests/codegen_test_base.cc @@ -19,8 +19,11 @@ namespace xla { StatusOr> CodegenTestBase::CompileToExecutable( std::unique_ptr hlo_module) { - return backend().compiler()->Compile(std::move(hlo_module), - backend().default_stream_executor()); + TF_ASSIGN_OR_RETURN(hlo_module, backend().compiler()->RunHloPasses( + std::move(hlo_module), + backend().default_stream_executor())); + return backend().compiler()->RunBackend(std::move(hlo_module), + backend().default_stream_executor()); } StatusOr> diff --git a/tensorflow/compiler/xla/tests/compilation_cache_test.cc b/tensorflow/compiler/xla/tests/compilation_cache_test.cc index 707e439245c29a1ddf80bfd9205aa14b0d4765f6..0f780fa87ef98fd5c48726ef83fa8efc1e90fbf7 100644 --- a/tensorflow/compiler/xla/tests/compilation_cache_test.cc +++ b/tensorflow/compiler/xla/tests/compilation_cache_test.cc @@ -138,13 +138,13 @@ XLA_TEST_F(CompilationCacheTest, DifferentParameterLayouts) { // layouts. Use these arrays as parameters to a simple computation. If the // layout of the array changes then computation should be recompiled (cache // miss). - auto rowmaj_array = test_utils::CreateR2LiteralWithLayout( - {{1.0f, 2.0f}, {3.0f, 4.0f}}, /*minor_to_major=*/{1, 0}); + auto rowmaj_array = Literal::CreateR2WithLayout( + {{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({1, 0})); auto rowmaj_handle = client_->TransferToServer(*rowmaj_array).ConsumeValueOrDie(); - auto colmaj_array = test_utils::CreateR2LiteralWithLayout( - {{1.0f, 2.0f}, {3.0f, 4.0f}}, /*minor_to_major=*/{0, 1}); + auto colmaj_array = Literal::CreateR2WithLayout( + {{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({0, 1})); auto colmaj_handle = client_->TransferToServer(*colmaj_array).ConsumeValueOrDie(); diff --git a/tensorflow/compiler/xla/tests/compute_constant_test.cc b/tensorflow/compiler/xla/tests/compute_constant_test.cc index b2e9743af79d0e4658451e7a9522c338036851ba..5226a78386824a94572d3e5cc3329677108a910a 100644 --- a/tensorflow/compiler/xla/tests/compute_constant_test.cc +++ b/tensorflow/compiler/xla/tests/compute_constant_test.cc @@ -71,24 +71,27 @@ class ComputeConstantTest : public ::testing::Test { StatusOr> ComputeConstantLiteral( Client* client, const ComputationDataHandle& operand, - ComputationBuilder* builder, Layout* output_layout = nullptr) { - TF_ASSIGN_OR_RETURN(auto computed, - builder->ComputeConstant(operand, output_layout)); + ComputationBuilder* builder, Layout* output_layout = nullptr, + tensorflow::gtl::ArraySlice parameters = {}) { + TF_ASSIGN_OR_RETURN(auto computed, builder->ComputeConstant( + operand, output_layout, parameters)); return std::move(computed); } template - StatusOr ComputeConstantScalar(Client* client, - const ComputationDataHandle& operand, - ComputationBuilder* builder) { - TF_ASSIGN_OR_RETURN(auto literal, - ComputeConstantLiteral(client, operand, builder)); + StatusOr ComputeConstantScalar( + Client* client, const ComputationDataHandle& operand, + ComputationBuilder* builder, + tensorflow::gtl::ArraySlice parameters = {}) { + TF_ASSIGN_OR_RETURN( + auto literal, + ComputeConstantLiteral(client, operand, builder, nullptr, parameters)); return literal->Get({}); } bool IsConstant(const ComputationDataHandle& operand, - ComputationBuilder* builder) { - StatusOr result = builder->IsConstant(operand); + ComputationBuilder* builder, int64 num_parameters = 0) { + StatusOr result = builder->IsConstant(operand, num_parameters); EXPECT_TRUE(result.ok()) << result.status(); return result.ok() ? result.ValueOrDie() : false; } @@ -138,7 +141,25 @@ TEST_F(ComputeConstantTest, ScalarRng) { } } -TEST_F(ComputeConstantTest, DirectParam) { +TEST_F(ComputeConstantTest, Param) { + for (ClientType client_type : client_types) { + Client* client = ClientOrDie(platform_, client_type); + ComputationBuilder b(client, TestName()); + auto param = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "lhs"); + auto computation = b.Add(param, b.ConstantR0(1.5f)); + + std::vector arguments; + arguments.emplace_back(*Literal::CreateR0(42.5f)); + EXPECT_TRUE(IsConstant(computation, &b, arguments.size())); + + auto value = + ComputeConstantScalar(client, computation, &b, arguments); + ASSERT_TRUE(value.ok()) << value.status(); + EXPECT_EQ(value.ValueOrDie(), 44.0f); + } +} + +TEST_F(ComputeConstantTest, DirectParamMissing) { for (ClientType client_type : client_types) { Client* client = ClientOrDie(platform_, client_type); ComputationBuilder b(client, TestName()); @@ -152,7 +173,7 @@ TEST_F(ComputeConstantTest, DirectParam) { } } -TEST_F(ComputeConstantTest, IndirectParam) { +TEST_F(ComputeConstantTest, IndirectParamMissing) { for (ClientType client_type : client_types) { Client* client = ClientOrDie(platform_, client_type); ComputationBuilder b(client, TestName()); @@ -243,8 +264,8 @@ XLA_TEST_F(ComputeConstantTest, Layout) { ASSERT_TRUE(computed.ok()) << computed.status(); std::unique_ptr expected_literal = - test_utils::CreateR2LiteralWithLayout({{11, 22}, {33, 44}}, - layout); + Literal::CreateR2WithLayout({{11, 22}, {33, 44}}, + LayoutUtil::MakeLayout(layout)); LiteralTestUtil::AssertEqualShapesAndLayouts( expected_literal->shape(), computed.ValueOrDie()->shape()); LiteralTestUtil::ExpectEqual(*expected_literal, *computed.ValueOrDie()); diff --git a/tensorflow/compiler/xla/tests/conditional_test.cc b/tensorflow/compiler/xla/tests/conditional_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..cbfacaea53952b02596eb3e84b13a5749335651d --- /dev/null +++ b/tensorflow/compiler/xla/tests/conditional_test.cc @@ -0,0 +1,238 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" + +namespace xla { +namespace { + +class ConditionalOpTest : public ClientLibraryTestBase { + protected: + Computation CreateR0F32ConstantComputation(float value) { + ComputationBuilder builder(client_, "Constant"); + builder.Parameter(0, empty_tuple_, "tuple"); + builder.ConstantR0(value); + auto build_status = builder.Build(); + EXPECT_IS_OK(build_status.status()); + return build_status.ConsumeValueOrDie(); + } + + Computation CreateR0F32IdentityComputation() { + ComputationBuilder builder(client_, "Identity"); + builder.Parameter(0, r0f32_, "x"); + auto build_status = builder.Build(); + EXPECT_IS_OK(build_status.status()); + return build_status.ConsumeValueOrDie(); + } + + Computation CreateR0F32CeilComputation() { + ComputationBuilder builder(client_, "Ceil"); + auto param = builder.Parameter(0, r0f32_, "param"); + builder.Ceil(param); + auto build_status = builder.Build(); + EXPECT_IS_OK(build_status.status()); + return build_status.ConsumeValueOrDie(); + } + + Computation CreateR0F32FloorComputation() { + ComputationBuilder builder(client_, "Ceil"); + auto param = builder.Parameter(0, r0f32_, "param"); + builder.Floor(param); + auto build_status = builder.Build(); + EXPECT_IS_OK(build_status.status()); + return build_status.ConsumeValueOrDie(); + } + + Computation CreateAddTupleComputation(const string& computation_name, + const Shape& tuple_shape) { + ComputationBuilder builder(client_, computation_name); + auto tuple = builder.Parameter(0, tuple_shape, "tuple"); + auto x = builder.GetTupleElement(tuple, 0); + auto y = builder.GetTupleElement(tuple, 1); + builder.Add(x, y); + auto build_status = builder.Build(); + EXPECT_IS_OK(build_status.status()); + return build_status.ConsumeValueOrDie(); + } + + Computation CreateAddR0Computation() { + return CreateAddTupleComputation("AddR0", tuple_2_r0f32_); + } + + Computation CreateAddR1Computation() { + return CreateAddTupleComputation("AddR1", tuple_2_r1s2f32_); + } + + Computation CreateSubTupleComputation(const string& computation_name, + const Shape& tuple_shape) { + ComputationBuilder builder(client_, computation_name); + auto tuple = builder.Parameter(0, tuple_shape, "tuple"); + auto x = builder.GetTupleElement(tuple, 0); + auto y = builder.GetTupleElement(tuple, 1); + builder.Sub(x, y); + auto build_status = builder.Build(); + EXPECT_IS_OK(build_status.status()); + return build_status.ConsumeValueOrDie(); + } + + Computation CreateSubR0Computation() { + return CreateSubTupleComputation("SubR0", tuple_2_r0f32_); + } + + Computation CreateSubR1Computation() { + return CreateSubTupleComputation("SubR1", tuple_2_r1s2f32_); + } + + Shape r0f32_ = ShapeUtil::MakeShape(F32, {}); + Shape tuple_2_r0f32_ = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}); + Shape tuple_2_r1s2f32_ = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {2}), ShapeUtil::MakeShape(F32, {2})}); + Shape empty_tuple_ = ShapeUtil::MakeTupleShape({}); + ErrorSpec error_spec_{0.001}; +}; + +// Test true and false computations that do not take any parameters. +XLA_TEST_F(ConditionalOpTest, Parameters0) { + ComputationBuilder builder(client_, TestName()); + auto pred = builder.ConstantR0(true); + auto operands = builder.Tuple({}); + auto true_computation = CreateR0F32ConstantComputation(56.0f); + auto false_computation = CreateR0F32ConstantComputation(12.0f); + auto result = builder.Conditional(pred, operands, true_computation, operands, + false_computation); + + ComputeAndCompareR0(&builder, 56.0f, {}, error_spec_); +} + +// Test true and false computations that take in 1 parameter. +XLA_TEST_F(ConditionalOpTest, Parameters1) { + ComputationBuilder builder(client_, TestName()); + auto pred = builder.ConstantR0(false); + auto operand1 = builder.ConstantR0(56.0f); + auto operand2 = builder.ConstantR0(12.0f); + auto identity = CreateR0F32IdentityComputation(); + auto result = + builder.Conditional(pred, operand1, identity, operand2, identity); + + ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); +} + +// Test true and false computations that take in 2 parameters and predicate is +// true. +XLA_TEST_F(ConditionalOpTest, Parameters2TrueBranch) { + ComputationBuilder builder(client_, TestName()); + auto pred = builder.ConstantR0(true); + auto operand1 = builder.ConstantR0(56.0f); + auto operand2 = builder.ConstantR0(12.0f); + auto operands = builder.Tuple({operand1, operand2}); + auto result = builder.Conditional(pred, operands, CreateAddR0Computation(), + operands, CreateSubR0Computation()); + + ComputeAndCompareR0(&builder, 68.0f, {}, error_spec_); +} + +// Test true and false computations that take in 2 parameters and predicate is +// false. +XLA_TEST_F(ConditionalOpTest, Parameters2FalseBranch) { + ComputationBuilder builder(client_, TestName()); + auto pred = builder.ConstantR0(false); + auto operand1 = builder.ConstantR0(56.0f); + auto operand2 = builder.ConstantR0(12.0f); + auto operands = builder.Tuple({operand1, operand2}); + auto result = builder.Conditional(pred, operands, CreateAddR0Computation(), + operands, CreateSubR0Computation()); + + ComputeAndCompareR0(&builder, 44.0f, {}, error_spec_); +} + +// Test true and false computations that take in 2 array parameters and +// predicate is true. +XLA_TEST_F(ConditionalOpTest, Parameters2ArrayTrueBranch) { + ComputationBuilder builder(client_, TestName()); + auto pred = builder.ConstantR0(true); + auto operand1 = builder.ConstantR1({24.0f, 56.0f}); + auto operand2 = builder.ConstantR1({10.0f, 11.0f}); + auto operands = builder.Tuple({operand1, operand2}); + auto result = builder.Conditional(pred, operands, CreateAddR1Computation(), + operands, CreateSubR1Computation()); + + ComputeAndCompareR1(&builder, {34.0f, 67.0f}, {}, error_spec_); +} + +// Test true and false computations that take in 2 array parameters and +// predicate is false. +XLA_TEST_F(ConditionalOpTest, Parameters2ArrayFalseBranch) { + ComputationBuilder builder(client_, TestName()); + auto pred = builder.ConstantR0(false); + auto operand1 = builder.ConstantR1({24.0f, 56.0f}); + auto operand2 = builder.ConstantR1({10.0f, 11.0f}); + auto operands = builder.Tuple({operand1, operand2}); + auto result = builder.Conditional(pred, operands, CreateAddR1Computation(), + operands, CreateSubR1Computation()); + + ComputeAndCompareR1(&builder, {14.0f, 45.0f}, {}, error_spec_); +} + +// Test the case where one conditional is nested within another. +XLA_TEST_F(ConditionalOpTest, NestedConditionals) { + Shape r0bool = ShapeUtil::MakeShape(PRED, {}); + Shape tuple_shape = ShapeUtil::MakeTupleShape({r0bool, r0f32_, r0f32_}); + ComputationBuilder inner_builder(client_, TestName() + ".inner_conditional"); + auto param0 = inner_builder.Parameter(0, tuple_shape, "param0"); + auto pred_cond = inner_builder.GetTupleElement(param0, 0); + auto true_operand = inner_builder.GetTupleElement(param0, 1); + auto false_operand = inner_builder.GetTupleElement(param0, 2); + inner_builder.Conditional(pred_cond, true_operand, + CreateR0F32CeilComputation(), false_operand, + CreateR0F32FloorComputation()); + auto inner_builder_result = inner_builder.Build(); + + ComputationBuilder builder(client_, TestName()); + auto pred1 = builder.ConstantR0(true); + auto pred2 = builder.ConstantR0(false); + auto operand1 = builder.ConstantR0(1.1f); + auto operand2 = builder.ConstantR0(12.2f); + auto operand3 = builder.ConstantR0(43.3f); + auto tuple_operand = builder.Tuple({pred2, operand1, operand2}); + builder.Conditional(pred1, tuple_operand, + inner_builder_result.ConsumeValueOrDie(), operand3, + CreateR0F32IdentityComputation()); + + ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); +} + +// Test a mismatch in the shape of the true operand and true computation. +XLA_TEST_F(ConditionalOpTest, ShapeMismatch) { + ComputationBuilder builder(client_, TestName()); + auto pred = builder.ConstantR0(true); + auto operand1 = builder.ConstantR0(56.0f); + auto operand2 = builder.ConstantR0(12.0f); + auto operands = builder.Tuple({operand1, operand2}); + builder.Conditional(pred, operands, CreateAddR1Computation(), operands, + CreateSubR0Computation()); + + auto result = builder.Build(); + EXPECT_FALSE(result.ok()); + EXPECT_THAT(result.status().error_message(), + ::testing::HasSubstr("true_operand must match the shape of the " + "only parameter of true_computation")); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc index b0a63bccbb93f226175beff2e30e2a243fdca1d3..896b34fb6e2762c14bd9ec2bf1ba13c548d4cf60 100644 --- a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc @@ -39,8 +39,8 @@ class ConvolutionDimensionNumbersTest : public ClientLibraryTestBase {}; // Tests the convolution operation with invalid input dimension numbers. TEST_F(ConvolutionDimensionNumbersTest, InvalidInputDimensionNumbers) { auto dimension_numbers_status = - ComputationBuilder::CreateConvDimensionNumbers(0, 2, 0, 2, 2, 3, 0, 1, 2, - 3); + ComputationBuilder::CreateConvDimensionNumbers(0, 2, 2, 3, 0, 1, 2, 3, 0, + 1, 2, 3); ASSERT_FALSE(dimension_numbers_status.ok()); ASSERT_THAT(dimension_numbers_status.status().error_message(), ::testing::HasSubstr("input are not unique")); @@ -49,13 +49,23 @@ TEST_F(ConvolutionDimensionNumbersTest, InvalidInputDimensionNumbers) { // Tests the convolution operation with invalid weight dimension numbers. TEST_F(ConvolutionDimensionNumbersTest, InvalidWeightDimensionNumbers) { auto dimension_numbers_status = - ComputationBuilder::CreateConvDimensionNumbers(0, 1, 0, 1, 2, 3, 2, 3, 2, - 3); + ComputationBuilder::CreateConvDimensionNumbers(0, 1, 2, 3, 0, 1, 2, 3, 0, + 2, 2, 3); ASSERT_FALSE(dimension_numbers_status.ok()); ASSERT_THAT(dimension_numbers_status.status().error_message(), ::testing::HasSubstr("weight are not unique")); } +// Tests the convolution operation with invalid output dimension numbers. +TEST_F(ConvolutionDimensionNumbersTest, InvalidOutputDimensionNumbers) { + auto dimension_numbers_status = + ComputationBuilder::CreateConvDimensionNumbers(0, 1, 2, 3, 0, 2, 2, 3, 0, + 1, 2, 3); + ASSERT_FALSE(dimension_numbers_status.ok()); + ASSERT_THAT(dimension_numbers_status.status().error_message(), + ::testing::HasSubstr("output are not unique")); +} + XLA_TEST_F(ConvolutionDimensionNumbersTest, TwoConvsWithDifferentDimensionNumbers) { auto input_array = MakeUnique>(2, 3, 5, 5); diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc index 0cc2e5fb7e655884f3334426a684dd3ce00d4052..2924c08615fa706bb19addf04bf58e1d5dd5a659 100644 --- a/tensorflow/compiler/xla/tests/convolution_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_test.cc @@ -82,177 +82,127 @@ XLA_TEST_F(ConvolutionTest, ForwardPassConvolution_3x3x256_256_OutputZ_Iota) { ComputationBuilder builder(client_, TestName()); auto lhs = builder.ConstantR4FromArray4D(*alhs); auto rhs = builder.ConstantR4FromArray4D(*arhs); - builder.Conv(lhs, rhs, {1, 1}, Padding::kValid); + auto conv = builder.Conv(lhs, rhs, {1, 1}, Padding::kValid); - std::unique_ptr> aexpected = - ReferenceUtil::ConvArray4D(*alhs, *arhs, {1, 1}, Padding::kValid); - - ComputeAndCompareR4(&builder, *aexpected, {}, error_spec_); + ComputeAndCompare(&builder, conv, {}, error_spec_); } TEST_F(ConvolutionTest, Convolve_1x1x1x2_1x1x1x2_Valid) { ComputationBuilder builder(client_, TestName()); - { - Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2}); - Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2}); - auto input = builder.Parameter(0, input_shape, "input"); - auto filter = builder.Parameter(1, filter_shape, "filter"); - builder.Conv(input, filter, {1, 1}, Padding::kValid); - } - - Array4D input(1, 1, 1, 2); - input.FillWithYX(Array2D({ + Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2}); + Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2}); + auto input = builder.Parameter(0, input_shape, "input"); + auto filter = builder.Parameter(1, filter_shape, "filter"); + auto conv = builder.Conv(input, filter, {1, 1}, Padding::kValid); + + Array4D input_data(1, 1, 1, 2); + input_data.FillWithYX(Array2D({ {1, 2}, })); - Array4D filter(1, 1, 1, 2); - filter.FillWithYX(Array2D({ + Array4D filter_data(1, 1, 1, 2); + filter_data.FillWithYX(Array2D({ {5, 6}, })); - std::unique_ptr> aexpected = - ReferenceUtil::ConvArray4D(input, filter, {1, 1}, Padding::kValid); - - auto input_literal = - client_->TransferToServer(*Literal::CreateR4FromArray4D(input)) - .ConsumeValueOrDie(); - auto filter_literal = - client_->TransferToServer(*Literal::CreateR4FromArray4D(filter)) - .ConsumeValueOrDie(); - - ComputeAndCompareR4(&builder, *aexpected, - {input_literal.get(), filter_literal.get()}, - error_spec_); + ComputeAndCompare(&builder, conv, + {*Literal::CreateFromArray(input_data), + *Literal::CreateFromArray(filter_data)}, + error_spec_); } // Tests valid padding for 2D convolution in raster space. TEST_F(ConvolutionTest, Convolve_1x1x4x4_1x1x2x2_Valid) { ComputationBuilder builder(client_, TestName()); - { - Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4}); - Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 2, 2}); - auto input = builder.Parameter(0, input_shape, "input"); - auto filter = builder.Parameter(1, filter_shape, "filter"); - builder.Conv(input, filter, {1, 1}, Padding::kValid); - } + Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4}); + Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 2, 2}); + auto input = builder.Parameter(0, input_shape, "input"); + auto filter = builder.Parameter(1, filter_shape, "filter"); + auto conv = builder.Conv(input, filter, {1, 1}, Padding::kValid); - Array4D input(1, 1, 4, 4); + Array4D input_data(1, 1, 4, 4); // clang-format off - input.FillWithYX(Array2D({ + input_data.FillWithYX(Array2D({ {1, 2, 3, 4 }, {5, 6, 7, 8 }, {9, 10, 11, 12}, {13, 14, 15, 16}, })); // clang-format on - Array4D filter(1, 1, 2, 2); + Array4D filter_data(1, 1, 2, 2); // clang-format off - filter.FillWithYX(Array2D({ + filter_data.FillWithYX(Array2D({ {5, 6}, {7, 8}, })); // clang-format on - - std::unique_ptr> aexpected = - ReferenceUtil::ConvArray4D(input, filter, {1, 1}, Padding::kValid); - - auto input_literal = - client_->TransferToServer(*Literal::CreateR4FromArray4D(input)) - .ConsumeValueOrDie(); - auto filter_literal = - client_->TransferToServer(*Literal::CreateR4FromArray4D(filter)) - .ConsumeValueOrDie(); - - ComputeAndCompareR4(&builder, *aexpected, - {input_literal.get(), filter_literal.get()}, - error_spec_); + ComputeAndCompare(&builder, conv, + {*Literal::CreateFromArray(input_data), + *Literal::CreateFromArray(filter_data)}, + error_spec_); } // Tests same padding for 2D convolution in raster space. TEST_F(ConvolutionTest, Convolve_1x1x4x4_1x1x2x2_Same) { ComputationBuilder builder(client_, TestName()); - { - Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4}); - Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 2, 2}); - auto input = builder.Parameter(0, input_shape, "input"); - auto filter = builder.Parameter(1, filter_shape, "filter"); - builder.Conv(input, filter, {1, 1}, Padding::kSame); - } + Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4}); + Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 2, 2}); + auto input = builder.Parameter(0, input_shape, "input"); + auto filter = builder.Parameter(1, filter_shape, "filter"); + auto conv = builder.Conv(input, filter, {1, 1}, Padding::kSame); - Array4D input(1, 1, 4, 4); + Array4D input_data(1, 1, 4, 4); // clang-format off - input.FillWithYX(Array2D({ + input_data.FillWithYX(Array2D({ {1, 2, 3, 4 }, {5, 6, 7, 8 }, {9, 10, 11, 12}, {13, 14, 15, 16}, })); // clang-format on - Array4D filter(1, 1, 2, 2); + Array4D filter_data(1, 1, 2, 2); // clang-format off - filter.FillWithYX(Array2D({ + filter_data.FillWithYX(Array2D({ {5, 6}, {7, 8}, })); // clang-format on - - std::unique_ptr> aexpected = - ReferenceUtil::ConvArray4D(input, filter, {1, 1}, Padding::kSame); - - auto input_literal = - client_->TransferToServer(*Literal::CreateR4FromArray4D(input)) - .ConsumeValueOrDie(); - auto filter_literal = - client_->TransferToServer(*Literal::CreateR4FromArray4D(filter)) - .ConsumeValueOrDie(); - - ComputeAndCompareR4(&builder, *aexpected, - {input_literal.get(), filter_literal.get()}, - error_spec_); + ComputeAndCompare(&builder, conv, + {*Literal::CreateFromArray(input_data), + *Literal::CreateFromArray(filter_data)}, + error_spec_); } // Tests same padding for 2D convolution in raster space with an odd sized // kernel. TEST_F(ConvolutionTest, Convolve_1x1x4x4_1x1x3x3_Same) { ComputationBuilder builder(client_, TestName()); - { - Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4}); - Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 3, 3}); - auto input = builder.Parameter(0, input_shape, "input"); - auto filter = builder.Parameter(1, filter_shape, "filter"); - builder.Conv(input, filter, {1, 1}, Padding::kSame); - } + Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4}); + Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 3, 3}); + auto input = builder.Parameter(0, input_shape, "input"); + auto filter = builder.Parameter(1, filter_shape, "filter"); + auto conv = builder.Conv(input, filter, {1, 1}, Padding::kSame); - Array4D input(1, 1, 4, 4); + Array4D input_data(1, 1, 4, 4); // clang-format off - input.FillWithYX(Array2D({ + input_data.FillWithYX(Array2D({ {1, 2, 3, 4 }, {5, 6, 7, 8 }, {9, 10, 11, 12}, {13, 14, 15, 16}, })); // clang-format on - Array4D filter(1, 1, 3, 3); + Array4D filter_data(1, 1, 3, 3); // clang-format off - filter.FillWithYX(Array2D({ + filter_data.FillWithYX(Array2D({ { 5, 6, 7}, { 8, 9, 10}, {11, 12, 13}, })); // clang-format on - - std::unique_ptr> aexpected = - ReferenceUtil::ConvArray4D(input, filter, {1, 1}, Padding::kSame); - - auto input_literal = - client_->TransferToServer(*Literal::CreateR4FromArray4D(input)) - .ConsumeValueOrDie(); - auto filter_literal = - client_->TransferToServer(*Literal::CreateR4FromArray4D(filter)) - .ConsumeValueOrDie(); - - ComputeAndCompareR4(&builder, *aexpected, - {input_literal.get(), filter_literal.get()}, - error_spec_); + ComputeAndCompare(&builder, conv, + {*Literal::CreateFromArray(input_data), + *Literal::CreateFromArray(filter_data)}, + error_spec_); } XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_Valid) { @@ -420,9 +370,12 @@ XLA_TEST_F(ConvolutionTest, Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid) { ConvolutionDimensionNumbers dnums; dnums.set_input_batch_dimension(0); dnums.set_output_batch_dimension(0); - dnums.add_spatial_dimensions(1); - dnums.add_spatial_dimensions(2); - dnums.add_spatial_dimensions(3); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.add_input_spatial_dimensions(3); + dnums.add_output_spatial_dimensions(3); dnums.set_input_feature_dimension(4); dnums.set_output_feature_dimension(4); dnums.add_kernel_spatial_dimensions(0); @@ -473,8 +426,10 @@ XLA_TEST_F(ConvolutionTest, Convolve2D_1x3x3x5_3x3x5x5_Valid) { ConvolutionDimensionNumbers dnums; dnums.set_input_batch_dimension(0); dnums.set_output_batch_dimension(0); - dnums.add_spatial_dimensions(1); - dnums.add_spatial_dimensions(2); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); dnums.set_input_feature_dimension(3); dnums.set_output_feature_dimension(3); dnums.add_kernel_spatial_dimensions(0); @@ -508,6 +463,54 @@ XLA_TEST_F(ConvolutionTest, Convolve2D_1x3x3x5_3x3x5x5_Valid) { error_spec_); } +// Test fixture to run convolution tests with and without convolution +// canonicalization enabled. +class ConvolveWithAndWithoutCanonicalization + : public ConvolutionTest, + public ::testing::WithParamInterface {}; + +XLA_TEST_P(ConvolveWithAndWithoutCanonicalization, + DISABLED_ON_GPU(Convolve2D_NoSpatialDims)) { + if (GetParam()) { + execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes( + "convolution-canonicalization"); + } + ComputationBuilder builder(client_, TestName()); + Shape input_shape = ShapeUtil::MakeShape(F32, {4, 29}); + Shape filter_shape = ShapeUtil::MakeShape(F32, {4, 10}); + + auto input = builder.Parameter(0, input_shape, "input"); + auto filter = builder.Parameter(1, filter_shape, "filter"); + + ConvolutionDimensionNumbers dnums; + dnums.set_input_feature_dimension(0); + dnums.set_input_batch_dimension(1); + dnums.set_kernel_input_feature_dimension(0); + dnums.set_kernel_output_feature_dimension(1); + dnums.set_output_batch_dimension(0); + dnums.set_output_feature_dimension(1); + auto conv = builder.ConvWithGeneralDimensions(input, filter, {}, + Padding::kValid, dnums); + + Array2D param0(4, 29); + param0.FillUnique(); + + Array2D param1(4, 10); + param1.FillUnique(); + + Array2D expected_result(29, 10); + expected_result.Fill(0); + + ComputeAndCompare( + &builder, conv, + {*Literal::CreateFromArray(param0), *Literal::CreateFromArray(param1)}, + error_spec_); +} + +INSTANTIATE_TEST_CASE_P(ConvolveWithAndWithoutCanonicalization_Instantiation, + ConvolveWithAndWithoutCanonicalization, + ::testing::Values(true, false)); + struct Convolve1DTestParam { int64 input_feature; int64 output_feature; @@ -540,7 +543,8 @@ XLA_TEST_P(Convolve1D1WindowTest, Convolve1D1Window) { ConvolutionDimensionNumbers dnums; dnums.set_input_batch_dimension(0); dnums.set_output_batch_dimension(0); - dnums.add_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); dnums.set_input_feature_dimension(2); dnums.set_output_feature_dimension(2); dnums.add_kernel_spatial_dimensions(0); diff --git a/tensorflow/compiler/xla/tests/convolution_variants_test.cc b/tensorflow/compiler/xla/tests/convolution_variants_test.cc index 9b36e3722b8f8a5d01c426425fdfb0c4b9ae3a16..9c1145def8c11f1222c63adf006102887d49f00d 100644 --- a/tensorflow/compiler/xla/tests/convolution_variants_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_variants_test.cc @@ -320,9 +320,10 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter3x3in2x2Padded) { Array4D input_array(1, 1, 2, 2, {1, 2, 3, 4}); auto input = builder.ConstantR4FromArray4D(input_array); - const Array4D filter_array(1, 1, 3, 3, {10000, 0, 1000, // row 0 - 0, 100, 0, // row 1 - 10, 0, 1}); // row 2 + const Array4D filter_array(1, 1, 3, 3, + {10000, 0, 1000, // row 0 + 0, 100, 0, // row 1 + 10, 0, 1}); // row 2 auto filter = builder.ConstantR4FromArray4D(filter_array); builder.Conv(input, filter, {1, 1}, Padding::kSame); @@ -472,7 +473,9 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x2x2Input3x1x2x2) { builder.Conv(input, filter, {1, 1}, Padding::kValid); std::vector expected_data = { - 23, 33, 43, + 23, + 33, + 43, }; Array4D expected(bs, 1, 1, 1, expected_data); ComputeAndCompareR4(&builder, expected, {}, error_spec_); @@ -669,10 +672,11 @@ XLA_TEST_F(ConvolutionVariantsTest, FlatLhsDilation) { std::iota(input_data.begin(), input_data.end(), 1.0); Array4D input_array(1, 1, 3, 4, input_data); - Array4D filter_array(1, 1, 4, 3, {100, 10, 1, // - 200, 20, 2, // - 300, 30, 3, // - 400, 40, 4}); + Array4D filter_array(1, 1, 4, 3, + {100, 10, 1, // + 200, 20, 2, // + 300, 30, 3, // + 400, 40, 4}); auto input = builder.ConstantR4FromArray4D(input_array); auto filter = builder.ConstantR4FromArray4D(filter_array); builder.ConvGeneralDilated( @@ -681,9 +685,10 @@ XLA_TEST_F(ConvolutionVariantsTest, FlatLhsDilation) { /*rhs_dilation=*/{}, ComputationBuilder::CreateDefaultConvDimensionNumbers()); - Array4D expected(1, 1, 3, 5, {204, 40, 406, 60, 608, // - 1518, 180, 1821, 210, 2124, // - 4146, 460, 4651, 510, 5156}); + Array4D expected(1, 1, 3, 5, + {204, 40, 406, 60, 608, // + 1518, 180, 1821, 210, 2124, // + 4146, 460, 4651, 510, 5156}); ComputeAndCompareR4(&builder, expected, {}, error_spec_); } @@ -926,7 +931,8 @@ XLA_TEST_F(ConvolutionVariantsTest, RandomData_Input16x16x1x1_Filter16x16x1x1) { ComputeAndCompareR4(&builder, *expected, {}, error_spec_); } -XLA_TEST_F(ConvolutionVariantsTest, RandomData_Input16x16x16x16_Filter16x16x16x16) { +XLA_TEST_F(ConvolutionVariantsTest, + RandomData_Input16x16x16x16_Filter16x16x16x16) { constexpr int bs = 16; constexpr int iz = 16; constexpr int oz = 16; @@ -976,8 +982,10 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x2x1x1Input1x2x3x1GeneralPadding) { // NHWC input format. dnums.set_input_batch_dimension(0); dnums.set_output_batch_dimension(0); - dnums.add_spatial_dimensions(1); - dnums.add_spatial_dimensions(2); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); dnums.set_input_feature_dimension(3); dnums.set_output_feature_dimension(3); @@ -1018,8 +1026,10 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input1x2x3x1GeneralPadding) { // NHWC input format. dnums.set_input_batch_dimension(0); dnums.set_output_batch_dimension(0); - dnums.add_spatial_dimensions(1); - dnums.add_spatial_dimensions(2); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); dnums.set_input_feature_dimension(3); dnums.set_output_feature_dimension(3); @@ -1060,8 +1070,10 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input1x2x3x1NoPadding) { // NHWC input format. dnums.set_input_batch_dimension(0); dnums.set_output_batch_dimension(0); - dnums.add_spatial_dimensions(1); - dnums.add_spatial_dimensions(2); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); dnums.set_input_feature_dimension(3); dnums.set_output_feature_dimension(3); @@ -1099,8 +1111,10 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x2x3Input1x2x3x2NoPadding) { // NHWC input format. dnums.set_input_batch_dimension(0); dnums.set_output_batch_dimension(0); - dnums.add_spatial_dimensions(1); - dnums.add_spatial_dimensions(2); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); dnums.set_input_feature_dimension(3); dnums.set_output_feature_dimension(3); @@ -1131,7 +1145,8 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x2x3Input1x2x3x2NoPadding) { // Conv([1,2,3], Reverse([5,6]), padding_low=1) // into // BackwardInputConv([1,2,3], [5,6], padding_low=0, padding_high=1) -XLA_TEST_F(ConvolutionVariantsTest, BackwardInputLowPaddingLessThanHighPadding) { +XLA_TEST_F(ConvolutionVariantsTest, + BackwardInputLowPaddingLessThanHighPadding) { ComputationBuilder builder(client_, TestName()); auto gradients = builder.ConstantR4FromArray4D( @@ -1149,7 +1164,8 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardInputLowPaddingLessThanHighPadding) // Conv([1], Reverse([1,10,100]), padding_high=3, base_dilation=3) // into // BackwardInputConv([1], [1,10,100], stride=3, padding=(2,1)) -XLA_TEST_F(ConvolutionVariantsTest, BackwardInputLowPaddingGreaterThanHighPadding) { +XLA_TEST_F(ConvolutionVariantsTest, + BackwardInputLowPaddingGreaterThanHighPadding) { ComputationBuilder builder(client_, TestName()); auto gradients = builder.ConstantR4FromArray4D( @@ -1206,7 +1222,8 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardInputWithNegativePaddingHigh) { ComputeAndCompareR4(&builder, {{{{12, 23, 30, 0}}}}, {}, error_spec_); } -XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterLowPaddingLessThanHighPadding) { +XLA_TEST_F(ConvolutionVariantsTest, + BackwardFilterLowPaddingLessThanHighPadding) { ComputationBuilder builder(client_, TestName()); // activations: 1,2,3,4 ---pad--> 0,1,2,3,4,0,0 @@ -1230,7 +1247,7 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterLowPaddingLessThanHighPadding) } XLA_TEST_F(ConvolutionVariantsTest, - BackwardFilterLowPaddingGreaterThanHighPadding) { + BackwardFilterLowPaddingGreaterThanHighPadding) { ComputationBuilder builder(client_, TestName()); // activations: 1,2,3,4 ---pad--> 0,0,1,2,3,4 diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index cf089d748dcd4f5db637ff9087c5fbc504c82572..2058cd04a5e765e22be1733c835f07e237afbfbd 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -177,15 +177,15 @@ void DotOperationTest::TestSquareMatrixDot(bool lhs_row_major, bool rhs_row_major) { auto lhs_handle = client_ - ->TransferToServer(*test_utils::CreateR2LiteralWithLayout( + ->TransferToServer(*Literal::CreateR2WithLayout( {{1.0, 2.0}, {3.0, -4.0}}, - MinorToMajorForIsRowMajor(lhs_row_major))) + LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(lhs_row_major)))) .ConsumeValueOrDie(); auto rhs_handle = client_ - ->TransferToServer(*test_utils::CreateR2LiteralWithLayout( + ->TransferToServer(*Literal::CreateR2WithLayout( {{1.0, 6.0}, {7.0, -4.0}}, - MinorToMajorForIsRowMajor(rhs_row_major))) + LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(rhs_row_major)))) .ConsumeValueOrDie(); ComputationBuilder builder(client_, TestName()); @@ -277,10 +277,64 @@ XLA_TEST_F(DotOperationTest, MatrixDotF32_260_3_520_MinorToMajorFF) { TestMatrixDot(260, 3, 520, false, false); } +XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_1x8x8) { + TestMatrixDot(1, 8, 8, true, true); +} + +XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_1x130x8) { + TestMatrixDot(1, 130, 8, true, true); +} + +XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_1x8x130) { + TestMatrixDot(1, 8, 130, true, true); +} + +XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_1x290x130) { + TestMatrixDot(1, 290, 130, true, true); +} + +XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_2x1x1) { + TestMatrixDot(2, 1, 1, true, true); +} + +XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_8x8x1) { + TestMatrixDot(8, 8, 1, true, true); +} + +XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_16x1x1) { + TestMatrixDot(16, 1, 1, true, true); +} + +XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_16x3x1) { + TestMatrixDot(16, 3, 1, true, true); +} + +XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_3x3x1) { + TestMatrixDot(3, 3, 1, true, true); +} + +XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_29x29x1) { + TestMatrixDot(29, 29, 1, true, true); +} + +XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_1x8x2) { + TestMatrixDot(1, 8, 2, true, true); +} + +XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_1x2x8) { + TestMatrixDot(1, 2, 8, true, true); +} + +XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_259x258x1) { + TestMatrixDot(259, 258, 1, true, true); +} + +XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_259x258x1_FT) { + TestMatrixDot(259, 258, 1, false, true); +} + XLA_TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorFF) { - constexpr bool kLhsRowMajor = false; - constexpr bool kRhsRowMajor = false; - TestSquareMatrixDot(kLhsRowMajor, kRhsRowMajor); + TestSquareMatrixDot(false, false); } XLA_TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorFT) { @@ -291,10 +345,24 @@ XLA_TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorTF) { TestSquareMatrixDot(true, false); } -TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorTT) { - constexpr bool kLhsRowMajor = true; - constexpr bool kRhsRowMajor = true; - TestSquareMatrixDot(kLhsRowMajor, kRhsRowMajor); +XLA_TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorTT) { + TestSquareMatrixDot(true, true); +} + +XLA_TEST_F(DotOperationTest, SquareMatrixDotC64MinorToMajorFF) { + TestSquareMatrixDot(false, false); +} + +XLA_TEST_F(DotOperationTest, SquareMatrixDotC64MinorToMajorFT) { + TestSquareMatrixDot(false, true); +} + +XLA_TEST_F(DotOperationTest, SquareMatrixDotC64MinorToMajorTF) { + TestSquareMatrixDot(true, false); +} + +XLA_TEST_F(DotOperationTest, SquareMatrixDotC64MinorToMajorTT) { + TestSquareMatrixDot(true, true); } XLA_TEST_F(DotOperationTest, SquareMatrixDotF64) { @@ -306,15 +374,15 @@ void DotOperationTest::TestNonsquareMatrixDot(bool lhs_row_major, bool rhs_row_major) { auto lhs_handle = client_ - ->TransferToServer(*test_utils::CreateR2LiteralWithLayout( + ->TransferToServer(*Literal::CreateR2WithLayout( {{1.0, 2.0, 3.0}, {3.0, -4.0, -1.0}}, - MinorToMajorForIsRowMajor(lhs_row_major))) + LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(lhs_row_major)))) .ConsumeValueOrDie(); auto rhs_handle = client_ - ->TransferToServer(*test_utils::CreateR2LiteralWithLayout( + ->TransferToServer(*Literal::CreateR2WithLayout( {{1.0, 6.0}, {2.0, 3.0}, {7.0, -4.0}}, - MinorToMajorForIsRowMajor(rhs_row_major))) + LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(rhs_row_major)))) .ConsumeValueOrDie(); ComputationBuilder builder(client_, TestName()); @@ -330,35 +398,64 @@ void DotOperationTest::TestNonsquareMatrixDot(bool lhs_row_major, } XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF32MajorToMinorFF) { - constexpr bool kLhsRowMajor = false; - constexpr bool kRhsRowMajor = false; - TestNonsquareMatrixDot(kLhsRowMajor, kRhsRowMajor); + TestNonsquareMatrixDot(false, false); } XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF32MajorToMinorFT) { - constexpr bool kLhsRowMajor = false; - constexpr bool kRhsRowMajor = true; - TestNonsquareMatrixDot(kLhsRowMajor, kRhsRowMajor); + TestNonsquareMatrixDot(false, true); } XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF32MajorToMinorTF) { - constexpr bool kLhsRowMajor = true; - constexpr bool kRhsRowMajor = false; - TestNonsquareMatrixDot(kLhsRowMajor, kRhsRowMajor); + TestNonsquareMatrixDot(true, false); } XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF32MajorToMinorTT) { - constexpr bool kLhsRowMajor = true; - constexpr bool kRhsRowMajor = true; - TestNonsquareMatrixDot(kLhsRowMajor, kRhsRowMajor); + TestNonsquareMatrixDot(true, true); } XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF64) { TestNonsquareMatrixDot(); } -XLA_TEST_F(DotOperationTest, NonsquareMatrixDotC64) { - TestNonsquareMatrixDot(); +XLA_TEST_F(DotOperationTest, NonsquareMatrixDotC64MajorToMinorFF) { + TestNonsquareMatrixDot(false, false); +} + +XLA_TEST_F(DotOperationTest, NonsquareMatrixDotC64MajorToMinorFT) { + TestNonsquareMatrixDot(false, true); +} + +XLA_TEST_F(DotOperationTest, NonsquareMatrixDotC64MajorToMinorTF) { + TestNonsquareMatrixDot(true, false); +} + +XLA_TEST_F(DotOperationTest, NonsquareMatrixDotC64MajorToMinorTT) { + TestNonsquareMatrixDot(true, true); +} + +XLA_TEST_F(DotOperationTest, MatrixVectorC64) { + auto lhs_handle = + client_ + ->TransferToServer(*Literal::CreateR2WithLayout( + {{1.0, 2.0, 3.0, -4.0}}, LayoutUtil::MakeLayout({1, 0}))) + .ConsumeValueOrDie(); + auto rhs_handle = + client_ + ->TransferToServer(*Literal::CreateR2WithLayout( + {{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}, {-4.0, 4.0}}, + LayoutUtil::MakeLayout({1, 0}))) + .ConsumeValueOrDie(); + + ComputationBuilder builder(client_, TestName()); + auto prim_type = primitive_util::NativeToPrimitiveType(); + auto result = builder.Dot( + builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {1, 4}), "lhs"), + builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {4, 2}), "rhs")); + + Array2D expected({{30.0, -2.0}}); + + ComputeAndCompareR2( + &builder, expected, {lhs_handle.get(), rhs_handle.get()}, error_spec_); } XLA_TEST_F(DotOperationTest, ConcurrentMatMul) { diff --git a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc index 19252f50f25eee42e4e492b7f0e2ec3960c62126..8baaf39e3cf8fa7f6fa4a0224c1297f82e0d92aa 100644 --- a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc +++ b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc @@ -250,9 +250,6 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { // Slice at dimension boundaries. RunR1({0, 1, 2, 3, 4, 5, 6, 7}, {8, 9, 10}, {5}, {0, 1, 2, 3, 4, 8, 9, 10}); - // Slice at dimension boundaries, but with sizes that cause indices to wrap. - RunR1({0, 1, 2, 3, 4, 5, 6, 7}, {8, 9, 10}, {6}, - {0, 1, 2, 3, 4, 5, 8, 9}); // Zero-sized update. RunR1({0, 1, 2, 3, 4, 5, 6, 7}, {}, {2}, {0, 1, 2, 3, 4, 5, 6, 7}); @@ -269,9 +266,6 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { // Slice at dimension boundaries. RunR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {{10, 11}}, {2, 1}, {{1, 2, 3}, {4, 5, 6}, {7, 10, 11}}); - // Slice at dimension boundaries, but with sizes that cause indices to wrap. - RunR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {{10, 11}}, {2, 2}, - {{1, 2, 3}, {4, 5, 6}, {7, 8, 10}}); // Zero-sized update. RunR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {{}}, {2, 1}, {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}); @@ -289,10 +283,20 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { RunR3( {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}, {{{13}, {15}}}, {1, 1, 1}, {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 13}, {11, 15}}}); + } + + template + void TestWrap() { // Slice at dimension boundaries, but with sizes that cause indices to wrap. + RunR1({0, 1, 2, 3, 4, 5, 6, 7}, {8, 9, 10}, {6}, + {10, 1, 2, 3, 4, 5, 8, 9}); + // R2 Shape: [3, 3] + RunR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {{10, 11}}, {2, 2}, + {{1, 2, 3}, {4, 5, 6}, {11, 8, 10}}); + // R3 Shape: [2, 3, 2] RunR3( {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}, {{{13}, {15}}}, - {1, 2, 1}, {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 13}}}); + {1, 2, 1}, {{{1, 2}, {3, 4}, {5, 6}}, {{7, 15}, {9, 10}, {11, 13}}}); } template @@ -425,6 +429,12 @@ XLA_TEST_F(DynamicUpdateSliceTest, Int64R3) { TestR3(); } XLA_TEST_F(DynamicUpdateSliceTest, UInt64R3) { TestR3(); } +XLA_TEST_F(DynamicUpdateSliceTest, Int32Wrap) { TestWrap(); } + +XLA_TEST_F(DynamicUpdateSliceTest, Int64Wrap) { TestWrap(); } + +XLA_TEST_F(DynamicUpdateSliceTest, UInt64Wrap) { TestWrap(); } + XLA_TEST_F(DynamicUpdateSliceTest, Int32R1Pred) { // Slice at dimension start. RunR1({false, false, true, true, false, true, true, false}, @@ -497,19 +507,13 @@ XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleElements) { RunR3Contiguous(operand_shape, /*index=*/1, /*size=*/2); } -// TODO(b/34128753) CPU and GPU failed on 2016-01-06. Appears not to handle -// wrapping as expected. -XLA_TEST_F(DynamicUpdateSliceTest, - DISABLED_ON_CPU(DISABLED_ON_GPU(R3ContiguousMultipleWrapping))) { +XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleWrapping) { // Multiple element, wrapping. std::vector operand_shape({4, 5, 2}); RunR3Contiguous(operand_shape, /*index=*/3, /*size=*/2); } -// TODO(b/34128753) CPU and GPU failed on 2016-01-06. Appears not to handle -// wrapping as expected. -XLA_TEST_F(DynamicUpdateSliceTest, - DISABLED_ON_CPU(DISABLED_ON_GPU(R3ContiguousTooLarge))) { +XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousTooLarge) { // Multiple element, update size larger than operand. std::vector operand_shape({4, 5, 2}); RunR3Contiguous(operand_shape, /*index=*/5, /*size=*/2); @@ -555,7 +559,11 @@ void BM_DynamicSlice(int num_iters) { auto computation = builder.Build().ConsumeValueOrDie(); // Initialize and transfer parameter buffer. - auto buffer = ScopedShapedBuffer::Allocate(start_indices_shape, &allocator, 0) + auto shape_size_fn = [client](const Shape& shape) { + return client->backend().transfer_manager()->GetByteSizeRequirement(shape); + }; + auto buffer = ScopedShapedBuffer::Allocate(start_indices_shape, &allocator, 0, + shape_size_fn) .ConsumeValueOrDie(); auto start_indices_literal = Literal::CreateR1({0, 1, 2, 3}); diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc index a8f6488996087b57e3121ce2c7de918070950c72..2686afccc216095345dbb7b43e916fbbe7c8ea39 100644 --- a/tensorflow/compiler/xla/tests/fusion_test.cc +++ b/tensorflow/compiler/xla/tests/fusion_test.cc @@ -770,8 +770,6 @@ void BM_ParallelFusion(int num_iters) { auto client = ClientLibrary::GetOrCreateLocalClient(client_options).ValueOrDie(); - auto* transfer_manager = - TransferManager::GetForPlatform(platform).ValueOrDie(); int device_ordinal = client->default_device_ordinal(); // Computation shape parameters. @@ -796,29 +794,23 @@ void BM_ParallelFusion(int num_iters) { auto computation = builder.Build().ConsumeValueOrDie(); // Transfer literals to device. - auto buffer0 = - ScopedShapedBuffer::Allocate(shape0, &allocator, /*device_ordinal=*/0) - .ConsumeValueOrDie(); auto param0_literal = Literal::CreateR2F32Linspace(1.0, 2.0, param0_dim0, param0_dim1); - ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice( - executors[device_ordinal], *param0_literal, buffer0->mutable_buffer({}))); - - auto buffer1 = - ScopedShapedBuffer::Allocate(shape1, &allocator, /*device_ordinal=*/0) + std::unique_ptr buffer0 = + client->LiteralToShapedBuffer(*param0_literal, device_ordinal) .ConsumeValueOrDie(); + auto param1_literal = Literal::CreateR2F32Linspace(1.0, 2.0, param1_dim0, param1_dim1); - ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice( - executors[device_ordinal], *param1_literal, buffer1->mutable_buffer({}))); - - auto buffer2 = - ScopedShapedBuffer::Allocate(shape2, &allocator, /*device_ordinal=*/0) + std::unique_ptr buffer1 = + client->LiteralToShapedBuffer(*param1_literal, device_ordinal) .ConsumeValueOrDie(); + auto param2_literal = Literal::CreateR2F32Linspace(1.0, 2.0, param2_dim0, param2_dim1); - ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice( - executors[device_ordinal], *param2_literal, buffer2->mutable_buffer({}))); + std::unique_ptr buffer2 = + client->LiteralToShapedBuffer(*param2_literal, device_ordinal) + .ConsumeValueOrDie(); // Build executable. std::unique_ptr executable = @@ -828,7 +820,7 @@ void BM_ParallelFusion(int num_iters) { ExecutableBuildOptions()) .ConsumeValueOrDie(); - se::Stream stream(executors[client->default_device_ordinal()]); + se::Stream stream(executors[device_ordinal]); stream.Init(); // Initialize thread pool. diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index d73c05ff92578209143e0679558848160cae99bd..e7a18828db064f82cad2a15f797b557d2be1f88a 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -15,13 +15,22 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include #include #include #include +#include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/platform_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -30,18 +39,72 @@ namespace se = ::perftools::gputools; namespace xla { +namespace { + +using tensorflow::StringPiece; +using tensorflow::gtl::ArraySlice; +using tensorflow::gtl::optional; + +constexpr char kInterpreter[] = "interpreter"; + +// Helper functions to get test and reference platforms. +se::Platform* GetReferencePlatform() { + auto result = PlatformUtil::GetPlatform(kInterpreter); + TF_CHECK_OK(result.status()) << "could not get interpreter platform"; + return result.ValueOrDie(); +} + +se::Platform* GetTestPlatform() { + auto result = PlatformUtil::GetDefaultPlatform(); + TF_CHECK_OK(result.status()) << "could not get test platform"; + return result.ValueOrDie(); +} + +bool ProgramShapesEqual(const ProgramShape& lhs, const ProgramShape& rhs) { + if (lhs.parameters_size() != rhs.parameters_size()) { + return false; + } + for (int i = 0; i < lhs.parameters_size(); i++) { + if (!ShapeUtil::Equal(lhs.parameters(i), rhs.parameters(i))) { + return false; + } + } + return ShapeUtil::Equal(lhs.result(), rhs.result()); +} + +ProgramShape GetProgramShapeWithLayout(const HloModule& module) { + ProgramShape program_shape; + const auto* entry = module.entry_computation(); + for (const auto* param : entry->parameter_instructions()) { + *program_shape.add_parameters() = param->shape(); + *program_shape.add_parameter_names() = param->name(); + } + *program_shape.mutable_result() = entry->root_instruction()->shape(); + return program_shape; +} + +} // namespace + +HloTestBase::HloTestBase() + : HloTestBase(GetTestPlatform(), GetReferencePlatform()) {} + +HloTestBase::HloTestBase(se::Platform* test_platform, + se::Platform* reference_platform) + : test_runner_(test_platform), reference_runner_(reference_platform) {} + /* static */ std::unique_ptr HloTestBase::CreateNewModule() { HloModuleConfig config; + config.set_debug_options(GetDebugOptionsForTest()); + return MakeUnique(TestName(), VersionedComputationHandle(), + config); +} +/*static*/ DebugOptions HloTestBase::GetDebugOptionsForTest() { auto debug_options = legacy_flags::GetDebugOptionsFromFlags(); // TODO(b/38354253): Change tests to use Parameters instead of Constants. debug_options.add_xla_disable_hlo_passes("constant_folding"); - - config.set_debug_options(debug_options); - - return MakeUnique(TestName(), VersionedComputationHandle(), - config); + return debug_options; } StatusOr HloTestBase::Execute( @@ -49,25 +112,168 @@ StatusOr HloTestBase::Execute( tensorflow::gtl::ArraySlice arguments, Shape* result_shape) { - return runner_.Execute(std::move(module), arguments, result_shape); + return test_runner_.Execute(std::move(module), arguments, result_shape); } se::DeviceMemoryBase HloTestBase::TransferToDevice(const Literal& literal) { - return runner_.TransferToDevice(literal).ValueOrDie(); + return test_runner_.TransferToDevice(literal).ValueOrDie(); } std::unique_ptr HloTestBase::TransferFromDevice( const Shape& shape, se::DeviceMemoryBase device_base) { - return runner_.TransferFromDevice(shape, device_base).ValueOrDie(); + return test_runner_.TransferFromDevice(shape, device_base).ValueOrDie(); } std::unique_ptr HloTestBase::ExecuteAndTransfer( std::unique_ptr module, tensorflow::gtl::ArraySlice arguments) { - return runner_.ExecuteAndTransfer(std::move(module), arguments).ValueOrDie(); + return test_runner_.ExecuteAndTransfer(std::move(module), arguments) + .ValueOrDie(); +} + +StatusOr> HloTestBase::MakeReferenceModule( + const HloModule& test_module, + const std::function& reference_preprocessor) { + std::unique_ptr reference_module = test_module.Clone(); + const auto& program_shape = GetProgramShapeWithLayout(test_module); + + if (reference_preprocessor != nullptr) { + reference_preprocessor(reference_module.get()); + if (!ProgramShapesEqual(program_shape, + GetProgramShapeWithLayout(*reference_module))) { + return InvalidArgument( + "reference preprocessor must not modify the program shape"); + } + } + TF_RETURN_IF_ERROR(VerifyHloModule(*reference_runner_.backend().platform(), + reference_module.get())); + return std::move(reference_module); +} + +template +StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( + std::unique_ptr module, const ArraySlice arguments, + const optional& error, bool run_hlo_passes, + const std::function& reference_preprocessor) { + static_assert( + std::is_same::value || + std::is_same, LiteralPtr>::value, + "The LiteralPtr type only accepts Literal* or std::unique_ptr."); + TF_RETURN_IF_ERROR( + VerifyHloModule(*test_runner_.backend().platform(), module.get())); + TF_ASSIGN_OR_RETURN(auto reference_module, + MakeReferenceModule(*module, reference_preprocessor)); + + // Execute on two backends. + TF_ASSIGN_OR_RETURN( + auto test, + test_runner_.Execute(std::move(module), arguments, run_hlo_passes)); + TF_ASSIGN_OR_RETURN(auto reference, + reference_runner_.Execute(std::move(reference_module), + arguments, run_hlo_passes)); + return LiteralTestUtil::NearOrEqual(/*expected=*/*reference, /*actual=*/*test, + error); +} + +template +::testing::AssertionResult HloTestBase::RunAndCompare( + std::unique_ptr module, const ArraySlice arguments, + const optional& error, + const std::function& reference_preprocessor) { + auto result = + RunAndCompareInternal(std::move(module), arguments, error, + /*run_hlo_passes=*/true, reference_preprocessor); + if (!result.ok()) { + return ::testing::AssertionFailure() << result.status(); + } + return result.ValueOrDie(); +} + +template +::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses( + std::unique_ptr module, const ArraySlice arguments, + const optional& error, + const std::function& reference_preprocessor) { + auto result = + RunAndCompareInternal(std::move(module), arguments, error, + /*run_hlo_passes=*/false, reference_preprocessor); + if (!result.ok()) { + return ::testing::AssertionFailure() << result.status(); + } + return result.ValueOrDie(); +} + +::testing::AssertionResult HloTestBase::RunAndCompare( + std::unique_ptr module, const optional& error, + const std::function& reference_preprocessor) { + const auto& fake_arguments = + MakeFakeArguments(module.get()).ConsumeValueOrDie(); + return RunAndCompare>( + std::move(module), fake_arguments, error, reference_preprocessor); +} + +::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses( + std::unique_ptr module, const optional& error, + const std::function& reference_preprocessor) { + const auto& fake_arguments = + MakeFakeArguments(module.get()).ConsumeValueOrDie(); + return RunAndCompareNoHloPasses>( + std::move(module), fake_arguments, error, reference_preprocessor); +} + +::testing::AssertionResult HloTestBase::RunAndCompare( + const StringPiece hlo_string, + const tensorflow::gtl::optional& error, + const std::function& reference_preprocessor) { + auto module_or_status = + HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()); + if (!module_or_status.ok()) { + return ::testing::AssertionFailure() << "failed parsing hlo textual IR"; + } + return RunAndCompare(module_or_status.ConsumeValueOrDie(), error, + reference_preprocessor); +} + +::testing::AssertionResult HloTestBase::RunAndCompareFromFile( + const string& filename, const tensorflow::gtl::optional& error, + const std::function& reference_preprocessor) { + auto module_or_status = + HloRunner::ReadModule(filename, GetDebugOptionsForTest()); + if (!module_or_status.ok()) { + return ::testing::AssertionFailure() + << "failed reading hlo module from file"; + } + return RunAndCompare(module_or_status.ConsumeValueOrDie(), error, + reference_preprocessor); +} + +::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses( + const StringPiece hlo_string, + const tensorflow::gtl::optional& error, + const std::function& reference_preprocessor) { + auto module_or_status = + HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()); + if (!module_or_status.ok()) { + return ::testing::AssertionFailure() << "failed parsing hlo textual IR"; + } + return RunAndCompareNoHloPasses(module_or_status.ConsumeValueOrDie(), error, + reference_preprocessor); +} + +::testing::AssertionResult HloTestBase::RunAndCompareNoHloPassesFromFile( + const string& filename, const tensorflow::gtl::optional& error, + const std::function& reference_preprocessor) { + auto module_or_status = + HloRunner::ReadModule(filename, GetDebugOptionsForTest()); + if (!module_or_status.ok()) { + return ::testing::AssertionFailure() + << "failed reading hlo module from file"; + } + return RunAndCompareNoHloPasses(module_or_status.ConsumeValueOrDie(), error, + reference_preprocessor); } -Backend& HloTestBase::backend() { return runner_.backend(); } +Backend& HloTestBase::backend() { return test_runner_.backend(); } /* static */ string HloTestBase::TestName() { diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index 7f068dce36be3546298de2f06bf6d33446d07ca2..3cbbb7aa247dda3e5b6589a2a6aa74cf074babe7 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -24,31 +24,74 @@ limitations under the License. #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_runner.h" +#include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/shape_layout.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/test.h" namespace xla { -// A base class for tests which build and run HLO code. This is a lower level of -// abstraction than using the client interface and enables, for one, explicitly -// building a graph of HLO instructions to run. +// A base class for tests which build and/or run HLO code. The class includes +// support for running an HLO module on two platforms and compare the results. +// This is a lower level of abstraction than using the client interface and +// enables, for one, explicitly building a graph of HLO instructions to run. +// +// This can also be used to write text/file-based test cases. Note that the test +// target is responsible for linking the needed backends. A covenient way to do +// this is to make it an xla_test: it will generate test targets linking with +// the respective backends, which will be used as the test backend; the +// interpreter backend is already linked with hlo_test_base so it will be the +// default reference backend. For example, if you want to compare both cpu vs. +// interpreter, and gpu vs. interpreter, you can: +// +// xla_test ( +// name = "sample_text_test", +// srcs = ["sample_text_test.cc"], +// backends = [ +// "cpu", +// "gpu", +// ], +// deps = [ +// "//third_party/tensorflow/compiler/xla/tests:hlo_test_base", +// ... +// ], +// ) +// +// For a more detailed example, see "../tests/sample_text_test.cc". class HloTestBase : public ::testing::Test { protected: - HloTestBase() {} + // This uses the interpreter backend as the reference backend and + // automatically finds another supported backend as the test backend. If the + // interpreter is the only supported backend, it will be both the test backend + // and the reference backend. + HloTestBase(); + + // If your test doesn't use interpreter as the reference backend, you can use + // this constructor. Note that your test target is responsible for linking in + // both needed backends. + HloTestBase(::perftools::gputools::Platform* test_platform, + ::perftools::gputools::Platform* reference_platform); ~HloTestBase() override {} // Creates a new HLO module for a test. The module created will have // TestName() for its name; it will also automatically populate its debug - // options from command-line flags. It's recommended to use this method to - // create all HloModules for tests. + // options from command-line flags. If you want a fresh HloModule object and + // then add HloComputations to it, it's recommended to use this method in your + // tests. static std::unique_ptr CreateNewModule(); + // Populates debug options from command-line flags and adjusts the options for + // testing. It is recommended to use this when you need to pass in + // DebugOptions, e.g. when creating a module from a string or a file. + static DebugOptions GetDebugOptionsForTest(); + // Executes the given module and returns a global data handle. StatusOr Execute( std::unique_ptr module, @@ -71,6 +114,73 @@ class HloTestBase : public ::testing::Test { tensorflow::gtl::ArraySlice arguments); + // Executes the given hlo module on two backends and compares results. + // + // 'arguments': the input of the hlo module. The LiteralPtr type accepts + // Literal* or std::unique_ptr. + // + // 'error': if has value, expects the results to be near (within the error + // bound). Otherwise, expects the results to be equal. + // + // 'reference_preprocessor': the module should be ready to run on the test + // backend, but it might need to be tailored so that it is able to run on the + // reference backend. Note that the program shape of the module must not be + // modified. + template + ::testing::AssertionResult RunAndCompare( + std::unique_ptr module, + const tensorflow::gtl::ArraySlice arguments, + const tensorflow::gtl::optional& error, + const std::function& reference_preprocessor = nullptr) + TF_MUST_USE_RESULT; + + // Same as above, except that the module will be executed without Hlo + // optimization. + template + ::testing::AssertionResult RunAndCompareNoHloPasses( + std::unique_ptr module, + const tensorflow::gtl::ArraySlice arguments, + const tensorflow::gtl::optional& error, + const std::function& reference_preprocessor = nullptr) + TF_MUST_USE_RESULT; + + // Executes an hlo module with fake inputs and compares the results. + ::testing::AssertionResult RunAndCompare( + std::unique_ptr module, + const tensorflow::gtl::optional& error, + const std::function& reference_preprocessor = nullptr) + TF_MUST_USE_RESULT; + + // Same as above, except that the module will be executed without Hlo + // optimization. + ::testing::AssertionResult RunAndCompareNoHloPasses( + std::unique_ptr module, + const tensorflow::gtl::optional& error, + const std::function& reference_preprocessor = nullptr) + TF_MUST_USE_RESULT; + + // Convenient wrappers for executing and comparing an hlo module with fake + // input. Module can be passed in directly, or parsed from an hlo_string, + // or loaded from a file. + ::testing::AssertionResult RunAndCompare( + const tensorflow::StringPiece hlo_string, + const tensorflow::gtl::optional& error, + const std::function& reference_preprocessor = nullptr) + TF_MUST_USE_RESULT; + ::testing::AssertionResult RunAndCompareFromFile( + const string& filename, const tensorflow::gtl::optional& error, + const std::function& reference_preprocessor = nullptr) + TF_MUST_USE_RESULT; + ::testing::AssertionResult RunAndCompareNoHloPasses( + const tensorflow::StringPiece hlo_string, + const tensorflow::gtl::optional& error, + const std::function& reference_preprocessor = nullptr) + TF_MUST_USE_RESULT; + ::testing::AssertionResult RunAndCompareNoHloPassesFromFile( + const string& filename, const tensorflow::gtl::optional& error, + const std::function& reference_preprocessor = nullptr) + TF_MUST_USE_RESULT; + // Convenience method to force the layout of a given parameter in a module. // The layout of parameter number 'param_no' in the 'module' is set to // 'layout'. @@ -101,12 +211,31 @@ class HloTestBase : public ::testing::Test { static string TestName(); - // Returns the backend owned by the HloRunner. + // Returns the backend owned by the test runner. Backend& backend(); - HloRunner runner_; + HloRunner test_runner_; + HloRunner reference_runner_; ErrorSpec error_spec_{0.0001}; + + private: + // Given the test module, makes a reference module that is ready to run on the + // reference platform. This assumes that the given module is ready to run on + // the test platform. + StatusOr> MakeReferenceModule( + const HloModule& test_module, + const std::function& reference_preprocessor); + + // Runs the module on two platforms with or without running hlo passes and + // compares the results. Returns whether the results are near or equal. If any + // error happens before the results are computed, returns the error status. + template + StatusOr<::testing::AssertionResult> RunAndCompareInternal( + std::unique_ptr module, + const tensorflow::gtl::ArraySlice arguments, + const tensorflow::gtl::optional& error, bool run_hlo_passes, + const std::function& reference_preprocessor); }; } // namespace xla diff --git a/tensorflow/compiler/xla/tests/isolated_convolution.hlo b/tensorflow/compiler/xla/tests/isolated_convolution.hlo new file mode 100644 index 0000000000000000000000000000000000000000..9452780930efbb1ecc13b35cd4ab53678d36c37f --- /dev/null +++ b/tensorflow/compiler/xla/tests/isolated_convolution.hlo @@ -0,0 +1,8 @@ +HloModule convolution.167: + +ENTRY %convolution.167 (parameter.0: f32[16,28,28,128], parameter.1: f32[3,3,128,128]) -> f32[16,28,28,128] { + %parameter.0 = f32[16,28,28,128]{3,0,2,1} parameter(0) + %parameter.1 = f32[3,3,128,128]{3,2,1,0} parameter(1) + ROOT %convolution.167 = f32[16,28,28,128]{3,0,2,1} convolution(f32[16,28,28,128]{3,0,2,1} %parameter.0, f32[3,3,128,128]{3,2,1,0} %parameter.1), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01oi->b01f +} + diff --git a/tensorflow/compiler/xla/tests/literal_test_util.cc b/tensorflow/compiler/xla/tests/literal_test_util.cc index 95a52ecd2f5cfc97ec1ccba7d1b7ca6257a8267e..bf6631a4310d3504e4dfa8c46bf66125a94b9315 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util.cc @@ -100,6 +100,58 @@ namespace xla { ASSERT_EQ(expected.ShortDebugString(), actual.ShortDebugString()); } +/* static */ std::unique_ptr LiteralTestUtil::ConvertBF16ToF32( + const Literal& literal) { + if (ShapeUtil::IsTuple(literal.shape())) { + std::vector> converted_elements; + for (const auto& element : literal.tuple_literals()) { + converted_elements.push_back(ConvertBF16ToF32(element)); + } + return Literal::MakeTupleOwned(std::move(converted_elements)); + } + + if (literal.shape().element_type() != BF16) { + return MakeUnique(literal); + } + Shape converted_shape = literal.shape(); + converted_shape.set_element_type(F32); + auto converted = Literal::CreateFromShape(converted_shape); + if (!ShapeUtil::HasZeroElements(converted_shape)) { + std::vector index(converted_shape.dimensions_size(), 0); + do { + converted->Set(index, + static_cast(literal.Get(index))); + } while (IndexUtil::BumpIndices(converted_shape, &index)); + } + return converted; +} + +/* static */ std::unique_ptr LiteralTestUtil::ConvertF32ToBF16( + const Literal& literal) { + if (ShapeUtil::IsTuple(literal.shape())) { + std::vector> converted_elements; + for (const auto& element : literal.tuple_literals()) { + converted_elements.push_back(ConvertF32ToBF16(element)); + } + return Literal::MakeTupleOwned(std::move(converted_elements)); + } + + if (literal.shape().element_type() != F32) { + return MakeUnique(literal); + } + Shape converted_shape = literal.shape(); + converted_shape.set_element_type(BF16); + auto converted = Literal::CreateFromShape(converted_shape); + if (!ShapeUtil::HasZeroElements(converted_shape)) { + std::vector index(converted_shape.dimensions_size(), 0); + do { + converted->Set( + index, static_cast(literal.Get(index))); + } while (IndexUtil::BumpIndices(converted_shape, &index)); + } + return converted; +} + namespace { string Hostname() { @@ -116,16 +168,18 @@ template ::testing::AssertionResult CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs) { auto ulhs = tensorflow::bit_cast(lhs); auto urhs = tensorflow::bit_cast(rhs); + auto lhs_double = static_cast(lhs); + auto rhs_double = static_cast(rhs); if (ulhs != urhs) { return ::testing::AssertionFailure() << tensorflow::strings::Printf( "floating values are not bitwise-equal; and equality testing " "was requested: %s=%g=%a vs %s=%g=%a", tensorflow::strings::StrCat(tensorflow::strings::Hex(ulhs)) .c_str(), - lhs, lhs, + lhs_double, lhs_double, tensorflow::strings::StrCat(tensorflow::strings::Hex(urhs)) .c_str(), - rhs, rhs); + rhs_double, rhs_double); } return ::testing::AssertionSuccess(); } @@ -149,6 +203,10 @@ template // Specializations for floating types that do bitwise comparisons when equality // comparison is requested. template <> +::testing::AssertionResult CompareEqual(bfloat16 lhs, bfloat16 rhs) { + return CompareFloatsBitwiseEqual(lhs, rhs); +} +template <> ::testing::AssertionResult CompareEqual(float lhs, float rhs) { return CompareFloatsBitwiseEqual(lhs, rhs); } @@ -238,6 +296,9 @@ bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual, case U64: match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); break; + case BF16: + match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); + break; case F32: match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); break; @@ -272,23 +333,37 @@ bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual, return result; } -/* static */ void LiteralTestUtil::ExpectEqualTuple(const Literal& expected, - const Literal& actual) { +/* static */ ::testing::AssertionResult LiteralTestUtil::EqualTuple( + const Literal& expected, const Literal& actual) { VLOG(1) << "expected: " << expected.ToString(); VLOG(1) << "actual: " << actual.ToString(); - ASSERT_TRUE(ShapeUtil::IsTuple(expected.shape())); - ASSERT_TRUE(ShapeUtil::IsTuple(actual.shape())); + if (!ShapeUtil::IsTuple(expected.shape()) || + !ShapeUtil::IsTuple(actual.shape())) { + return ::testing::AssertionFailure() + << "tuples expected shape = " << expected.shape().ShortDebugString() + << " actual shape = " << actual.shape().ShortDebugString(); + } AssertEqualShapes(expected.shape(), actual.shape()); for (uint64 i = 0; i < expected.tuple_literals_size(); ++i) { const auto& expected_element = expected.tuple_literals(i); const auto& actual_element = actual.tuple_literals(i); if (ShapeUtil::IsTuple(expected_element.shape())) { - ExpectEqualTuple(expected_element, actual_element); + auto ret = EqualTuple(expected_element, actual_element); + if (!ret) { + return ret; + } } else { - ExpectEqual(expected_element, actual_element); + return Equal(expected_element, actual_element); } } + + return ::testing::AssertionSuccess(); +} + +/* static */ void LiteralTestUtil::ExpectEqualTuple(const Literal& expected, + const Literal& actual) { + EXPECT_TRUE(EqualTuple(expected, actual)); } namespace { @@ -331,6 +406,9 @@ class NearComparator { multi_index_.resize(expected.shape().dimensions_size(), 0); switch (expected.shape().element_type()) { + case BF16: + ExpectLiteralsNear(expected, actual, 0); + break; case F32: ExpectLiteralsNear(expected, actual, 0); break; @@ -516,6 +594,13 @@ void NearComparator::ExpectNear(complex64 expected, complex64 actual, << message; } +template <> +bool NearComparator::ExpectValuesNear(bfloat16 expected, + bfloat16 actual) { + return ExpectValuesNear(static_cast(expected), + static_cast(actual)); +} + } // namespace /* static */ ::testing::AssertionResult LiteralTestUtil::Near( @@ -544,8 +629,7 @@ void NearComparator::ExpectNear(complex64 expected, complex64 actual, if (!ShapeUtil::IsTuple(expected.shape()) || !ShapeUtil::IsTuple(actual.shape())) { return ::testing::AssertionFailure() - << "tuples expected expected shape = " - << expected.shape().ShortDebugString() + << "tuples expected shape = " << expected.shape().ShortDebugString() << " actual shape = " << actual.shape().ShortDebugString(); } AssertEqualShapes(expected.shape(), actual.shape()); @@ -579,6 +663,32 @@ void NearComparator::ExpectNear(complex64 expected, complex64 actual, EXPECT_TRUE(NearTuple(expected, actual, error)); } +/*static*/ ::testing::AssertionResult LiteralTestUtil::NearOrEqual( + const Literal& expected, const Literal& actual, + const tensorflow::gtl::optional& error) { + bool is_tuple = ShapeUtil::IsTuple(expected.shape()); + if (error.has_value()) { + if (is_tuple) { + VLOG(1) << "Expects near tuple"; + return NearTuple(expected, actual, *error); + } + VLOG(1) << "Expects near"; + return Near(expected, actual, *error); + } + if (is_tuple) { + VLOG(1) << "Expects equal tuple"; + return EqualTuple(expected, actual); + } + VLOG(1) << "Expects equal"; + return Equal(expected, actual); +} + +/*static*/ void LiteralTestUtil::ExpectNearOrEqual( + const Literal& expected, const Literal& actual, + const tensorflow::gtl::optional& error) { + EXPECT_TRUE(NearOrEqual(expected, actual, error)); +} + /* static */ string LiteralTestUtil::MultiIndexAsString( tensorflow::gtl::ArraySlice multi_index) { return tensorflow::strings::StrCat( diff --git a/tensorflow/compiler/xla/tests/literal_test_util.h b/tensorflow/compiler/xla/tests/literal_test_util.h index 467d44b857b74d2a38e9b3f8a32a9b1d39a4a10d..f53553c70170bdcda717e72ffd791016effd0774 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.h +++ b/tensorflow/compiler/xla/tests/literal_test_util.h @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -59,6 +60,16 @@ class LiteralTestUtil { static void AssertEqualShapesAndLayouts(const Shape& expected, const Shape& actual); + // If the given literal's data type is bfloat16, converts it to a float + // literal; otherwise, returns a copy of it. If the literal is a tuple, + // recursively converts its elements. + static std::unique_ptr ConvertBF16ToF32(const Literal& bf16_literal); + + // If the given literal's data type is float, converts it to a bfloat16 + // literal; otherwise, returns a copy of it. If the literal is a tuple, + // recursively converts its elements. + static std::unique_ptr ConvertF32ToBF16(const Literal& f32_literal); + // Asserts that the expected and actual literals are (bitwise) equal for all // elements in the literal. Also, asserts that the rank, dimensions sizes, and // primitive type are equal. @@ -100,6 +111,10 @@ class LiteralTestUtil { static void ExpectR4EqualArray4D(const Array4D& expected, const Literal& actual); + // Returns whether the two tuples are equal. + static ::testing::AssertionResult EqualTuple( + const Literal& expected, const Literal& actual) TF_MUST_USE_RESULT; + // Expects that the values of the elements in the expected and actual tuples // are equal. Tuples are matched recursively. static void ExpectEqualTuple(const Literal& expected, const Literal& actual); @@ -167,6 +182,19 @@ class LiteralTestUtil { static void ExpectNearTuple(const Literal& expected, const Literal& actual, const ErrorSpec& error); + // If the error spec is given, returns whether the expected and the actual are + // within the error bound; otherwise, returns whether they are equal. Tuples + // will be compared recursively. + static ::testing::AssertionResult NearOrEqual( + const Literal& expected, const Literal& actual, + const tensorflow::gtl::optional& error) TF_MUST_USE_RESULT; + + // If the error spec is given, expects the expected and the actual to be near; + // otherwise, expects them to be equal. Tuples will be compared recursively. + static void ExpectNearOrEqual( + const Literal& expected, const Literal& actual, + const tensorflow::gtl::optional& error); + // Returns a multi-dimensional index as a string. For example: '{7, 8}' will // be returned for a 2-dimensional index with dimension 0 index equal to 7, // dimension 1 equal to 8. diff --git a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc index 458258e7ee1fee6964275c51ef38de5ff2ccd7b1..b5b95967ff9162301a092f3a57996e0f3f78658f 100644 --- a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc +++ b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc @@ -14,50 +14,147 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/llvm_compiler.h" +#include "tensorflow/compiler/xla/service/backend.h" +#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_compiler.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/service/platform_util.h" +#include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/stream_executor/stream_executor.h" namespace xla { namespace { -class LLVMCompilerTest : public HloTestBase {}; - -XLA_TEST_F(LLVMCompilerTest, CompilerHooks) { - int pre_opt_hook_call_count = 0; - int post_opt_hook_call_count = 0; - - auto pre_opt_hook = [&pre_opt_hook_call_count](const llvm::Module &) { - ++pre_opt_hook_call_count; - return Status::OK(); - }; - auto post_opt_hook = [&post_opt_hook_call_count](const llvm::Module &) { - ++post_opt_hook_call_count; - return Status::OK(); - }; - - // Create HLO module, and run the compiler. - auto builder = HloComputation::Builder(TestName()); - builder.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(42.0))); - - auto hlo_module = CreateNewModule(); - hlo_module->AddEntryComputation(builder.Build()); - - auto compiler = static_cast(backend().compiler()); - compiler->SetPreOptimizationHook(pre_opt_hook); - compiler->SetPostOptimizationHook(post_opt_hook); - - ASSERT_TRUE( - compiler - ->Compile(std::move(hlo_module), backend().default_stream_executor()) - .ok()); - - // Test that hooks were called. - EXPECT_EQ(1, pre_opt_hook_call_count); - EXPECT_EQ(1, post_opt_hook_call_count); +class LLVMCompilerTest : public ::testing::Test { + public: + void SetUp() override { + Platform *platform = FindPlatform(); + ASSERT_NE(platform, nullptr); + + BackendOptions backend_options; + backend_options.set_platform(platform); + StatusOr> backend_or_status = + Backend::CreateBackend(backend_options); + ASSERT_IS_OK(backend_or_status.status()); + backend_ = backend_or_status.ConsumeValueOrDie(); + } + + ~LLVMCompilerTest() override {} + + protected: + using Platform = ::perftools::gputools::Platform; + + explicit LLVMCompilerTest(string platform_name) + : platform_name_(std::move(platform_name)) {} + + void TestCompilerHooks(LLVMCompiler *compiler) { + int pre_opt_hook_call_count = 0; + int post_opt_hook_call_count = 0; + + auto pre_opt_hook = [&pre_opt_hook_call_count](const llvm::Module &) { + ++pre_opt_hook_call_count; + return Status::OK(); + }; + auto post_opt_hook = [&post_opt_hook_call_count](const llvm::Module &) { + ++post_opt_hook_call_count; + return Status::OK(); + }; + + // Create HLO module, and run the compiler. + auto builder = HloComputation::Builder(TestName()); + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42.0))); + + auto hlo_module = CreateNewModule(); + hlo_module->AddEntryComputation(builder.Build()); + + compiler->SetPreOptimizationHook(pre_opt_hook); + compiler->SetPostOptimizationHook(post_opt_hook); + + ASSERT_TRUE(compiler + ->RunBackend(std::move(hlo_module), + backend_->default_stream_executor()) + .ok()); + + // Test that hooks were called. + EXPECT_EQ(1, pre_opt_hook_call_count); + EXPECT_EQ(1, post_opt_hook_call_count); + } + + void TestMultiModuleCompilation(LLVMCompiler *compiler) { + HloComputation::Builder builder(TestName()); + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42.0))); + + std::unique_ptr hlo_module = CreateNewModule(); + hlo_module->AddEntryComputation(builder.Build()); + + std::vector> modules; + modules.push_back(hlo_module->Clone()); + modules.push_back(std::move(hlo_module)); + + std::vector> executors; + executors.push_back({backend_->default_stream_executor()}); + executors.push_back({backend_->default_stream_executor()}); + + EXPECT_IS_OK(compiler->Compile(std::move(modules), std::move(executors))); + } + + private: + Platform *FindPlatform() { + for (Platform *platform : + PlatformUtil::GetSupportedPlatforms().ConsumeValueOrDie()) { + if (platform->Name() == platform_name_) { + return platform; + } + } + return nullptr; + } + + string platform_name_; + std::unique_ptr backend_; + + static string TestName() { + return ::testing::UnitTest::GetInstance()->current_test_info()->name(); + } + + static std::unique_ptr CreateNewModule() { + HloModuleConfig config; + config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); + return MakeUnique(TestName(), VersionedComputationHandle(), + config); + } +}; + +class CpuCompilerTest : public LLVMCompilerTest { + public: + CpuCompilerTest() : LLVMCompilerTest("Host") {} +}; + +class GpuCompilerTest : public LLVMCompilerTest { + public: + GpuCompilerTest() : LLVMCompilerTest("CUDA") {} +}; + +TEST_F(CpuCompilerTest, HooksTest) { + cpu::CpuCompiler compiler; + TestCompilerHooks(&compiler); +} + +TEST_F(GpuCompilerTest, HooksTest) { + gpu::GpuCompiler compiler; + TestCompilerHooks(&compiler); } +TEST_F(CpuCompilerTest, MultiModuleCompilation) { + cpu::CpuCompiler compiler; + TestMultiModuleCompilation(&compiler); +} + +TEST_F(GpuCompilerTest, MultModuleCompilation) { + gpu::GpuCompiler compiler; + TestMultiModuleCompilation(&compiler); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/local_client_execute_test.cc b/tensorflow/compiler/xla/tests/local_client_execute_test.cc index 329b53012f58c8d084cc05f9a567a8aa432c4a3a..ad71d40197fe48b4343ee5f5f7f71b282a05cbf5 100644 --- a/tensorflow/compiler/xla/tests/local_client_execute_test.cc +++ b/tensorflow/compiler/xla/tests/local_client_execute_test.cc @@ -136,16 +136,14 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentInputLayouts) { auto computation = builder.Build().ConsumeValueOrDie(); // Create x as a col-major array. - auto x_array = LiteralToShapedBuffer( - *test_utils::CreateR2LiteralWithLayout({{1.0f, 2.0f}, {3.0f, 4.0f}}, - /*minor_to_major=*/{0, 1})); + auto x_array = LiteralToShapedBuffer(*Literal::CreateR2WithLayout( + {{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({0, 1}))); EXPECT_TRUE(LayoutUtil::Equal(x_array->shape().layout(), LayoutUtil::MakeLayout({0, 1}))); // Create y as a row-major array. - auto y_array = LiteralToShapedBuffer( - *test_utils::CreateR2LiteralWithLayout({{10.0f, 20.0f}, {30.0f, 40.0f}}, - /*minor_to_major=*/{1, 0})); + auto y_array = LiteralToShapedBuffer(*Literal::CreateR2WithLayout( + {{10.0f, 20.0f}, {30.0f, 40.0f}}, LayoutUtil::MakeLayout({1, 0}))); EXPECT_TRUE(LayoutUtil::Equal(y_array->shape().layout(), LayoutUtil::MakeLayout({1, 0}))); @@ -876,11 +874,13 @@ XLA_TEST_F(LocalClientExecuteTest, tensorflow::ThreadOptions(), "execute_thread", [&] { ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {}); })); - ASSERT_IS_OK(local_client_->TransferToInfeed( - *Literal::CreateR1({-5.0, 123.0, 42.0}))); + ASSERT_IS_OK(local_client_->TransferToInfeedLocal( + *Literal::CreateR1({-5.0, 123.0, 42.0}), + local_client_->default_device_ordinal())); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, - local_client_->TransferFromOutfeed(&shape)); + local_client_->TransferFromOutfeedLocal( + shape, local_client_->default_device_ordinal())); LiteralTestUtil::ExpectR1Equal({-4.0, 125.0, 45.0}, *result); } @@ -906,9 +906,12 @@ void BM_LocalClientOverhead(int num_iters) { builder.Add(x, x); auto computation = builder.Build().ConsumeValueOrDie(); - auto buffer = - ScopedShapedBuffer::Allocate(shape, &allocator, /*device_ordinal=*/0) - .ConsumeValueOrDie(); + auto shape_size_fn = [client](const Shape& shape) { + return client->backend().transfer_manager()->GetByteSizeRequirement(shape); + }; + auto buffer = ScopedShapedBuffer::Allocate( + shape, &allocator, /*device_ordinal=*/0, shape_size_fn) + .ConsumeValueOrDie(); auto literal = Literal::CreateR2({{0, 0, 0}, {0, 0, 0}}); ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice( executors[device_ordinal], *literal, buffer->mutable_buffer({}))); diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.cc b/tensorflow/compiler/xla/tests/local_client_test_base.cc index c11e1df0a7890a6c3aada5ff47494b42fdaf3b9d..062a9246e49598d5d03dce8c1f437138923449bf 100644 --- a/tensorflow/compiler/xla/tests/local_client_test_base.cc +++ b/tensorflow/compiler/xla/tests/local_client_test_base.cc @@ -12,13 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#define EIGEN_USE_THREADS #include "tensorflow/compiler/xla/tests/local_client_test_base.h" #include -#define EIGEN_USE_THREADS - #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/map_util.h" @@ -136,29 +135,10 @@ std::unique_ptr LocalClientTestBase::LiteralToShapedBuffer( .ConsumeValueOrDie(); } -void LocalClientTestBase::CopyShapedBufferToLiteral( - const ShapedBuffer& shaped_buffer, ShapeIndex* index, Literal* literal) { - const Shape& shape = ShapeUtil::GetSubshape(shaped_buffer.shape(), *index); - if (ShapeUtil::IsTuple(shape)) { - *literal->mutable_shape() = shape; - for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { - Literal* element_literal = literal->add_tuple_literals(); - index->push_back(i); - CopyShapedBufferToLiteral(shaped_buffer, index, element_literal); - index->pop_back(); - } - } else { - ASSERT_IS_OK(transfer_manager_->TransferLiteralFromDevice( - stream_executor_, shaped_buffer.buffer(*index), shape, shape, literal)); - } -} - std::unique_ptr LocalClientTestBase::ShapedBufferToLiteral( const ShapedBuffer& shaped_buffer) { - auto literal = MakeUnique(); - ShapeIndex index; - CopyShapedBufferToLiteral(shaped_buffer, &index, literal.get()); - return literal; + return local_client_->ShapedBufferToLiteral(shaped_buffer) + .ConsumeValueOrDie(); } ExecutableBuildOptions LocalClientTestBase::DefaultExecutableBuildOptions() diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.h b/tensorflow/compiler/xla/tests/local_client_test_base.h index 3edfcb656ed8278d403103f0cfd820a10892476a..f0c73f04f6eb67b2e9cb5e111eccdc3818059b2b 100644 --- a/tensorflow/compiler/xla/tests/local_client_test_base.h +++ b/tensorflow/compiler/xla/tests/local_client_test_base.h @@ -93,10 +93,6 @@ class LocalClientTestBase : public ::testing::Test { std::unique_ptr ShapedBufferToLiteral( const ShapedBuffer& shaped_buffer); - // Helper for converting a ShapedBuffer into a literal. - void CopyShapedBufferToLiteral(const ShapedBuffer& shaped_buffer, - ShapeIndex* index, Literal* literal); - // Execute the given computation on the local client. With and without // options. StatusOr> ExecuteLocally( diff --git a/tensorflow/compiler/xla/tests/map_test.cc b/tensorflow/compiler/xla/tests/map_test.cc index 2ef392508d14cf6dc14b2c979f07a79bc60d7426..2b0f7e6e80c48435ca55432a2afa3b6d69162625 100644 --- a/tensorflow/compiler/xla/tests/map_test.cc +++ b/tensorflow/compiler/xla/tests/map_test.cc @@ -405,13 +405,13 @@ TEST_F(MapTest, MapBinaryAdder) { // for Map that used to fail in shape inference (b/28989438). XLA_TEST_F(MapTest, AddWithMixedLayouts) { ComputationBuilder builder(client_, TestName()); - std::unique_ptr param0_literal = - test_utils::CreateR2LiteralWithLayout({{1, 2}, {3, 4}}, {1, 0}); + std::unique_ptr param0_literal = Literal::CreateR2WithLayout( + {{1, 2}, {3, 4}}, LayoutUtil::MakeLayout({1, 0})); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - std::unique_ptr param1_literal = - test_utils::CreateR2LiteralWithLayout({{10, 20}, {30, 40}}, {0, 1}); + std::unique_ptr param1_literal = Literal::CreateR2WithLayout( + {{10, 20}, {30, 40}}, LayoutUtil::MakeLayout({0, 1})); std::unique_ptr param1_data = client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); diff --git a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc index 22d2b917a1d55f4f453e21c2d8fea38e32ff796b..89fa6ed9f7fe590f3ac872cce48a329b2894048a 100644 --- a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc +++ b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc @@ -76,8 +76,11 @@ class MultiOutputFusionTest : public HloTestBase { elem_shape2, HloOpcode::kAdd, broadcast, param1)); HloInstruction* sub = builder.AddInstruction(HloInstruction::CreateBinary( elem_shape2, HloOpcode::kSubtract, param1, broadcast)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateBinary(elem_shape2, HloOpcode::kDot, sub, add2)); + HloInstruction::CreateDot(elem_shape2, sub, add2, dot_dnums)); auto computation = hlo_module->AddEntryComputation(builder.Build(dot)); if (manual_fusion) { @@ -133,8 +136,11 @@ class MultiOutputFusionTest : public HloTestBase { HloInstruction* reshape = builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {size, 1}), add)); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {1}), HloOpcode::kDot, sub, reshape)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(0); + dot_dnums.add_rhs_contracting_dimensions(0); + HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( + ShapeUtil::MakeShape(F32, {1}), sub, reshape, dot_dnums)); auto computation = hlo_module->AddEntryComputation(builder.Build(dot)); if (manual_fusion) { diff --git a/tensorflow/compiler/xla/tests/params_test.cc b/tensorflow/compiler/xla/tests/params_test.cc index fda4389f479cdc7a659e4d7c8a2facba55e17e83..24c5daed3d09dc447ef92a4bc7e0d7185ec903ed 100644 --- a/tensorflow/compiler/xla/tests/params_test.cc +++ b/tensorflow/compiler/xla/tests/params_test.cc @@ -252,8 +252,8 @@ XLA_TEST_F(ParamsTest, HundredLargeR1Parameters) { } // Only run the 3,000-parameter tests in opt mode to avoid test timeouts. -// Timeout last observed on 2017-09-12. -#ifndef NDEBUG +// Timeout last observed on 2017-11-20. +#ifdef NDEBUG // TODO(b/65525254) Fails on GPU on 2017-09-10 because we try to reserve too // much space in parameter memory for the kernel. @@ -334,6 +334,106 @@ XLA_TEST_F(ParamsTest, DISABLED_ON_CPU(DISABLED_ON_GPU( ComputeAndCompareTuple(&builder, *Literal::MakeTuple(ptrs), param_data); } +// Test large number of parameters flowing into a while-loop. +// Construct conceptually the following HLO graph: +// +// p0 = parameter(0) +// p1 = parameter(1) +// ... +// pN = parameter(N) +// result = while (false) { +// p0 += (1, 1); +// p1 += (1, 1); +// ... +// pN += (1, 1) +// } +// result = {p0, p1, ..., pN} +// +// TODO(b/70173746): Times out during compilation on GPU and CPU-parallel +// backend as of 2017-12-03. +XLA_TEST_F(ParamsTest, DISABLED_ON_CPU_PARALLEL( + DISABLED_ON_GPU(ManyParametersIntoWhileLoop))) { + ComputationBuilder builder(client_, TestName()); + + std::vector> param_data_owner; + constexpr int kParamCount = 1900; + std::vector params; + std::vector parameter_shapes; + for (int i = 0; i < kParamCount; ++i) { + std::unique_ptr literal = Literal::CreateR1({i, i}); + param_data_owner.push_back( + std::move(client_->TransferToServer(*literal)).ValueOrDie()); + ComputationDataHandle param = + builder.Parameter(i, literal->shape(), "param"); + params.push_back(param); + parameter_shapes.push_back(literal->shape()); + } + + // Add bool parameter for the loop condition. Use a parameter HLO instead of a + // constant because DCE may eliminate the while-body otherwise. + std::unique_ptr bool_literal = Literal::CreateR0(false); + param_data_owner.push_back( + std::move(client_->TransferToServer(*bool_literal)).ValueOrDie()); + ComputationDataHandle bool_param = + builder.Parameter(kParamCount, bool_literal->shape(), "bool_param"); + params.push_back(bool_param); + parameter_shapes.push_back(bool_literal->shape()); + + auto init = builder.Tuple(params); + + // Create a computation for the condition: while(bool_param). + Shape while_shape = ShapeUtil::MakeTupleShape(parameter_shapes); + Computation condition; + { + ComputationBuilder builder(client_, "condition"); + auto condition_parameter = + builder.Parameter(0, while_shape, "condition_parameter"); + builder.GetTupleElement(condition_parameter, kParamCount); + condition = builder.Build().ConsumeValueOrDie(); + } + + // Create a computation for the body. + // Add {1, 1} to the each tuple element. + Computation body; + { + ComputationBuilder builder(client_, "body"); + auto body_parameter = builder.Parameter(0, while_shape, "body_parameter"); + std::vector updates; + for (int i = 0; i < kParamCount; ++i) { + auto add = builder.Add(builder.GetTupleElement(body_parameter, i), + builder.ConstantR1({1, 1})); + updates.push_back(add); + } + // Add bool parameter. + updates.push_back(builder.GetTupleElement(body_parameter, kParamCount)); + + builder.Tuple(updates); + body = builder.Build().ConsumeValueOrDie(); + } + + auto loop = builder.While(condition, body, init); + + std::vector outputs; + for (int i = 0; i < kParamCount; ++i) { + outputs.push_back(builder.GetTupleElement(loop, i)); + } + builder.Tuple(outputs); + + std::vector param_data; + param_data.reserve(param_data_owner.size()); + for (const std::unique_ptr& data : param_data_owner) { + param_data.push_back(data.get()); + } + + std::vector> elements; + std::vector ptrs; + for (int i = 0; i < kParamCount; ++i) { + elements.push_back(Literal::CreateR1({i, i})); + ptrs.push_back(elements.back().get()); + } + ComputeAndCompareTuple(&builder, *Literal::MakeTuple(ptrs), param_data); +} + #endif XLA_TEST_F(ParamsTest, diff --git a/tensorflow/compiler/xla/tests/pred_test.cc b/tensorflow/compiler/xla/tests/pred_test.cc index 3500e8dc28570fe216f53b746c3757e080aa689f..10e44b274a8a9f3ac28dc40d7b1938d24a9ee40c 100644 --- a/tensorflow/compiler/xla/tests/pred_test.cc +++ b/tensorflow/compiler/xla/tests/pred_test.cc @@ -90,7 +90,7 @@ TEST_F(PredTest, ConstantR2Pred) { builder.ConstantR2({{false, true, true}, {true, false, false}}); const string expected = R"(pred[2,3] { { 011 }, - { 100 }, + { 100 } })"; EXPECT_EQ(expected, ExecuteToString(&builder, {})); } @@ -119,7 +119,9 @@ TEST_F(PredTest, AnyR1VacuouslyFalse) { TEST_F(PredTest, AnyR2True) { ComputationBuilder builder(client_, TestName()); auto a = builder.ConstantR2({ - {false, false, false}, {false, false, false}, {false, false, true}, + {false, false, false}, + {false, false, false}, + {false, false, true}, }); TF_ASSERT_OK(Any(a, &builder).status()); ComputeAndCompareR0(&builder, true, {}); @@ -128,7 +130,9 @@ TEST_F(PredTest, AnyR2True) { TEST_F(PredTest, AnyR2False) { ComputationBuilder builder(client_, TestName()); auto a = builder.ConstantR2({ - {false, false, false}, {false, false, false}, {false, false, false}, + {false, false, false}, + {false, false, false}, + {false, false, false}, }); TF_ASSERT_OK(Any(a, &builder).status()); ComputeAndCompareR0(&builder, false, {}); diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc index 7bc3185c367f076c9a7d211c9799557e1a91d92f..b09ccdd679b6c8f628e40f78f58dbd1734926af6 100644 --- a/tensorflow/compiler/xla/tests/reduce_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_test.cc @@ -352,15 +352,13 @@ XLA_TEST_F(ReduceTest, ReduceR2_111x50_01_To_R1) { XLA_TEST_F(ReduceTest, ReduceR2_1024x1024_To_R1) { RunR2ToR1Test(1024, 1024); } XLA_TEST_F(ReduceTest, ReduceR2_1000x1500_To_R1) { RunR2ToR1Test(1000, 1500); } -// TODO(b/34969189): Invalid CAS generated on GPU. -XLA_TEST_F(ReduceTest, DISABLED_ON_GPU(AndReduceAllOnesR1_10_Pred)) { +XLA_TEST_F(ReduceTest, AndReduceAllOnesR1_10_Pred) { constexpr int element_count = 10; std::vector input(element_count, 1); RunR1ToR0PredTest(/*and_reduce=*/true, input); } -// TODO(b/34969189): Invalid CAS generated on GPU. -XLA_TEST_F(ReduceTest, DISABLED_ON_GPU(AndReduceOnesAndZerosR1_10_Pred)) { +XLA_TEST_F(ReduceTest, AndReduceOnesAndZerosR1_10_Pred) { constexpr int element_count = 10; std::vector input(element_count); for (int i = 0; i < element_count; ++i) { @@ -369,15 +367,13 @@ XLA_TEST_F(ReduceTest, DISABLED_ON_GPU(AndReduceOnesAndZerosR1_10_Pred)) { RunR1ToR0PredTest(/*and_reduce=*/true, input); } -// TODO(b/34969189): Invalid CAS generated on GPU. -XLA_TEST_F(ReduceTest, DISABLED_ON_GPU(OrReduceAllOnesR1_10_Pred)) { +XLA_TEST_F(ReduceTest, OrReduceAllOnesR1_10_Pred) { constexpr int element_count = 10; std::vector input(element_count, 1); RunR1ToR0PredTest(/*and_reduce=*/false, input); } -// TODO(b/34969189): Invalid CAS generated on GPU. -XLA_TEST_F(ReduceTest, DISABLED_ON_GPU(OrReduceOnesAndZerosR1_10_Pred)) { +XLA_TEST_F(ReduceTest, OrReduceOnesAndZerosR1_10_Pred) { constexpr int element_count = 10; std::vector input(element_count); for (int i = 0; i < element_count; ++i) { diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc index 6c9b62b48d8bb2ad93b2ce98839e5e52d8eaa8cc..b32df74312ed1b513bcdd161c1516c5a5a2f0faf 100644 --- a/tensorflow/compiler/xla/tests/reduce_window_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc @@ -41,16 +41,40 @@ limitations under the License. namespace xla { namespace { -class ReduceWindowTest : public ClientLibraryTestBase { +#ifdef XLA_BACKEND_SUPPORTS_BFLOAT16 +// Tests both F32 and BF16. +static std::array use_bfloat16_params{false, true}; +#else +// Only tests F32. +static std::array use_bfloat16_params{false}; +#endif + +class ReduceWindowTestBase : public ClientLibraryTestBase { public: - ReduceWindowTest() : builder_(client_, TestName()) {} + ErrorSpec DefaultErrorSpec() const { + if (use_bfloat16()) { + return ErrorSpec(1e-1, 5e-2); + } else { + return ErrorSpec(1e-3, 1e-3); + } + } +}; + +class ReduceWindowTest : public ::testing::WithParamInterface, + public ReduceWindowTestBase { + public: + ReduceWindowTest() : builder_(client_, TestName()) { + set_use_bfloat16(GetParam()); + } void ReduceWindowAdd(const ComputationDataHandle& input, tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, Padding padding) { - builder_.ReduceWindow(input, builder_.ConstantR0(0.0f), - CreateScalarAddComputation(F32, &builder_), + auto init = + CreateConstantFromLiteral(*Literal::CreateR0(0.0f), &builder_); + builder_.ReduceWindow(input, init, + CreateScalarAddComputation(FloatType(), &builder_), window_dimensions, window_strides, padding); } @@ -58,30 +82,32 @@ class ReduceWindowTest : public ClientLibraryTestBase { tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, Padding padding) { - builder_.ReduceWindow( - input, builder_.ConstantLiteral(Literal::MinValue(F32)), - CreateScalarMax(), window_dimensions, window_strides, padding); + auto init = CreateConstantFromLiteral(Literal::MinValue(F32), &builder_); + builder_.ReduceWindow(input, init, CreateScalarMax(), window_dimensions, + window_strides, padding); } void ReduceWindowMin(const ComputationDataHandle& input, tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, Padding padding) { - builder_.ReduceWindow(input, - builder_.ConstantLiteral(Literal::MaxValue(F32)), - CreateScalarMinComputation(F32, &builder_), + auto init = CreateConstantFromLiteral(Literal::MaxValue(F32), &builder_); + builder_.ReduceWindow(input, init, + CreateScalarMinComputation(FloatType(), &builder_), window_dimensions, window_strides, padding); } ComputationBuilder builder_; }; -TEST_F(ReduceWindowTest, MismatchedRanksGivesErrorStatus) { - const auto input = builder_.ConstantR1({1, 1, 1, 1}); - const auto init_value = builder_.ConstantR0(0); +TEST_P(ReduceWindowTest, MismatchedRanksGivesErrorStatus) { + const auto input = CreateConstantFromLiteral( + *Literal::CreateR1({1, 1, 1, 1}), &builder_); + const auto init_value = + CreateConstantFromLiteral(*Literal::CreateR0(0), &builder_); TF_ASSERT_OK(builder_.first_error()); builder_.ReduceWindow(input, init_value, - CreateScalarAddComputation(F32, &builder_), + CreateScalarAddComputation(FloatType(), &builder_), /*window_dimensions=*/{1, 2}, /*window_strides=*/{1}, Padding::kValid); ASSERT_EQ(builder_.first_error().code(), tensorflow::error::INVALID_ARGUMENT) @@ -90,79 +116,97 @@ TEST_F(ReduceWindowTest, MismatchedRanksGivesErrorStatus) { ::testing::HasSubstr("Want input dimensions size")); } -TEST_F(ReduceWindowTest, Min3In5Stride2) { - const auto input = builder_.ConstantR1({10000, 1000, 100, 10, 1}); +// Regression test for b/68964348. +TEST_P(ReduceWindowTest, R0ReduceWindow) { + const auto input = + CreateConstantFromLiteral(*Literal::CreateR0(42.0), &builder_); + const auto init = + CreateConstantFromLiteral(*Literal::CreateR0(1.0), &builder_); + builder_.ReduceWindow(input, init, + CreateScalarAddComputation(FloatType(), &builder_), + /*window_dimensions=*/{}, + /*window_strides=*/{}, Padding::kSame); + ComputeAndCompareLiteral(&builder_, *Literal::CreateR0(43.0), {}, + ErrorSpec(0.00001)); +} + +TEST_P(ReduceWindowTest, Min3In5Stride2) { + const auto input = CreateConstantFromLiteral( + *Literal::CreateR1({10000, 1000, 100, 10, 1}), &builder_); ReduceWindowMin(input, {3}, {2}, Padding::kValid); - ComputeAndCompareR1(&builder_, {100, 1}, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateR1({100, 1}), {}, + ErrorSpec(0.00001)); } -XLA_TEST_F(ReduceWindowTest, ZeroElementSmall) { +XLA_TEST_P(ReduceWindowTest, ZeroElementSmall) { Array4D input_array(1, 0, 2, 1); - - const auto input = builder_.ConstantR4FromArray4D(input_array); + const auto input = CreateConstantFromArray(input_array, &builder_); Padding padding = Padding::kSame; ReduceWindowAdd(input, {1, 1, 2, 1}, {1, 1, 1, 1}, padding); auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 2, 1}, {1, 1, 1, 1}, padding); - ComputeAndCompareR4(&builder_, *res, {}, ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), {}, + DefaultErrorSpec()); } -TEST_F(ReduceWindowTest, NonSquareSmall) { +TEST_P(ReduceWindowTest, NonSquareSmall) { Array4D input_array(1, 2, 2, 1); - input_array.FillRandom(2.f); + input_array.FillRandom(2.f, 2.f); + const auto input = CreateConstantFromArray(input_array, &builder_); - const auto input = builder_.ConstantR4FromArray4D(input_array); Padding padding = Padding::kSame; ReduceWindowAdd(input, {1, 1, 2, 1}, {1, 1, 1, 1}, padding); auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 2, 1}, {1, 1, 1, 1}, padding); - ComputeAndCompareR4(&builder_, *res, {}, ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), {}, + DefaultErrorSpec()); } -TEST_F(ReduceWindowTest, MiddleDimsSmall) { +TEST_P(ReduceWindowTest, MiddleDimsSmall) { Array4D input_array(1, 3, 3, 1); - input_array.FillRandom(2.f); - - const auto input = builder_.ConstantR4FromArray4D(input_array); + input_array.FillRandom(2.f, 2.f); + const auto input = CreateConstantFromArray(input_array, &builder_); Padding padding = Padding::kSame; ReduceWindowAdd(input, {1, 1, 1, 1}, {1, 2, 2, 1}, padding); auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 1, 1}, {1, 2, 2, 1}, padding); - ComputeAndCompareR4(&builder_, *res, {}, ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), {}, + DefaultErrorSpec()); } -TEST_F(ReduceWindowTest, Along2ndMinorDim) { +TEST_P(ReduceWindowTest, Along2ndMinorDim) { Array4D input_array(3, 6, 7, 32); - input_array.FillRandom(2.f); + input_array.FillRandom(2.f, 2.f); + const auto input = CreateConstantFromArray(input_array, &builder_); // The parameters of this reduction mimic feature norm (e.g. LRN). int lrn_diameter = 7; // diameter = 2*radius + 1 --> must be odd - const auto input = builder_.ConstantR4FromArray4D(input_array); Padding padding = Padding::kSame; ReduceWindowAdd(input, {1, 1, lrn_diameter, 1}, {1, 1, 1, 1}, padding); auto res = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {1, 1, lrn_diameter, 1}, {1, 1, 1, 1}, padding); - ComputeAndCompareR4(&builder_, *res, {}, ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), {}, + DefaultErrorSpec()); } -TEST_F(ReduceWindowTest, AmongMajor2Dims) { +TEST_P(ReduceWindowTest, AmongMajor2Dims) { Array4D input_array(4, 4, 6, 8); input_array.FillWithMinorDimNum(); + const auto input_data_handle = + CreateConstantFromArray(input_array, &builder_); int win_len = 3; int win_stride = 1; Padding padding = Padding::kSame; - const auto input_data_handle = - builder_.ConstantR4FromArray4D(input_array); // Reduce only along the x and y dimensions, according to the win_len. ReduceWindowAdd(input_data_handle, {win_len, win_len, 1, 1}, {win_stride, win_stride, 1, 1}, padding); @@ -170,18 +214,20 @@ TEST_F(ReduceWindowTest, AmongMajor2Dims) { auto result = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {win_len, win_len, 1, 1}, {win_stride, win_stride, 1, 1}, padding); - ComputeAndCompareR4(&builder_, *result, {}, ErrorSpec(1e-3, 1e-3)); + + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*result), {}, + DefaultErrorSpec()); } -TEST_F(ReduceWindowTest, AmongMajor2DimsMediumSize) { +TEST_P(ReduceWindowTest, AmongMajor2DimsMediumSize) { Array4D input_array(9, 12, 4, 89); - input_array.FillRandom(2.0f); + input_array.FillRandom(2.f, 2.f); int win_len = 3; int win_stride = 2; const auto input_data_handle = - builder_.ConstantR4FromArray4D(input_array); + CreateConstantFromArray(input_array, &builder_); Padding padding = Padding::kSame; // Reduce only along the x and y dimensions, according to the win_len. @@ -192,20 +238,21 @@ TEST_F(ReduceWindowTest, AmongMajor2DimsMediumSize) { input_array, 0.0f, {win_len, win_len, 1, 1}, {win_stride, win_stride, 1, 1}, padding); - ComputeAndCompareR4(&builder_, *result, {}, ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*result), {}, + DefaultErrorSpec()); } // TODO(b/32173947): Test support for arbitrary-sized padding. -TEST_F(ReduceWindowTest, DISABLED_AmongMajor2DimsMediumSizeLargePadding) { +TEST_P(ReduceWindowTest, DISABLED_AmongMajor2DimsMediumSizeLargePadding) { Array4D input_array(9, 12, 4, 89); // simulate Dim0IsMinor layout - input_array.FillRandom(2.0f); + input_array.FillRandom(2.f, 2.f); int64 rank = 4; int win_len = 3; int win_stride = 2; const auto input_data_handle = - builder_.ConstantR4FromArray4D(input_array); + CreateConstantFromArray(input_array, &builder_); Padding padding = Padding::kSame; // Reduce only along the x and y dimensions, according to the win_len. @@ -222,26 +269,28 @@ TEST_F(ReduceWindowTest, DISABLED_AmongMajor2DimsMediumSizeLargePadding) { input_array, 0.0f, {win_len, win_len, 1, 1}, {win_stride, win_stride, 1, 1}, padding); - ComputeAndCompareR4(&builder_, *result, {}, ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*result), {}, + DefaultErrorSpec()); } -XLA_TEST_F(ReduceWindowTest, Add1x1x2In2x1x2) { +XLA_TEST_P(ReduceWindowTest, Add1x1x2In2x1x2) { Array3D input_array(2, 1, 2); input_array(0, 0, 0) = 1000; input_array(0, 0, 1) = 100; input_array(1, 0, 0) = 10; input_array(1, 0, 1) = 1; - auto input = builder_.ConstantR3FromArray3D(input_array); + const auto input = CreateConstantFromArray(input_array, &builder_); ReduceWindowAdd(input, {1, 1, 2}, {1, 1, 1}, Padding::kValid); Array3D expected(2, 1, 1); expected(0, 0, 0) = 1100; expected(1, 0, 0) = 11; - ComputeAndCompareR3(&builder_, expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(expected), {}, + DefaultErrorSpec()); } -XLA_TEST_F(ReduceWindowTest, Add1x1x2In2x1x3Stride1x1x2) { +XLA_TEST_P(ReduceWindowTest, Add1x1x2In2x1x3Stride1x1x2) { Array3D input_array(2, 1, 3); input_array(0, 0, 0) = 100; input_array(0, 0, 1) = 10; @@ -249,17 +298,18 @@ XLA_TEST_F(ReduceWindowTest, Add1x1x2In2x1x3Stride1x1x2) { input_array(1, 0, 0) = 500; input_array(1, 0, 1) = 50; input_array(1, 0, 2) = 5; - auto input = builder_.ConstantR3FromArray3D(input_array); + const auto input = CreateConstantFromArray(input_array, &builder_); ReduceWindowAdd(input, {1, 1, 2}, {1, 1, 2}, Padding::kValid); Array3D expected(2, 1, 1); expected(0, 0, 0) = 110; expected(1, 0, 0) = 550; - ComputeAndCompareR3(&builder_, expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(expected), {}, + DefaultErrorSpec()); } -XLA_TEST_F(ReduceWindowTest, Add1x1x2In2x1x3SamePad) { +XLA_TEST_P(ReduceWindowTest, Add1x1x2In2x1x3SamePad) { Array3D input_array(2, 1, 3); input_array(0, 0, 0) = 100; input_array(0, 0, 1) = 10; @@ -267,7 +317,7 @@ XLA_TEST_F(ReduceWindowTest, Add1x1x2In2x1x3SamePad) { input_array(1, 0, 0) = 500; input_array(1, 0, 1) = 50; input_array(1, 0, 2) = 5; - auto input = builder_.ConstantR3FromArray3D(input_array); + const auto input = CreateConstantFromArray(input_array, &builder_); ReduceWindowAdd(input, {1, 1, 2}, {1, 1, 1}, Padding::kSame); @@ -278,30 +328,34 @@ XLA_TEST_F(ReduceWindowTest, Add1x1x2In2x1x3SamePad) { expected(1, 0, 0) = 550; expected(1, 0, 1) = 55; expected(1, 0, 2) = 5; - ComputeAndCompareR3(&builder_, expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(expected), {}, + DefaultErrorSpec()); } // Tests a reduction function that is not a simple add/min/max/etc. -XLA_TEST_F(ReduceWindowTest, NonstandardReduceFunction) { +XLA_TEST_P(ReduceWindowTest, NonstandardReduceFunction) { Array4D input_array(1, 2, 2, 1); input_array(0, 0, 0, 0) = 1; input_array(0, 0, 1, 0) = 2; input_array(0, 1, 0, 0) = 3; input_array(0, 1, 1, 0) = 4; + const auto input = CreateConstantFromArray(input_array, &builder_); - const auto input = builder_.ConstantR4FromArray4D(input_array); Padding padding = Padding::kValid; - - const Shape scalar = ShapeUtil::MakeShape(F32, {}); + const Shape scalar = ShapeUtil::MakeShape(FloatType(), {}); auto b = builder_.CreateSubBuilder("unusual"); auto lhs = b->Parameter(0, scalar, "lhs"); auto rhs = b->Parameter(1, scalar, "rhs"); - b->Min(b->Add(lhs, rhs), b->ConstantR0(8.0f)); + b->Min(b->Add(lhs, rhs), + CreateConstantFromLiteral(*Literal::CreateR0(8.0f), b.get())); Computation reduce_fn = b->BuildAndNoteError(); - builder_.ReduceWindow(input, builder_.ConstantR0(3.0f), reduce_fn, - /*window_dimensions=*/{1, 1, 2, 1}, - /*window_strides=*/{1, 1, 1, 1}, padding); + builder_.ReduceWindow( + input, + CreateConstantFromLiteral(*Literal::CreateR0(3.0f), &builder_), + reduce_fn, + /*window_dimensions=*/{1, 1, 2, 1}, + /*window_strides=*/{1, 1, 1, 1}, padding); const auto reduce_func = [](float arg1, float arg2) { return std::min(arg1 + arg2, 8.0f); @@ -312,17 +366,19 @@ XLA_TEST_F(ReduceWindowTest, NonstandardReduceFunction) { /*window=*/{1, 1, 2, 1}, /*stride=*/{1, 1, 1, 1}, padding); - ComputeAndCompareR4(&builder_, *expected, {}, ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*expected), {}, + DefaultErrorSpec()); } -TEST_F(ReduceWindowTest, R4UnitWindow) { +TEST_P(ReduceWindowTest, R4UnitWindow) { Array4D input_array(13, 12, 8, 15); - input_array.Fill(1.0f); + input_array.FillRandom(2.f, 2.f); std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input_array, LayoutUtil::MakeLayout({0, 3, 2, 1})); - ComputationDataHandle input = - builder_.Parameter(0, input_literal->shape(), "operand"); + ComputationDataHandle input; + auto input_data = CreateParameterAndTransferLiteral( + 0, *input_literal, "parameter", &builder_, &input); Padding padding = Padding::kSame; ReduceWindowAdd(input, {1, 1, 7, 1}, {1, 4, 1, 1}, padding); @@ -330,15 +386,11 @@ TEST_F(ReduceWindowTest, R4UnitWindow) { auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 7, 1}, {1, 4, 1, 1}, padding); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_data, - client_->TransferToServer(*input_literal)); - ComputeAndCompareR4(&builder_, *res, {input_data.get()}, - ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), + {input_data.get()}, DefaultErrorSpec()); } -XLA_TEST_F(HloTestBase, R6AddMultipleStrides) { - auto b = HloComputation::Builder(TestName()); - +XLA_TEST_P(ReduceWindowTest, R6AddMultipleStrides) { std::vector input_dims(6, 8); auto shape = ShapeUtil::MakeShape(F32, input_dims); @@ -348,56 +400,15 @@ XLA_TEST_F(HloTestBase, R6AddMultipleStrides) { }; TF_EXPECT_OK(arg_literal->Populate(generator)); - auto input = - b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal))); - - auto init_value = b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.f))); - - HloComputation::Builder add_computation("add"); - Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); - auto param_lhs = add_computation.AddInstruction( - HloInstruction::CreateParameter(0, scalar_shape, "lhs")); - auto param_rhs = add_computation.AddInstruction( - HloInstruction::CreateParameter(1, scalar_shape, "rhs")); - add_computation.AddInstruction(HloInstruction::CreateBinary( - scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs)); - - auto module = CreateNewModule(); - auto add_func = module->AddEmbeddedComputation(add_computation.Build()); - - WindowDimension trivial_dim; - trivial_dim.set_size(1); - trivial_dim.set_stride(1); - trivial_dim.set_padding_low(0); - trivial_dim.set_padding_high(0); - trivial_dim.set_window_dilation(1); - trivial_dim.set_base_dilation(1); - - WindowDimension active_dim; - active_dim.set_size(3); - active_dim.set_stride(1); - active_dim.set_padding_low(0); - active_dim.set_padding_high(0); - active_dim.set_window_dilation(1); - active_dim.set_base_dilation(1); - - Window window; - *window.add_dimensions() = active_dim; - *window.add_dimensions() = trivial_dim; - *window.add_dimensions() = active_dim; - *window.add_dimensions() = active_dim; - *window.add_dimensions() = trivial_dim; - *window.add_dimensions() = trivial_dim; - - // Non-monotonic output layout with minor dims trivial. + const auto input = CreateConstantFromLiteral(*arg_literal, &builder_); + + Padding padding = Padding::kValid; + ReduceWindowAdd(input, {3, 1, 3, 3, 1, 1}, {1, 1, 1, 1, 1, 1}, padding); + std::vector output_layout = {1, 5, 3, 2, 0, 4}; std::vector output_dims = {6, 8, 6, 6, 8, 8}; Shape result_shape = ShapeUtil::MakeShapeWithLayout(F32, output_dims, output_layout); - b.AddInstruction(HloInstruction::CreateReduceWindow( - result_shape, input, init_value, window, add_func)); - std::unique_ptr expected = Literal::CreateFromShape(result_shape); auto out_generator = [&](tensorflow::gtl::ArraySlice indexes) -> float { @@ -405,82 +416,37 @@ XLA_TEST_F(HloTestBase, R6AddMultipleStrides) { }; TF_EXPECT_OK(expected->Populate(out_generator)); - module->AddEntryComputation(b.Build()); - auto actual = ExecuteAndTransfer(std::move(module), {}); - - LiteralTestUtil::ExpectNear(*actual, *expected, ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&builder_, *expected, {}, DefaultErrorSpec()); } -XLA_TEST_F(HloTestBase, R6Add) { - auto b = HloComputation::Builder(TestName()); - +XLA_TEST_P(ReduceWindowTest, R6Add) { std::vector input_dims(6, 8); + auto shape = ShapeUtil::MakeShape(F32, input_dims); + std::unique_ptr arg_literal = Literal::CreateFullWithMonotonicDim0MajorLayout(input_dims, 1.0f); - auto input = - b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal))); - - auto init_value = b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.f))); - - HloComputation::Builder add_computation("add"); - Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); - auto param_lhs = add_computation.AddInstruction( - HloInstruction::CreateParameter(0, scalar_shape, "lhs")); - auto param_rhs = add_computation.AddInstruction( - HloInstruction::CreateParameter(1, scalar_shape, "rhs")); - add_computation.AddInstruction(HloInstruction::CreateBinary( - scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs)); - - auto module = CreateNewModule(); - auto add_func = module->AddEmbeddedComputation(add_computation.Build()); - - WindowDimension trivial_dim; - trivial_dim.set_size(1); - trivial_dim.set_stride(1); - trivial_dim.set_padding_low(0); - trivial_dim.set_padding_high(0); - trivial_dim.set_window_dilation(1); - trivial_dim.set_base_dilation(1); - - WindowDimension active_dim; - active_dim.set_size(3); - active_dim.set_stride(1); - active_dim.set_padding_low(0); - active_dim.set_padding_high(0); - active_dim.set_window_dilation(1); - active_dim.set_base_dilation(1); - - Window window; - *window.add_dimensions() = trivial_dim; - *window.add_dimensions() = trivial_dim; - *window.add_dimensions() = active_dim; - *window.add_dimensions() = active_dim; - *window.add_dimensions() = trivial_dim; - *window.add_dimensions() = trivial_dim; - - Shape shape = ShapeUtil::MakeShape(F32, {8, 8, 6, 6, 8, 8}); - b.AddInstruction(HloInstruction::CreateReduceWindow(shape, input, init_value, - window, add_func)); + + const auto input = CreateConstantFromLiteral(*arg_literal, &builder_); + + Padding padding = Padding::kValid; + ReduceWindowAdd(input, {1, 1, 3, 3, 1, 1}, {1, 1, 1, 1, 1, 1}, padding); std::vector output_dims = {8, 8, 6, 6, 8, 8}; std::unique_ptr expected = Literal::CreateFullWithMonotonicDim0MajorLayout(output_dims, 9.0f); - module->AddEntryComputation(b.Build()); - auto actual = ExecuteAndTransfer(std::move(module), {}); - - LiteralTestUtil::ExpectNear(*actual, *expected, ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&builder_, *expected, {}, DefaultErrorSpec()); } -XLA_TEST_F(ReduceWindowTest, R4SecondMinorStride) { +XLA_TEST_P(ReduceWindowTest, R4SecondMinorStride) { Array4D input_array(2, 1, 27, 119); input_array.FillRandom(2.0f); std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); - ComputationDataHandle input = - builder_.Parameter(0, input_literal->shape(), "operand"); + ComputationDataHandle input; + auto input_data = CreateParameterAndTransferLiteral( + 0, *input_literal, "parameter", &builder_, &input); int win_len = 1; int stride = 8; @@ -490,20 +456,19 @@ XLA_TEST_F(ReduceWindowTest, R4SecondMinorStride) { auto res = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_data, - client_->TransferToServer(*input_literal)); - ComputeAndCompareR4(&builder_, *res, {input_data.get()}, - ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), + {input_data.get()}, DefaultErrorSpec()); } -XLA_TEST_F(ReduceWindowTest, R4SecondMinorUnitStride) { +XLA_TEST_P(ReduceWindowTest, R4SecondMinorUnitStride) { Array4D input_array(3, 2, 4, 64); input_array.FillRandom(2.0f); std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); - ComputationDataHandle input = - builder_.Parameter(0, input_literal->shape(), "operand"); + ComputationDataHandle input; + auto input_data = CreateParameterAndTransferLiteral( + 0, *input_literal, "parameter", &builder_, &input); int win_len = 3; int stride = 1; @@ -513,20 +478,19 @@ XLA_TEST_F(ReduceWindowTest, R4SecondMinorUnitStride) { auto res = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_data, - client_->TransferToServer(*input_literal)); - ComputeAndCompareR4(&builder_, *res, {input_data.get()}, - ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), + {input_data.get()}, DefaultErrorSpec()); } -XLA_TEST_F(ReduceWindowTest, R4SecondMinorWin) { +XLA_TEST_P(ReduceWindowTest, R4SecondMinorWin) { Array4D input_array(1, 3, 12, 200); input_array.FillRandom(2.0f); std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); - ComputationDataHandle input = - builder_.Parameter(0, input_literal->shape(), "operand"); + ComputationDataHandle input; + auto input_data = CreateParameterAndTransferLiteral( + 0, *input_literal, "parameter", &builder_, &input); int win_len = 8; int stride = 5; @@ -536,13 +500,11 @@ XLA_TEST_F(ReduceWindowTest, R4SecondMinorWin) { auto res = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_data, - client_->TransferToServer(*input_literal)); - ComputeAndCompareR4(&builder_, *res, {input_data.get()}, - ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), + {input_data.get()}, DefaultErrorSpec()); } -TEST_F(ReduceWindowTest, AmongMajor2DimsMultipleMinor) { +TEST_P(ReduceWindowTest, AmongMajor2DimsMultipleMinor) { Array4D input_array(6, 4, 10, 130); input_array.FillRandom(2.0f); @@ -551,7 +513,7 @@ TEST_F(ReduceWindowTest, AmongMajor2DimsMultipleMinor) { Padding padding = Padding::kSame; const auto input_data_handle = - builder_.ConstantR4FromArray4D(input_array); + CreateConstantFromArray(input_array, &builder_); // Reduce only along the x and y dimensions, according to the win_len. ReduceWindowAdd(input_data_handle, {win_len, win_len, 1, 1}, {win_stride, win_stride, 1, 1}, padding); @@ -559,36 +521,42 @@ TEST_F(ReduceWindowTest, AmongMajor2DimsMultipleMinor) { auto result = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {win_len, win_len, 1, 1}, {win_stride, win_stride, 1, 1}, padding); - ComputeAndCompareR4(&builder_, *result, {}, ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*result), {}, + DefaultErrorSpec()); } -XLA_TEST_F(ReduceWindowTest, Add24In1152_NoOverlap) { +XLA_TEST_P(ReduceWindowTest, Add24In1152_NoOverlap) { std::vector input_vector(128 * 9, 1); - const auto input = builder_.ConstantR1(input_vector); + const auto input = CreateConstantFromLiteral( + *Literal::CreateR1(input_vector), &builder_); ReduceWindowAdd(input, {32}, {128}, Padding::kValid); - ComputeAndCompareR1(&builder_, {32, 32, 32, 32, 32, 32, 32, 32, 32}, - {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral( + &builder_, + *Literal::CreateR1({32, 32, 32, 32, 32, 32, 32, 32, 32}), {}, + DefaultErrorSpec()); } -XLA_TEST_F(ReduceWindowTest, Add128In128Stride128) { - const auto input = builder_.ConstantR1( - {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); +XLA_TEST_P(ReduceWindowTest, Add128In128Stride128) { + std::vector input_vector{ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + const auto input = CreateConstantFromLiteral( + *Literal::CreateR1(input_vector), &builder_); ReduceWindowAdd(input, {128}, {128}, Padding::kValid); - ComputeAndCompareR1(&builder_, {1088}, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateR1({1088}), {}, + DefaultErrorSpec()); } // Regression test for a bug that appeared in Inception (b/34784899). -TEST_F(ReduceWindowTest, R2ReduceWindowInceptionFromBroadcast) { +TEST_P(ReduceWindowTest, R2ReduceWindowInceptionFromBroadcast) { Array2D input_array(14, 14, 1.0f); - ComputationDataHandle input = - builder_.Broadcast(builder_.ConstantLiteral(Literal::One(F32)), {14, 14}); + const auto input = CreateConstantFromArray(input_array, &builder_); int win_len = 3; int stride = 1; @@ -598,13 +566,14 @@ TEST_F(ReduceWindowTest, R2ReduceWindowInceptionFromBroadcast) { auto res = ReferenceUtil::ReduceWindow2DAdd( input_array, 0.0f, {win_len, win_len}, {stride, stride}, padding); - ComputeAndCompareR2(&builder_, *res, {}, ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), + {}, DefaultErrorSpec()); } -TEST_F(ReduceWindowTest, R2ReduceWindowNonOverlappingFromBroadcast) { +TEST_P(ReduceWindowTest, R2ReduceWindowNonOverlappingFromBroadcast) { Array2D input_array(6, 4, 1.0f); - ComputationDataHandle input = - builder_.Broadcast(builder_.ConstantLiteral(Literal::One(F32)), {6, 4}); + ComputationDataHandle input = builder_.Broadcast( + CreateConstantFromLiteral(Literal::One(F32), &builder_), {6, 4}); Padding padding = Padding::kSame; ReduceWindowAdd(input, {4, 2}, {3, 3}, padding); @@ -612,9 +581,13 @@ TEST_F(ReduceWindowTest, R2ReduceWindowNonOverlappingFromBroadcast) { auto res = ReferenceUtil::ReduceWindow2DAdd(input_array, 0.0f, {4, 2}, {3, 3}, padding); - ComputeAndCompareR2(&builder_, *res, {}, ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), + {}, DefaultErrorSpec()); } +INSTANTIATE_TEST_CASE_P(ReduceWindowTestInstance, ReduceWindowTest, + ::testing::ValuesIn(use_bfloat16_params)); + enum Reducer { kAdd, kMax }; struct R4ReduceWindowTestData { @@ -628,30 +601,36 @@ struct R4ReduceWindowTestData { }; string R4ReduceWindowTestDataToString( - const ::testing::TestParamInfo& data) { + const ::testing::TestParamInfo< + ::testing::tuple>& data) { + const auto& param = ::testing::get<0>(data.param); string str = tensorflow::strings::StrCat( - "base_bounds_", - tensorflow::str_util::Join(data.param.base_bounds, "x"), // + "base_bounds_", tensorflow::str_util::Join(param.base_bounds, "x"), // "__window_bounds_", - tensorflow::str_util::Join(data.param.window_bounds, "x"), // - "__strides_", tensorflow::str_util::Join(data.param.strides, "x"), // - "__pad_low_", tensorflow::str_util::Join(data.param.pad_low, "x"), // - "__pad_high_", tensorflow::str_util::Join(data.param.pad_high, "x"), // - (data.param.reducer == kAdd) ? "add" : "max"); - CHECK(data.param.reducer == kAdd || data.param.reducer == kMax); + tensorflow::str_util::Join(param.window_bounds, "x"), // + "__strides_", tensorflow::str_util::Join(param.strides, "x"), // + "__pad_low_", tensorflow::str_util::Join(param.pad_low, "x"), // + "__pad_high_", tensorflow::str_util::Join(param.pad_high, "x"), // + (param.reducer == kAdd) ? "add" : "max"); + CHECK(param.reducer == kAdd || param.reducer == kMax); // Test names are not allowed to contain the '-' character. std::replace(str.begin(), str.end(), '-', 'n'); + if (::testing::get<1>(data.param)) { + str = tensorflow::strings::StrCat(str, "_bfloat16"); + } return str; } -class R4ReduceWindowTest - : public ClientLibraryTestBase, - public ::testing::WithParamInterface { +class R4ReduceWindowTest : public ReduceWindowTestBase, + public ::testing::WithParamInterface< + ::testing::tuple> { protected: + R4ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); } + void DoIt() { ComputationBuilder b(client_, TestName()); - const auto& param = GetParam(); + const auto& param = ::testing::get<0>(GetParam()); const float kInitValue = 0.0f; @@ -660,23 +639,24 @@ class R4ReduceWindowTest input.FillIota(1); std::unique_ptr input_literal = Literal::CreateR4FromArray4D(input); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_arg, - client_->TransferToServer(*input_literal)); + ComputationDataHandle parameter; + auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0", + &b, ¶meter); std::vector> padding(4); for (int i = 0; i < 4; ++i) { padding[i] = {param.pad_low[i], param.pad_high[i]}; } - auto parameter = b.Parameter(0, input_literal->shape(), "p0"); - auto pad_value = b.ConstantR0(kInitValue); + auto init_value = + CreateConstantFromLiteral(*Literal::CreateR0(kInitValue), &b); CHECK(param.reducer == kAdd || param.reducer == kMax); auto computation = param.reducer == kAdd - ? CreateScalarAddComputation(F32, &b) - : CreateScalarMaxComputation(F32, &b); + ? CreateScalarAddComputation(FloatType(), &b) + : CreateScalarMaxComputation(FloatType(), &b); b.ReduceWindowWithGeneralPadding( /*operand=*/parameter, - /*init_value=*/pad_value, + /*init_value=*/init_value, /*computation=*/computation, /*window_dimensions=*/param.window_bounds, /*window_strides=*/param.strides, @@ -694,8 +674,8 @@ class R4ReduceWindowTest /*window=*/param.window_bounds, /*stride=*/param.strides, /*padding=*/padding); - ComputeAndCompareR4(&b, *expected, {input_arg.get()}, - ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&b, *Literal::CreateFromArray(*expected), + {input_arg.get()}, DefaultErrorSpec()); } }; @@ -824,9 +804,11 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = { /*reducer=*/kAdd}, }; -INSTANTIATE_TEST_CASE_P(R4ReduceWindowTestInstantiation, R4ReduceWindowTest, - ::testing::ValuesIn(kR4ReduceWindowTestValues), - R4ReduceWindowTestDataToString); +INSTANTIATE_TEST_CASE_P( + R4ReduceWindowTestInstantiation, R4ReduceWindowTest, + ::testing::Combine(::testing::ValuesIn(kR4ReduceWindowTestValues), + ::testing::ValuesIn(use_bfloat16_params)), + R4ReduceWindowTestDataToString); class R4ReduceWindowLargeTest : public R4ReduceWindowTest {}; @@ -849,10 +831,11 @@ const R4ReduceWindowTestData kR4ReduceWindowLargeTestValues[] = { /*reducer=*/kAdd}, }; -INSTANTIATE_TEST_CASE_P(R4ReduceWindowLargeTestInstantiation, - R4ReduceWindowLargeTest, - ::testing::ValuesIn(kR4ReduceWindowLargeTestValues), - R4ReduceWindowTestDataToString); +INSTANTIATE_TEST_CASE_P( + R4ReduceWindowLargeTestInstantiation, R4ReduceWindowLargeTest, + ::testing::Combine(::testing::ValuesIn(kR4ReduceWindowLargeTestValues), + ::testing::ValuesIn(use_bfloat16_params)), + R4ReduceWindowTestDataToString); struct R2ReduceWindowTestData { int64 base_bounds[2]; @@ -900,26 +883,33 @@ struct R2ReduceWindowTestData { }; string R2ReduceWindowTestDataToString( - const ::testing::TestParamInfo& data) { + const ::testing::TestParamInfo< + ::testing::tuple>& data) { + const auto& param = ::testing::get<0>(data.param); string str = tensorflow::strings::StrCat( - "base_bounds_", - tensorflow::str_util::Join(data.param.base_bounds, "x"), // + "base_bounds_", tensorflow::str_util::Join(param.base_bounds, "x"), // "__window_bounds_", - tensorflow::str_util::Join(data.param.window_bounds, "x"), // - "__strides_", tensorflow::str_util::Join(data.param.strides, "x"), // - "__padding_", data.param.padding == Padding::kSame ? "same" : "valid", // - "__layout_", data.param.layout[0], "_", data.param.layout[1], // - "__reducer_", data.param.reducer == kAdd ? "add" : "max"); + tensorflow::str_util::Join(param.window_bounds, "x"), // + "__strides_", tensorflow::str_util::Join(param.strides, "x"), // + "__padding_", param.padding == Padding::kSame ? "same" : "valid", // + "__layout_", param.layout[0], "_", param.layout[1], // + "__reducer_", param.reducer == kAdd ? "add" : "max"); + if (::testing::get<1>(data.param)) { + str = tensorflow::strings::StrCat(str, "_bfloat16"); + } return str; } -class R2ReduceWindowTest - : public ClientLibraryTestBase, - public ::testing::WithParamInterface {}; +class R2ReduceWindowTest : public ReduceWindowTestBase, + public ::testing::WithParamInterface< + ::testing::tuple> { + protected: + R2ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); } +}; TEST_P(R2ReduceWindowTest, Add) { ComputationBuilder b(client_, TestName()); - const auto& param = GetParam(); + const auto& param = ::testing::get<0>(GetParam()); CHECK(param.reducer == kAdd); const float kInitValue = 0.0f; @@ -927,12 +917,15 @@ TEST_P(R2ReduceWindowTest, Add) { std::unique_ptr input_literal = Literal::CreateR2FromArray2DWithLayout( input, LayoutUtil::MakeLayout(param.layout)); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_arg, - client_->TransferToServer(*input_literal)); - b.ReduceWindow(/*operand=*/ - b.Parameter(0, input_literal->shape(), "p0"), - /*init_value=*/b.ConstantR0(kInitValue), - /*computation=*/CreateScalarAddComputation(F32, &b), + + ComputationDataHandle parameter; + auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0", + &b, ¶meter); + auto init_value = + CreateConstantFromLiteral(*Literal::CreateR0(kInitValue), &b); + b.ReduceWindow(/*operand=*/parameter, + /*init_value=*/init_value, + /*computation=*/CreateScalarAddComputation(FloatType(), &b), /*window_dimensions=*/param.window_bounds, /*window_strides=*/param.strides, /*padding=*/param.padding); @@ -940,90 +933,145 @@ TEST_P(R2ReduceWindowTest, Add) { /*operand=*/input, /*init=*/kInitValue, /*window=*/param.window_bounds, /*stride=*/param.strides, /*padding=*/param.padding); - ComputeAndCompareR2(&b, *expected, {input_arg.get()}, - ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&b, *Literal::CreateFromArray(*expected), + {input_arg.get()}, DefaultErrorSpec()); } -INSTANTIATE_TEST_CASE_P(R2ReduceWindowTestInstantiation, R2ReduceWindowTest, - ::testing::ValuesIn(kR2TestCases), - R2ReduceWindowTestDataToString); +INSTANTIATE_TEST_CASE_P( + R2ReduceWindowTestInstantiation, R2ReduceWindowTest, + ::testing::Combine(::testing::ValuesIn(kR2TestCases), + ::testing::ValuesIn(use_bfloat16_params)), + R2ReduceWindowTestDataToString); struct R1ReduceWindowTestData { int64 base_bounds[1]; int64 window_bounds[1]; int64 strides[1]; - Padding padding; + int64 pad_low[1]; + int64 pad_high[1]; Reducer reducer; } kR1TestCases[] = { {/*base_bounds=*/{1}, /*window_bounds=*/{1}, /*strides=*/{1}, - /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + /*pad_low=*/{xla::MakePadding({1}, {1}, {1}, Padding::kValid)[0].first}, + /*pad_high=*/{xla::MakePadding({1}, {1}, {1}, Padding::kValid)[0].second}, + /*reducer=*/Reducer::kAdd}, {/*base_bounds=*/{3}, /*window_bounds=*/{3}, /*strides=*/{1}, - /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + /*pad_low=*/{xla::MakePadding({3}, {3}, {1}, Padding::kValid)[0].first}, + /*pad_high=*/{xla::MakePadding({3}, {3}, {1}, Padding::kValid)[0].second}, + /*reducer=*/Reducer::kAdd}, {/*base_bounds=*/{3}, /*window_bounds=*/{2}, /*strides=*/{1}, - /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + /*pad_low=*/{xla::MakePadding({3}, {2}, {1}, Padding::kValid)[0].first}, + /*pad_high=*/{xla::MakePadding({3}, {2}, {1}, Padding::kValid)[0].second}, + /*reducer=*/Reducer::kAdd}, {/*base_bounds=*/{5}, /*window_bounds=*/{1}, /*strides=*/{1}, - /*padding=*/Padding::kValid, /*reducer=*/Reducer::kMax}, + /*pad_low=*/{xla::MakePadding({5}, {1}, {1}, Padding::kValid)[0].first}, + /*pad_high=*/{xla::MakePadding({5}, {1}, {1}, Padding::kValid)[0].second}, + /*reducer=*/Reducer::kMax}, {/*base_bounds=*/{16}, /*window_bounds=*/{4}, /*strides=*/{4}, - /*padding=*/Padding::kValid, /*reducer=*/Reducer::kMax}, + /*pad_low=*/{xla::MakePadding({16}, {4}, {4}, Padding::kValid)[0].first}, + /*pad_high=*/{xla::MakePadding({16}, {4}, {4}, Padding::kValid)[0].second}, + /*reducer=*/Reducer::kMax}, {/*base_bounds=*/{16}, /*window_bounds=*/{4}, /*strides=*/{3}, - /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + /*pad_low=*/{xla::MakePadding({16}, {4}, {3}, Padding::kValid)[0].first}, + /*pad_high=*/{xla::MakePadding({16}, {4}, {3}, Padding::kValid)[0].second}, + /*reducer=*/Reducer::kAdd}, - {/*base_bounds=*/{128 * 2}, /*window_bounds=*/{30}, + {/*base_bounds=*/{128 * 2}, + /*window_bounds=*/{30}, /*strides=*/{27}, - /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, - - {/*base_bounds=*/{128 * 17}, /*window_bounds=*/{7}, + /*pad_low=*/ + {xla::MakePadding({128 * 2}, {30}, {27}, Padding::kValid)[0].first}, + /*pad_high=*/ + {xla::MakePadding({128 * 2}, {30}, {27}, Padding::kValid)[0].second}, + /*reducer=*/Reducer::kAdd}, + + {/*base_bounds=*/{128 * 17}, + /*window_bounds=*/{7}, /*strides=*/{64}, - /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, - - {/*base_bounds=*/{128 * 2}, /*window_bounds=*/{32}, + /*pad_low=*/ + {xla::MakePadding({128 * 17}, {7}, {64}, Padding::kValid)[0].first}, + /*pad_high=*/ + {xla::MakePadding({128 * 17}, {7}, {64}, Padding::kValid)[0].second}, + /*reducer=*/Reducer::kAdd}, + + {/*base_bounds=*/{128 * 2}, + /*window_bounds=*/{32}, /*strides=*/{56}, - /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + /*pad_low=*/ + {xla::MakePadding({128 * 2}, {32}, {56}, Padding::kValid)[0].first}, + /*pad_high=*/ + {xla::MakePadding({128 * 2}, {32}, {56}, Padding::kValid)[0].second}, + /*reducer=*/Reducer::kAdd}, {/*base_bounds=*/{3}, /*window_bounds=*/{2}, /*strides=*/{1}, - /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + /*pad_low=*/{xla::MakePadding({3}, {2}, {1}, Padding::kSame)[0].first}, + /*pad_high=*/{xla::MakePadding({3}, {2}, {1}, Padding::kSame)[0].second}, + /*reducer=*/Reducer::kAdd}, {/*base_bounds=*/{5}, /*window_bounds=*/{3}, /*strides=*/{2}, - /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + /*pad_low=*/{xla::MakePadding({5}, {3}, {2}, Padding::kSame)[0].first}, + /*pad_high=*/{xla::MakePadding({5}, {3}, {2}, Padding::kSame)[0].second}, + /*reducer=*/Reducer::kAdd}, {/*base_bounds=*/{16}, /*window_bounds=*/{4}, /*strides=*/{3}, - /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + /*pad_low=*/{xla::MakePadding({16}, {4}, {3}, Padding::kSame)[0].first}, + /*pad_high=*/{xla::MakePadding({16}, {4}, {3}, Padding::kSame)[0].second}, + /*reducer=*/Reducer::kAdd}, + + {/*base_bounds=*/{5}, /*window_bounds=*/{5}, + /*strides=*/{1}, + /*pad_low=*/{0}, + /*pad_high=*/{5}, + /*reducer=*/Reducer::kAdd}, + + {/*base_bounds=*/{5}, /*window_bounds=*/{5}, + /*strides=*/{1}, + /*pad_low=*/{5}, + /*pad_high=*/{0}, + /*reducer=*/Reducer::kAdd}, }; string R1ReduceWindowTestDataToString( - const ::testing::TestParamInfo& data) { + const ::testing::TestParamInfo< + ::testing::tuple>& data) { + const auto& param = ::testing::get<0>(data.param); string str = tensorflow::strings::StrCat( - "base_bounds_", - tensorflow::str_util::Join(data.param.base_bounds, "x"), // - "__window_bounds_", - tensorflow::str_util::Join(data.param.window_bounds, "x"), // - "__strides_", tensorflow::str_util::Join(data.param.strides, "x"), // - "__padding_", data.param.padding == Padding::kSame ? "same" : "valid", // - "__reducer_", data.param.reducer == kAdd ? "add" : "max"); + "base_bounds_", tensorflow::str_util::Join(param.base_bounds, "x"), + "__window_bounds_", tensorflow::str_util::Join(param.window_bounds, "x"), + "__strides_", tensorflow::str_util::Join(param.strides, "x"), + "__pad_low_", tensorflow::str_util::Join(param.pad_low, "x"), + "__pad_high_", tensorflow::str_util::Join(param.pad_high, "x"), + "__reducer_", param.reducer == kAdd ? "add" : "max"); + if (::testing::get<1>(data.param)) { + str = tensorflow::strings::StrCat(str, "_bfloat16"); + } return str; } -class R1ReduceWindowTest - : public ClientLibraryTestBase, - public ::testing::WithParamInterface {}; +class R1ReduceWindowTest : public ReduceWindowTestBase, + public ::testing::WithParamInterface< + ::testing::tuple> { + protected: + R1ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); } +}; TEST_P(R1ReduceWindowTest, DoIt) { ComputationBuilder b(client_, TestName()); - const auto& param = GetParam(); + const auto& param = ::testing::get<0>(GetParam()); CHECK(param.reducer == kAdd || param.reducer == kMax); const float kInitValue = 0.0f; @@ -1031,18 +1079,24 @@ TEST_P(R1ReduceWindowTest, DoIt) { std::iota(std::begin(input_vector), std::end(input_vector), 0); std::unique_ptr input_literal = Literal::CreateR1(tensorflow::gtl::ArraySlice(input_vector)); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_arg, - client_->TransferToServer(*input_literal)); + ComputationDataHandle parameter; + auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0", + &b, ¶meter); + + std::vector> padding(1); + padding[0] = {param.pad_low[0], param.pad_high[0]}; auto computation = param.reducer == kAdd - ? CreateScalarAddComputation(F32, &b) - : CreateScalarMaxComputation(F32, &b); - b.ReduceWindow(/*operand=*/ - b.Parameter(0, input_literal->shape(), "p0"), - /*init_value=*/b.ConstantR0(kInitValue), - /*computation=*/computation, - /*window_dimensions=*/param.window_bounds, - /*window_strides=*/param.strides, /*padding=*/param.padding); + ? CreateScalarAddComputation(FloatType(), &b) + : CreateScalarMaxComputation(FloatType(), &b); + auto init_value = + CreateConstantFromLiteral(*Literal::CreateR0(kInitValue), &b); + b.ReduceWindowWithGeneralPadding( + /*operand=*/parameter, + /*init_value=*/init_value, + /*computation=*/computation, + /*window_dimensions=*/param.window_bounds, + /*window_strides=*/param.strides, /*padding=*/padding); auto reduce_func = param.reducer == kAdd ? +[](float a, float b) { return a + b; } @@ -1052,14 +1106,17 @@ TEST_P(R1ReduceWindowTest, DoIt) { /*init=*/kInitValue, /*reduce_func=*/reduce_func, /*window=*/param.window_bounds, - /*stride=*/param.strides, /*padding=*/param.padding); + /*stride=*/param.strides, + /*padding=*/padding); - ComputeAndCompareR1(&b, tensorflow::gtl::ArraySlice(*expected), - {input_arg.get()}, ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&b, *Literal::CreateR1(*expected), + {input_arg.get()}, DefaultErrorSpec()); } -INSTANTIATE_TEST_CASE_P(R1ReduceWindowTestInstantiation, R1ReduceWindowTest, - ::testing::ValuesIn(kR1TestCases), - R1ReduceWindowTestDataToString); +INSTANTIATE_TEST_CASE_P( + R1ReduceWindowTestInstantiation, R1ReduceWindowTest, + ::testing::Combine(::testing::ValuesIn(kR1TestCases), + ::testing::ValuesIn(use_bfloat16_params)), + R1ReduceWindowTestDataToString); } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/reshape_test.cc b/tensorflow/compiler/xla/tests/reshape_test.cc index 72c68f24a0a954deb0564e9a0e924edfaf5b5484..ddd50d7a5864d73de7916ce736bb7cd40c1c4dc9 100644 --- a/tensorflow/compiler/xla/tests/reshape_test.cc +++ b/tensorflow/compiler/xla/tests/reshape_test.cc @@ -41,326 +41,467 @@ limitations under the License. namespace xla { namespace { -class ReshapeTest : public ClientLibraryTestBase { +// Use a bool parameter to indicate whether to use bfloat16. +class ReshapeTest : public ::testing::WithParamInterface, + public ClientLibraryTestBase { public: + ReshapeTest() { set_use_bfloat16(GetParam()); } + ErrorSpec zero_error_spec_{0.0}; }; // Collapses 2-dimensional pseudo-scalar (single-element array) to 1 dimension. -XLA_TEST_F(ReshapeTest, CollapseTrivial1x1) { +XLA_TEST_P(ReshapeTest, CollapseTrivial1x1) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR2({{1.0}}); - builder.Collapse(/*operand=*/a, /*dimensions=*/{0, 1}); - - ComputeAndCompareR1(&builder, {1.0f}, {}, zero_error_spec_); + Array2D input_array(1, 1); + input_array.Fill(1.0f); + auto input_literal = Literal::CreateR2FromArray2D(input_array); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter", + &builder, ¶meter); + builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); + + auto expected_literal = Literal::CreateR1({1.0f}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } -XLA_TEST_F(ReshapeTest, CollapseTrivialR1EmptyDims) { +XLA_TEST_P(ReshapeTest, CollapseTrivialR1EmptyDims) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR1({1.0}); - builder.Collapse(/*operand=*/a, /*dimensions=*/{}); - - ComputeAndCompareR1(&builder, {1.0f}, {}, zero_error_spec_); + auto input_literal = Literal::CreateR1({1.0f}); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter", + &builder, ¶meter); + builder.Collapse(/*operand=*/parameter, /*dimensions=*/{}); + + auto expected_literal = Literal::CreateR1({1.0f}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } -XLA_TEST_F(ReshapeTest, CollapseTrivialR1OnlyDim) { +XLA_TEST_P(ReshapeTest, CollapseTrivialR1OnlyDim) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR1({1.0}); - builder.Collapse(/*operand=*/a, /*dimensions=*/{0}); - - ComputeAndCompareR1(&builder, {1.0f}, {}, zero_error_spec_); + auto input_literal = Literal::CreateR1({1.0f}); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter", + &builder, ¶meter); + builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0}); + + auto expected_literal = Literal::CreateR1({1.0f}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } // Collapses 2-dimensional pseudo-scalar (single-element array) to scalar. -XLA_TEST_F(ReshapeTest, SingleElementArrayToScalar) { +XLA_TEST_P(ReshapeTest, SingleElementArrayToScalar) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR2({{1.0}}); - auto reshape = - builder.Reshape(/*operand=*/a, /*dimensions=*/{0, 1}, /*new_sizes=*/{}); + Array2D input_array(1, 1); + input_array.Fill(1.0f); + auto input_literal = Literal::CreateR2FromArray2D(input_array); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter", + &builder, ¶meter); + auto reshape = builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, + /*new_sizes=*/{}); auto new_shape = builder.GetShape(reshape).ConsumeValueOrDie(); - ComputeAndCompareR0(&builder, 1.0f, {}, zero_error_spec_); + auto expected_literal = Literal::CreateR0(1.0f); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } -XLA_TEST_F(ReshapeTest, ScalarToSingleElementArray) { +XLA_TEST_P(ReshapeTest, ScalarToSingleElementArray) { ComputationBuilder builder(client_, TestName()); std::unique_ptr param0_literal = Literal::CreateR0(1.0f); - std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - - auto a = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param0"); - a = builder.Neg(a); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *param0_literal, "param0", + &builder, ¶meter); + auto a = builder.Neg(parameter); auto reshape = builder.Reshape(/*operand=*/a, /*dimensions=*/{}, /*new_sizes=*/{1}); - ComputeAndCompareR1(&builder, {-1.0f}, {param0_data.get()}, - zero_error_spec_); + auto expected_literal = Literal::CreateR1({-1.0f}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } -XLA_TEST_F(ReshapeTest, Trivial0x3) { +// TODO(b/29185393): Make this work with the GPU backend. The GPU backend +// does not handle zero-sized shapes correctly. Failed last on 2017-11-30 +// with an incorrect result rank. +XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Trivial0x3)) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR2FromArray2D(Array2D(0, 3)); - auto result = builder.Collapse(/*operand=*/a, /*dimensions=*/{0, 1}); - - ComputeAndCompareR1(&builder, {}, {}, zero_error_spec_); + Array2D input_array(0, 3); + auto input_literal = Literal::CreateR2FromArray2D(input_array); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); + auto expected_literal = Literal::CreateR1({}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } // TODO(b/29185393): Make this work with the GPU backend. The GPU backend // does not handle zero-sized shapes correctly. Failed last on 2017-05-15 // with an incorrect result rank. -XLA_TEST_F(ReshapeTest, DISABLED_ON_GPU(Trivial0x3WithParameter)) { +XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Trivial0x3WithParameter)) { ComputationBuilder builder(client_, TestName()); std::unique_ptr param0_literal = Literal::CreateR2FromArray2D(Array2D(0, 3)); - std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - - auto a = builder.Parameter(0, ShapeUtil::MakeShape(F32, {0, 3}), "param0"); - auto result = builder.Collapse(/*operand=*/a, /*dimensions=*/{0, 1}); - - ComputeAndCompareR1(&builder, {}, {param0_data.get()}, - zero_error_spec_); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *param0_literal, "param0", + &builder, ¶meter); + builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); + auto expected_literal = Literal::CreateR1({}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } -XLA_TEST_F(ReshapeTest, Trivial3x0) { +// TODO(b/29185393): Make this work with the GPU backend. The GPU backend +// does not handle zero-sized shapes correctly. Failed last on 2017-11-30 +// with an incorrect result rank. +XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Trivial3x0)) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR2FromArray2D(Array2D(3, 0)); - auto result = builder.Collapse(/*operand=*/a, /*dimensions=*/{0, 1}); - - ComputeAndCompareR1(&builder, {}, {}, zero_error_spec_); + Array2D input_array(3, 0); + auto input_literal = Literal::CreateR2FromArray2D(input_array); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); + auto expected_literal = Literal::CreateR1({}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } // Collapses a 2-dimensional row vector to 1 dimension. -XLA_TEST_F(ReshapeTest, Trivial1x3) { +XLA_TEST_P(ReshapeTest, Trivial1x3) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR2({{1.0f, 2.0f, 3.0f}}); - auto result = builder.Collapse(/*operand=*/a, /*dimensions=*/{0, 1}); - - ComputeAndCompareR1(&builder, {1.0f, 2.0f, 3.0f}, {}, - zero_error_spec_); + auto input_literal = Literal::CreateR2({{1.0f, 2.0f, 3.0f}}); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); + auto expected_literal = Literal::CreateR1({1.0f, 2.0f, 3.0f}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } // Collapses a 2-dimensional column vector to 1 dimension. -XLA_TEST_F(ReshapeTest, Trivial3x1) { +XLA_TEST_P(ReshapeTest, Trivial3x1) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR2({{1.0f}, {2.0f}, {3.0f}}); - auto result = builder.Collapse(/*operand=*/a, /*dimensions=*/{0, 1}); - - ComputeAndCompareR1(&builder, {1.0f, 2.0f, 3.0f}, {}, - zero_error_spec_); + auto input_literal = Literal::CreateR2({{1.0f}, {2.0f}, {3.0f}}); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); + auto expected_literal = Literal::CreateR1({1.0f, 2.0f, 3.0f}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } +// TODO(b/29185393): Make this work with the GPU backend. The GPU backend +// does not handle zero-sized shapes correctly. Failed last on 2017-11-30 +// with an incorrect result rank. +// // Splits an empty vector into an empty matrix. -XLA_TEST_F(ReshapeTest, R1ToR2_0_To_2x0) { +XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(R1ToR2_0_To_2x0)) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR1({}); - auto result = - builder.Reshape(/*operand=*/a, /*dimensions=*/{0}, /*new_sizes=*/{2, 0}); - ComputeAndCompareR2(&builder, Array2D(2, 0), {}, - zero_error_spec_); + auto input_literal = Literal::CreateR1({}); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0}, + /*new_sizes=*/{2, 0}); + auto expected_literal = Literal::CreateR2({{}, {}}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } // Splits a vector into a matrix. -XLA_TEST_F(ReshapeTest, R1ToR2_6_To_2x3) { +XLA_TEST_P(ReshapeTest, R1ToR2_6_To_2x3) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR1({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}); - auto result = - builder.Reshape(/*operand=*/a, /*dimensions=*/{0}, /*new_sizes=*/{2, 3}); - Array2D expected_2x3({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}); - ComputeAndCompareR2(&builder, expected_2x3, {}, zero_error_spec_); + auto input_literal = + Literal::CreateR1({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0}, + /*new_sizes=*/{2, 3}); + auto expected_literal = + Literal::CreateR2({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } +// TODO(b/29185393): Make this work with the GPU backend. The GPU backend +// does not handle zero-sized shapes correctly. Failed last on 2017-11-30 +// with an incorrect result rank. +// // Transposes a 2x0 array to a 0x2 array. -XLA_TEST_F(ReshapeTest, Reshape0x2To2x0) { +XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Reshape0x2To2x0)) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR2FromArray2D(Array2D(0, 2)); - auto result = builder.Reshape(/*operand=*/a, /*dimensions=*/{0, 1}, - /*new_sizes=*/{2, 0}); - - ComputeAndCompareR2(&builder, Array2D(2, 0), {}, - zero_error_spec_); + auto input_literal = Literal::CreateFromArray(Array2D(0, 2)); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, + /*new_sizes=*/{2, 0}); + auto expected_literal = Literal::CreateR2({{}, {}}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } // Transposes a 2-dimensional row vector to a column vector. -XLA_TEST_F(ReshapeTest, ReshapeRowToCol) { +XLA_TEST_P(ReshapeTest, ReshapeRowToCol) { ComputationBuilder builder(client_, TestName()); auto simple = MakeLinspaceArray2D(1.0f, 3.0f, 1, 3); - auto a = builder.ConstantR2FromArray2D(*simple); - auto result = builder.Reshape(/*operand=*/a, /*dimensions=*/{0, 1}, - /*new_sizes=*/{3, 1}); + auto input_literal = Literal::CreateFromArray(*simple); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, + /*new_sizes=*/{3, 1}); auto expected = ReferenceUtil::TransposeArray2D(*simple); - ComputeAndCompareR2(&builder, *expected, {}, zero_error_spec_); + auto expected_literal = Literal::CreateFromArray(*expected); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } // Transposes a 2-dimensional array. -XLA_TEST_F(ReshapeTest, TransposeAsReshape) { +XLA_TEST_P(ReshapeTest, TransposeAsReshape) { ComputationBuilder builder(client_, TestName()); auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); - auto a = builder.ConstantR2FromArray2D(*a4x3); - auto result = builder.Reshape(/*operand=*/a, /*dimensions=*/{1, 0}, - /*new_sizes=*/{3, 4}); - - auto expected3x4 = ReferenceUtil::TransposeArray2D(*a4x3); - ComputeAndCompareR2(&builder, *expected3x4, {}, zero_error_spec_); + auto input_literal = Literal::CreateFromArray(*a4x3); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0}, + /*new_sizes=*/{3, 4}); + + auto expected = ReferenceUtil::TransposeArray2D(*a4x3); + auto expected_literal = Literal::CreateFromArray(*expected); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } +// TODO(b/29185393): Make this work with the GPU backend. The GPU backend +// does not handle zero-sized shapes correctly. Failed last on 2017-11-30 +// with an incorrect result rank. +// // Transposes a 0x4 array with ComputationBuilder::Trans. -XLA_TEST_F(ReshapeTest, Transpose0x4) { +XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Transpose0x4)) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR2FromArray2D(Array2D(0, 4)); - auto result = builder.Transpose(a, {1, 0}); - - ComputeAndCompareR2(&builder, Array2D(4, 0), {}, - zero_error_spec_); + auto input_literal = Literal::CreateFromArray(Array2D(0, 4)); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Transpose(parameter, {1, 0}); + auto expected_literal = Literal::CreateR2({{}, {}, {}, {}}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } // Transposes a 2-dimensional array with ComputationBuilder::Trans. -XLA_TEST_F(ReshapeTest, Transpose4x3) { +XLA_TEST_P(ReshapeTest, Transpose4x3) { ComputationBuilder builder(client_, TestName()); auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); - auto a = builder.ConstantR2FromArray2D(*a4x3); - auto result = builder.Transpose(a, {1, 0}); - - auto expected3x4 = ReferenceUtil::TransposeArray2D(*a4x3); - ComputeAndCompareR2(&builder, *expected3x4, {}, zero_error_spec_); + auto input_literal = Literal::CreateFromArray(*a4x3); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Transpose(parameter, {1, 0}); + + auto expected = ReferenceUtil::TransposeArray2D(*a4x3); + auto expected_literal = Literal::CreateFromArray(*expected); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } +// TODO(b/29185393): Make this work with the GPU backend. The GPU backend +// does not handle zero-sized shapes correctly. Failed last on 2017-11-30 +// with an incorrect result rank. +// // Reshapes an empty 2-dimensional array with dimensions that are not just a // rearrangement of the originals (split), but no reordering (no shuffle). -XLA_TEST_F(ReshapeTest, ReshapeSplitNoShuffleZeroElements) { +XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(ReshapeSplitNoShuffleZeroElements)) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR2FromArray2D(Array2D(6, 0)); - auto result = builder.Reshape(/*operand=*/a, /*dimensions=*/{0, 1}, - /*new_sizes=*/{2, 3, 0, 0}); - - ComputeAndCompareR4(&builder, Array4D(2, 3, 0, 0), {}, - zero_error_spec_); + auto input_literal = Literal::CreateFromArray(Array2D(6, 0)); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, + /*new_sizes=*/{2, 3, 0, 0}); + auto expected_literal = Literal::CreateFromArray(Array4D(2, 3, 0, 0)); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } -XLA_TEST_F(ReshapeTest, ReshapeR4ToR2ZeroElements) { +// TODO(b/29185393): Make this work with the GPU backend. The GPU backend +// does not handle zero-sized shapes correctly. Failed last on 2017-11-30 +// with an incorrect result rank. +XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(ReshapeR4ToR2ZeroElements)) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR4FromArray4D(Array4D(2, 3, 4, 0)); - auto result = builder.Reshape(/*operand=*/a, /*dimensions=*/{0, 1, 2, 3}, - /*new_sizes=*/{24, 0}); - - ComputeAndCompareR2(&builder, Array2D(24, 0), {}, - zero_error_spec_); + auto input_literal = Literal::CreateFromArray(Array4D(2, 3, 4, 0)); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2, 3}, + /*new_sizes=*/{24, 0}); + auto expected_literal = Literal::CreateFromArray(Array2D(24, 0)); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } // Reshapes a 2-dimensional array with dimensions that are not just a // rearrangement of the originals (split), but no reordering (no shuffle). -XLA_TEST_F(ReshapeTest, ReshapeSplitNoShuffle) { +XLA_TEST_P(ReshapeTest, ReshapeSplitNoShuffle) { ComputationBuilder builder(client_, TestName()); auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); - auto a = builder.ConstantR2FromArray2D(*a4x3); - auto result = builder.Reshape(/*operand=*/a, /*dimensions=*/{0, 1}, - /*new_sizes=*/{2, 6}); - - auto expected2x6 = MakeLinspaceArray2D(1.0f, 12.0f, 2, 6); - ComputeAndCompareR2(&builder, *expected2x6, {}, zero_error_spec_); + auto input_literal = Literal::CreateFromArray(*a4x3); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, + /*new_sizes=*/{2, 6}); + + auto expected = MakeLinspaceArray2D(1.0f, 12.0f, 2, 6); + auto expected_literal = Literal::CreateFromArray(*expected); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } -// Reshapes a 2-dimensional array with dimensions that are not just a -// rearrangement of the originals (split), and reorder the input (shuffle). -XLA_TEST_F(ReshapeTest, ReshapeSplitAndShuffleZeroElements) { +// TODO(b/29185393): Make this work with the GPU backend. The GPU backend +// does not handle zero-sized shapes correctly. Failed last on 2017-11-30 +// with an incorrect result rank. +// +XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(ReshapeSplitAndShuffleZeroElements)) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR2FromArray2D(Array2D(0, 6)); - auto result = builder.Reshape(/*operand=*/a, /*dimensions=*/{1, 0}, - /*new_sizes=*/{3, 0}); - - ComputeAndCompareR2(&builder, Array2D(3, 0), {}, - zero_error_spec_); + auto input_literal = Literal::CreateFromArray(Array2D(0, 6)); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0}, + /*new_sizes=*/{3, 0}); + auto expected_literal = Literal::CreateFromArray(Array2D(3, 0)); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } // Reshapes a 2-dimensional array with dimensions that are not just a // rearrangement of the originals (split), and reorder the input (shuffle). -XLA_TEST_F(ReshapeTest, ReshapeSplitAndShuffle) { +XLA_TEST_P(ReshapeTest, ReshapeSplitAndShuffle) { ComputationBuilder builder(client_, TestName()); auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); - auto a = builder.ConstantR2FromArray2D(*a4x3); - auto result = builder.Reshape(/*operand=*/a, /*dimensions=*/{1, 0}, - /*new_sizes=*/{2, 6}); - - Array2D expected2x6({{1.0f, 4.0f, 7.0f, 10.0f, 2.0f, 5.0f}, - {8.0f, 11.0f, 3.0f, 6.0f, 9.0f, 12.0f}}); - ComputeAndCompareR2(&builder, expected2x6, {}, zero_error_spec_); + auto input_literal = Literal::CreateFromArray(*a4x3); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0}, + /*new_sizes=*/{2, 6}); + Array2D expected({{1.0f, 4.0f, 7.0f, 10.0f, 2.0f, 5.0f}, + {8.0f, 11.0f, 3.0f, 6.0f, 9.0f, 12.0f}}); + auto expected_literal = Literal::CreateFromArray(expected); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } // The following tests use the same input 3D array; they test the examples we // show for the Reshape operation in the operation_semantics document. // TODO(b/34503277): find a way to show this code in the documentation without // duplication on the TF documentation server. -Array3D v_array_for_doc_R3_tests({{{10, 11, 12}, {15, 16, 17}}, - {{20, 21, 22}, {25, 26, 27}}, - {{30, 31, 32}, {35, 36, 37}}, - {{40, 41, 42}, {45, 46, 47}}}); - -XLA_TEST_F(ReshapeTest, DocR3_R1_Collapse_012) { - ComputationBuilder builder(client_, TestName()); - auto v = builder.ConstantR3FromArray3D(v_array_for_doc_R3_tests); - auto result = builder.Reshape(/*operand=*/v, /*dimensions=*/{0, 1, 2}, - /*new_sizes=*/{24}); - ComputeAndCompareR1(&builder, - {10, 11, 12, 15, 16, 17, 20, 21, 22, 25, 26, 27, - 30, 31, 32, 35, 36, 37, 40, 41, 42, 45, 46, 47}, - {}); -} - -XLA_TEST_F(ReshapeTest, DocR3_R2_Collapse_012_Refine_83) { - ComputationBuilder builder(client_, TestName()); - auto v = builder.ConstantR3FromArray3D(v_array_for_doc_R3_tests); - auto result = builder.Reshape(/*operand=*/v, /*dimensions=*/{0, 1, 2}, - /*new_sizes=*/{8, 3}); - Array2D expected({{10, 11, 12}, - {15, 16, 17}, - {20, 21, 22}, - {25, 26, 27}, - {30, 31, 32}, - {35, 36, 37}, - {40, 41, 42}, - {45, 46, 47}}); - ComputeAndCompareR2(&builder, expected, {}); -} - -XLA_TEST_F(ReshapeTest, DocR3_R1_Collapse_120) { - ComputationBuilder builder(client_, TestName()); - auto v = builder.ConstantR3FromArray3D(v_array_for_doc_R3_tests); - auto result = builder.Reshape(/*operand=*/v, /*dimensions=*/{1, 2, 0}, - /*new_sizes=*/{24}); - ComputeAndCompareR1(&builder, - {10, 20, 30, 40, 11, 21, 31, 41, 12, 22, 32, 42, - 15, 25, 35, 45, 16, 26, 36, 46, 17, 27, 37, 47}, - {}); -} - -XLA_TEST_F(ReshapeTest, DocR3_R2_Collapse_120_Refine_83) { - ComputationBuilder builder(client_, TestName()); - auto v = builder.ConstantR3FromArray3D(v_array_for_doc_R3_tests); - auto result = builder.Reshape(/*operand=*/v, /*dimensions=*/{1, 2, 0}, - /*new_sizes=*/{8, 3}); - Array2D expected({{10, 20, 30}, - {40, 11, 21}, - {31, 41, 12}, - {22, 32, 42}, - {15, 25, 35}, - {45, 16, 26}, - {36, 46, 17}, - {27, 37, 47}}); - ComputeAndCompareR2(&builder, expected, {}); -} - -XLA_TEST_F(ReshapeTest, DocR3_R3_Collapse_120_Refine_262) { - ComputationBuilder builder(client_, TestName()); - auto v = builder.ConstantR3FromArray3D(v_array_for_doc_R3_tests); - auto result = builder.Reshape(/*operand=*/v, /*dimensions=*/{1, 2, 0}, - /*new_sizes=*/{2, 6, 2}); - Array3D expected( +static Array3D ArrayForDocR3Tests() { + return Array3D({{{10, 11, 12}, {15, 16, 17}}, + {{20, 21, 22}, {25, 26, 27}}, + {{30, 31, 32}, {35, 36, 37}}, + {{40, 41, 42}, {45, 46, 47}}}); +} + +XLA_TEST_P(ReshapeTest, DocR3_R1_Collapse_012) { + ComputationBuilder builder(client_, TestName()); + auto input_literal = Literal::CreateFromArray(ArrayForDocR3Tests()); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2}, + /*new_sizes=*/{24}); + auto expected_literal = Literal::CreateR1( + {10, 11, 12, 15, 16, 17, 20, 21, 22, 25, 26, 27, + 30, 31, 32, 35, 36, 37, 40, 41, 42, 45, 46, 47}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); +} + +XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_012_Refine_83) { + ComputationBuilder builder(client_, TestName()); + auto input_literal = Literal::CreateFromArray(ArrayForDocR3Tests()); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2}, + /*new_sizes=*/{8, 3}); + auto expected_literal = Literal::CreateR2({{10, 11, 12}, + {15, 16, 17}, + {20, 21, 22}, + {25, 26, 27}, + {30, 31, 32}, + {35, 36, 37}, + {40, 41, 42}, + {45, 46, 47}}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); +} + +XLA_TEST_P(ReshapeTest, DocR3_R1_Collapse_120) { + ComputationBuilder builder(client_, TestName()); + auto input_literal = Literal::CreateFromArray(ArrayForDocR3Tests()); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0}, + /*new_sizes=*/{24}); + auto expected_literal = Literal::CreateR1( + {10, 20, 30, 40, 11, 21, 31, 41, 12, 22, 32, 42, + 15, 25, 35, 45, 16, 26, 36, 46, 17, 27, 37, 47}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); +} + +XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_120_Refine_83) { + ComputationBuilder builder(client_, TestName()); + auto input_literal = Literal::CreateFromArray(ArrayForDocR3Tests()); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0}, + /*new_sizes=*/{8, 3}); + auto expected_literal = Literal::CreateR2({{10, 20, 30}, + {40, 11, 21}, + {31, 41, 12}, + {22, 32, 42}, + {15, 25, 35}, + {45, 16, 26}, + {36, 46, 17}, + {27, 37, 47}}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); +} + +XLA_TEST_P(ReshapeTest, DocR3_R3_Collapse_120_Refine_262) { + ComputationBuilder builder(client_, TestName()); + auto input_literal = Literal::CreateFromArray(ArrayForDocR3Tests()); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0}, + /*new_sizes=*/{2, 6, 2}); + auto expected_literal = Literal::CreateR3( {{{10, 20}, {30, 40}, {11, 21}, {31, 41}, {12, 22}, {32, 42}}, {{15, 25}, {35, 45}, {16, 26}, {36, 46}, {17, 27}, {37, 47}}}); - ComputeAndCompareR3(&builder, expected, {}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } // Collapses the low dimensions of a 4D tensor to get a 2D matrix, without @@ -378,23 +519,26 @@ XLA_TEST_F(ReshapeTest, DocR3_R3_Collapse_120_Refine_262) { // Then we collapse Z be collapsed so we just end up with planes: // // 1 2 3 4 5 6 1 2 3 4 5 6 -XLA_TEST_F(ReshapeTest, FullyConnectedCollapse) { +XLA_TEST_P(ReshapeTest, FullyConnectedCollapse) { ComputationBuilder builder(client_, TestName()); Array4D t2x2x2x3(2, 2, 2, 3); auto filler2x3 = MakeLinspaceArray2D(1.0f, 6.0f, 2, 3); t2x2x2x3.FillWithYX(*filler2x3); - auto a = builder.ConstantR4FromArray4D(t2x2x2x3); - auto result = builder.Collapse(/*operand=*/a, /*dimensions=*/{1, 2, 3}); - - Array2D expected2x12( + auto input_literal = Literal::CreateFromArray(t2x2x2x3); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Collapse(/*operand=*/parameter, /*dimensions=*/{1, 2, 3}); + auto expected_literal = Literal::CreateR2( {{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}}); - ComputeAndCompareR2(&builder, expected2x12, {}, zero_error_spec_); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } // As above, but uses reshape directly. -XLA_TEST_F(ReshapeTest, FullyConnectedCollapseDesugared) { +XLA_TEST_P(ReshapeTest, FullyConnectedCollapseDesugared) { ComputationBuilder builder(client_, TestName()); Array4D t(2, 1, 2, 2); t(0, 0, 0, 0) = 0; @@ -405,51 +549,67 @@ XLA_TEST_F(ReshapeTest, FullyConnectedCollapseDesugared) { t(1, 0, 0, 1) = 5; t(1, 0, 1, 0) = 6; t(1, 0, 1, 1) = 7; - auto a = builder.ConstantR4FromArray4D(t); - auto result = builder.Reshape(/*operand=*/a, /*dimensions=*/{0, 1, 2, 3}, - /*new_sizes=*/{2, 4}); - - Array2D expected({{0, 1, 2, 3}, {4, 5, 6, 7}}); - ComputeAndCompareR2(&builder, expected, {}, zero_error_spec_); + auto input_literal = Literal::CreateFromArray(t); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2, 3}, + /*new_sizes=*/{2, 4}); + + auto expected_literal = + Literal::CreateR2({{0, 1, 2, 3}, {4, 5, 6, 7}}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } // Reshape various ranks to a scalar. -XLA_TEST_F(ReshapeTest, ToScalar) { +XLA_TEST_P(ReshapeTest, ToScalar) { for (int rank = 0; rank < 8; ++rank) { ComputationBuilder b(client_, TestName()); - auto input = Literal::CreateR1({83.0f}); + auto input_literal = Literal::CreateR1({83.0f}); std::vector ones(rank, 1); // this is {1, ..., 1}. std::vector dimensions(rank); std::iota(dimensions.begin(), dimensions.end(), 0); - *input->mutable_shape() = ShapeUtil::MakeShape(F32, ones); - b.Reshape(b.ConstantLiteral(*input), dimensions, {}); + *input_literal->mutable_shape() = ShapeUtil::MakeShape(F32, ones); + + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &b, ¶meter); + b.Reshape(parameter, dimensions, {}); - ComputeAndCompareR0(&b, 83.0f, {}, zero_error_spec_); + auto expected_literal = Literal::CreateR0(83.0f); + ComputeAndCompareLiteral(&b, *expected_literal, {input.get()}, + zero_error_spec_); } } -XLA_TEST_F(ReshapeTest, BadDimensions) { +XLA_TEST_P(ReshapeTest, BadDimensions) { ComputationBuilder b(client_, TestName()); - b.Reshape(b.ConstantR1({1}), {}, {}); - EXPECT_THAT(ExecuteToString(&b, {}), - ::testing::HasSubstr("dimensions not a permutation")); + auto input_literal = Literal::CreateR1({1.0f}); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &b, + ¶meter); + b.Reshape(parameter, {}, {}); + EXPECT_THAT( + ExecuteToString(&b, {}), + ::testing::HasSubstr("not a permutation of the operand dimensions")); } -XLA_TEST_F(ReshapeTest, BadNewSizes) { +XLA_TEST_P(ReshapeTest, BadNewSizes) { ComputationBuilder b(client_, TestName()); - b.Reshape(b.ConstantR1({1, 2}), {1}, {}); + auto input_literal = Literal::CreateR1({1.0f, 2.0f}); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &b, + ¶meter); + b.Reshape(parameter, {1}, {}); EXPECT_THAT(ExecuteToString(&b, {}), ::testing::HasSubstr("mismatched element counts")); } -XLA_TEST_F(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) { - const Shape parameter_shape = ShapeUtil::MakeShape(F32, {2, 2, 2, 2}); +XLA_TEST_P(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) { ComputationBuilder builder(client_, TestName()); - auto a = builder.Parameter(0, parameter_shape, "a"); - builder.Reshape(a, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 8}); - // clang-format off - auto literal = Literal::CreateR4FromArray4DWithLayout(Array4D{ + auto input_literal = Literal::CreateR4FromArray4DWithLayout(Array4D{ { { {0, 1}, @@ -473,8 +633,12 @@ XLA_TEST_F(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) { }, LayoutUtil::MakeLayout({0, 1, 2, 3})); // clang-format on - std::unique_ptr input = - client_->TransferToServer(*literal).ConsumeValueOrDie(); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + + builder.Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 8}); + Array2D expected_array({ {0, 1, 2, 3, 100, 101, 102, 103}, {222, 333, 444, 555, 666, 777, 888, 999}, @@ -483,72 +647,75 @@ XLA_TEST_F(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) { Computation computation = builder.Build().ConsumeValueOrDie(); ExecutionOptions execution_options = execution_options_; *execution_options.mutable_shape_with_output_layout() = - ShapeUtil::MakeShapeWithLayout(F32, {2, 8}, {1, 0}); + ShapeUtil::MakeShapeWithLayout(use_bfloat16() ? BF16 : F32, {2, 8}, + {1, 0}); std::unique_ptr actual = client_ ->ExecuteAndTransfer(computation, {input.get()}, &execution_options) .ConsumeValueOrDie(); std::unique_ptr expected = Literal::CreateR2FromArray2D(expected_array); + if (use_bfloat16()) { + expected = LiteralTestUtil::ConvertF32ToBF16(*expected); + } LiteralTestUtil::ExpectEqual(*expected, *actual); } -XLA_TEST_F(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) { - std::unique_ptr input = Literal::CreateR2({ +XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) { + ComputationBuilder builder(client_, TestName()); + std::unique_ptr input_literal = Literal::CreateR2({ {0, 1, 2, 3, 4, 5, 6, 7}, {100, 101, 102, 103, 104, 105, 106, 107}, {200, 201, 202, 203, 204, 205, 206, 207}, }); - std::unique_ptr input_data = - client_->TransferToServer(*input).ConsumeValueOrDie(); - - ComputationBuilder builder(client_, TestName()); - auto a = builder.Parameter(0, input->shape(), "a"); - builder.Reshape(a, /*dimensions=*/{0, 1}, /*new_sizes=*/{3, 2, 1, 4}); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(parameter, /*dimensions=*/{0, 1}, /*new_sizes=*/{3, 2, 1, 4}); // clang-format off - Array4D expected = { + auto expected_literal = Literal::CreateR4({ {{{0, 1, 2, 3}}, {{4, 5, 6, 7}}}, {{{100, 101, 102, 103}}, {{104, 105, 106, 107}}}, {{{200, 201, 202, 203}}, {{204, 205, 206, 207}}} - }; + }); // clang-format on - ComputeAndCompareR4(&builder, expected, {input_data.get()}, - zero_error_spec_); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } // Tests R2->R4 reshape with the reshape dimensions {1, 0}. -XLA_TEST_F(ReshapeTest, R2ToR4_3x8_To_3x2x1x4_Dimensions_10) { - std::unique_ptr input = Literal::CreateR2({ +XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4_Dimensions_10) { + ComputationBuilder builder(client_, TestName()); + std::unique_ptr input_literal = Literal::CreateR2({ {0, 1, 2, 3, 4, 5, 6, 7}, {100, 101, 102, 103, 104, 105, 106, 107}, {200, 201, 202, 203, 204, 205, 206, 207}, }); - std::unique_ptr input_data = - client_->TransferToServer(*input).ConsumeValueOrDie(); - - ComputationBuilder builder(client_, TestName()); - auto a = builder.Parameter(0, input->shape(), "a"); - builder.Reshape(a, /*dimensions=*/{1, 0}, /*new_sizes=*/{3, 2, 1, 4}); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(parameter, /*dimensions=*/{1, 0}, /*new_sizes=*/{3, 2, 1, 4}); // clang-format off - Array4D expected = { + auto expected_literal = Literal::CreateR4({ {{{0, 100, 200, 1}}, {{101, 201, 2, 102}}}, {{{202, 3, 103, 203}}, {{4, 104, 204, 5}}}, {{{105, 205, 6, 106}}, {{206, 7, 107, 207}}} - }; + }); // clang-format on - ComputeAndCompareR4(&builder, expected, {input_data.get()}, - zero_error_spec_); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } -XLA_TEST_F(ReshapeTest, R4ToR2_2x1x1x1_To_2x1) { +XLA_TEST_P(ReshapeTest, R4ToR2_2x1x1x1_To_2x1) { + ComputationBuilder builder(client_, TestName()); std::mt19937 rng; std::uniform_real_distribution distribution; Array4D input(2, 1, 1, 1); @@ -558,12 +725,10 @@ XLA_TEST_F(ReshapeTest, R4ToR2_2x1x1x1_To_2x1) { std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); - std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); - - ComputationBuilder builder(client_, TestName()); - auto a = builder.Parameter(0, input_literal->shape(), "a"); - builder.Reshape(a, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 1}); + ComputationDataHandle parameter; + auto input_data = CreateParameterAndTransferLiteral( + 0, *input_literal, "input", &builder, ¶meter); + builder.Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 1}); std::unique_ptr expected = LiteralTestUtil::Reshape({2, 1}, {1, 0}, *input_literal); @@ -571,7 +736,8 @@ XLA_TEST_F(ReshapeTest, R4ToR2_2x1x1x1_To_2x1) { zero_error_spec_); } -XLA_TEST_F(ReshapeTest, R4ToR2_2x1x4x1_To_4x2) { +XLA_TEST_P(ReshapeTest, R4ToR2_2x1x4x1_To_4x2) { + ComputationBuilder builder(client_, TestName()); std::mt19937 rng; std::uniform_real_distribution distribution; Array4D input(2, 1, 4, 1); @@ -581,12 +747,10 @@ XLA_TEST_F(ReshapeTest, R4ToR2_2x1x4x1_To_4x2) { std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); - std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); - - ComputationBuilder builder(client_, TestName()); - auto a = builder.Parameter(0, input_literal->shape(), "a"); - builder.Reshape(a, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{4, 2}); + ComputationDataHandle parameter; + auto input_data = CreateParameterAndTransferLiteral( + 0, *input_literal, "input", &builder, ¶meter); + builder.Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{4, 2}); std::unique_ptr expected = LiteralTestUtil::Reshape({4, 2}, {1, 0}, *input_literal); @@ -595,7 +759,8 @@ XLA_TEST_F(ReshapeTest, R4ToR2_2x1x4x1_To_4x2) { } // Tests R4->R2 reshape with the reshape dimensions {0, 2, 1, 3}. -XLA_TEST_F(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) { +XLA_TEST_P(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) { + ComputationBuilder builder(client_, TestName()); std::mt19937 rng; std::uniform_real_distribution distribution; Array4D input(5, 10, 2, 3); @@ -605,12 +770,11 @@ XLA_TEST_F(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) { std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); - std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); - - ComputationBuilder builder(client_, TestName()); - auto a = builder.Parameter(0, input_literal->shape(), "a"); - builder.Reshape(a, /*dimensions=*/{0, 2, 1, 3}, /*new_sizes=*/{5, 60}); + ComputationDataHandle parameter; + auto input_data = CreateParameterAndTransferLiteral( + 0, *input_literal, "input", &builder, ¶meter); + builder.Reshape(parameter, /*dimensions=*/{0, 2, 1, 3}, + /*new_sizes=*/{5, 60}); Array2D expected_array(5, 60); input.Each([&](tensorflow::gtl::ArraySlice indices, float* cell) { @@ -618,10 +782,12 @@ XLA_TEST_F(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) { *cell; }); auto expected = Literal::CreateR2FromArray2D(expected_array); - ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}); + ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}, + zero_error_spec_); } -XLA_TEST_F(ReshapeTest, NoopReshape) { +XLA_TEST_P(ReshapeTest, NoopReshape) { + ComputationBuilder builder(client_, TestName()); std::mt19937 rng; std::uniform_real_distribution distribution; Array4D input_array(2, 3, 5, 7); @@ -631,18 +797,17 @@ XLA_TEST_F(ReshapeTest, NoopReshape) { std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input_array, LayoutUtil::MakeLayout({1, 2, 3, 0})); - std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); - - ComputationBuilder builder(client_, TestName()); - auto input = builder.Parameter(0, input_literal->shape(), "input"); - builder.Reshape(input, /*dimensions=*/{3, 0, 1, 2}, + ComputationDataHandle parameter; + auto input_data = CreateParameterAndTransferLiteral( + 0, *input_literal, "input", &builder, ¶meter); + builder.Reshape(parameter, /*dimensions=*/{3, 0, 1, 2}, /*new_sizes=*/{7, 2, 3, 5}); Computation computation = builder.Build().ConsumeValueOrDie(); ExecutionOptions execution_options = execution_options_; *execution_options.mutable_shape_with_output_layout() = - ShapeUtil::MakeShapeWithLayout(F32, {7, 2, 3, 5}, {2, 3, 0, 1}); + ShapeUtil::MakeShapeWithLayout(use_bfloat16() ? BF16 : F32, {7, 2, 3, 5}, + {2, 3, 0, 1}); std::unique_ptr output_literal = client_ ->ExecuteAndTransfer(computation, {input_data.get()}, @@ -651,35 +816,45 @@ XLA_TEST_F(ReshapeTest, NoopReshape) { // Since the reshape is a no-op, verify that it does not change the underlying // data. - EXPECT_EQ(tensorflow::gtl::ArraySlice(input_literal->f32s()), - tensorflow::gtl::ArraySlice(output_literal->f32s())); + if (use_bfloat16()) { + auto expected = LiteralTestUtil::ConvertF32ToBF16(*input_literal); + EXPECT_EQ(tensorflow::gtl::ArraySlice(expected->bf16s()), + tensorflow::gtl::ArraySlice(output_literal->bf16s())); + } else { + EXPECT_EQ(tensorflow::gtl::ArraySlice(input_literal->f32s()), + tensorflow::gtl::ArraySlice(output_literal->f32s())); + } } -XLA_TEST_F(ReshapeTest, R4ToR4Reshape_Trivial) { - auto literal_1x2x3x4 = Literal::CreateR4( +XLA_TEST_P(ReshapeTest, R4ToR4Reshape_Trivial) { + ComputationBuilder builder(client_, TestName()); + auto literal_1x2x3x4 = Literal::CreateR4( {{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, {{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}}); - ComputationBuilder builder(client_, TestName()); - auto input = builder.ConstantLiteral(*literal_1x2x3x4); - builder.Reshape(input, /*dimensions=*/{0, 1, 2, 3}, + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *literal_1x2x3x4, "input", + &builder, ¶meter); + builder.Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{1, 2, 3, 4}); - ComputeAndCompareLiteral(&builder, *literal_1x2x3x4, {}); + ComputeAndCompareLiteral(&builder, *literal_1x2x3x4, {input.get()}); } -XLA_TEST_F(ReshapeTest, R4ToR4Reshape) { - auto literal_1x2x3x4 = Literal::CreateR4( +XLA_TEST_P(ReshapeTest, R4ToR4Reshape) { + auto literal_1x2x3x4 = Literal::CreateR4( {{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, {{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}}); ComputationBuilder builder(client_, TestName()); - auto input = builder.ConstantLiteral(*literal_1x2x3x4); - builder.Reshape(input, /*dimensions=*/{1, 3, 2, 0}, + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *literal_1x2x3x4, "input", + &builder, ¶meter); + builder.Reshape(parameter, /*dimensions=*/{1, 3, 2, 0}, /*new_sizes=*/{2, 4, 3, 1}); // clang-format off - auto expected_2x4x3x1 = Literal::CreateR4( + auto expected_2x4x3x1 = Literal::CreateR4( {{{{1}, {5}, {9}}, {{2}, {6}, {10}}, {{3}, {7}, {11}}, @@ -690,10 +865,10 @@ XLA_TEST_F(ReshapeTest, R4ToR4Reshape) { {{16}, {20}, {24}}}}); // clang-format on - ComputeAndCompareLiteral(&builder, *expected_2x4x3x1, {}); + ComputeAndCompareLiteral(&builder, *expected_2x4x3x1, {input.get()}); } -XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeSimple) { +XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeSimple) { std::mt19937 rng; std::uniform_real_distribution distribution; std::vector bounds = {2, 2, 2, 2}; @@ -705,12 +880,12 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeSimple) { std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); - std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); - ComputationBuilder builder(client_, TestName()); - auto a = builder.Parameter(0, input_literal->shape(), "a"); - builder.Reshape(a, /*dimensions=*/{0, 1, 3, 2}, /*new_sizes=*/new_bounds); + ComputationDataHandle parameter; + auto input_data = CreateParameterAndTransferLiteral( + 0, *input_literal, "input", &builder, ¶meter); + builder.Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, + /*new_sizes=*/new_bounds); std::unique_ptr expected = LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal) @@ -722,7 +897,7 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeSimple) { zero_error_spec_, &expected->shape()); } -XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) { +XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) { std::mt19937 rng; std::uniform_real_distribution distribution; std::vector bounds = {1, 1, 250, 300}; @@ -734,12 +909,12 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) { std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); - std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); - ComputationBuilder builder(client_, TestName()); - auto a = builder.Parameter(0, input_literal->shape(), "a"); - builder.Reshape(a, /*dimensions=*/{0, 1, 3, 2}, /*new_sizes=*/new_bounds); + ComputationDataHandle parameter; + auto input_data = CreateParameterAndTransferLiteral( + 0, *input_literal, "input", &builder, ¶meter); + builder.Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, + /*new_sizes=*/new_bounds); std::unique_ptr expected = LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal) @@ -751,7 +926,7 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) { zero_error_spec_, &expected->shape()); } -XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) { +XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) { std::mt19937 rng; std::uniform_real_distribution distribution; std::vector bounds = {5, 5, 1, 10}; @@ -763,12 +938,12 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) { std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); - std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); - ComputationBuilder builder(client_, TestName()); - auto a = builder.Parameter(0, input_literal->shape(), "a"); - builder.Reshape(a, /*dimensions=*/{0, 1, 3, 2}, /*new_sizes=*/new_bounds); + ComputationDataHandle parameter; + auto input_data = CreateParameterAndTransferLiteral( + 0, *input_literal, "input", &builder, ¶meter); + builder.Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, + /*new_sizes=*/new_bounds); std::unique_ptr expected = LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal) @@ -780,7 +955,7 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) { zero_error_spec_, &expected->shape()); } -XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) { +XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) { std::mt19937 rng; std::uniform_real_distribution distribution; // This happens in NN-Builder MNIST. @@ -793,12 +968,12 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) { std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); - std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); - ComputationBuilder builder(client_, TestName()); - auto a = builder.Parameter(0, input_literal->shape(), "a"); - builder.Reshape(a, /*dimensions=*/{0, 1, 3, 2}, /*new_sizes=*/new_bounds); + ComputationDataHandle parameter; + auto input_data = CreateParameterAndTransferLiteral( + 0, *input_literal, "input", &builder, ¶meter); + builder.Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, + /*new_sizes=*/new_bounds); std::unique_ptr expected = LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal) @@ -810,7 +985,7 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) { zero_error_spec_, &expected->shape()); } -XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeTrivialR2) { +XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeTrivialR2) { std::mt19937 rng; std::uniform_real_distribution distribution; std::vector bounds = {3, 3, 1, 3}; @@ -822,12 +997,12 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeTrivialR2) { std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({0, 1, 2, 3})); - std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); - ComputationBuilder builder(client_, TestName()); - auto a = builder.Parameter(0, input_literal->shape(), "a"); - builder.Reshape(a, /*dimensions=*/{1, 0, 2, 3}, /*new_sizes=*/new_bounds); + ComputationDataHandle parameter; + auto input_data = CreateParameterAndTransferLiteral( + 0, *input_literal, "input", &builder, ¶meter); + builder.Reshape(parameter, /*dimensions=*/{1, 0, 2, 3}, + /*new_sizes=*/new_bounds); std::unique_ptr expected = LiteralTestUtil::Reshape(new_bounds, {1, 0, 2, 3}, *input_literal) @@ -839,5 +1014,12 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeTrivialR2) { zero_error_spec_, &expected->shape()); } +#ifdef XLA_BACKEND_SUPPORTS_BFLOAT16 +INSTANTIATE_TEST_CASE_P(ReshapeTestInstance, ReshapeTest, ::testing::Bool()); +#else +INSTANTIATE_TEST_CASE_P(ReshapeTestInstance, ReshapeTest, + ::testing::ValuesIn(std::vector{false})); +#endif + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/sample_file_test.cc b/tensorflow/compiler/xla/tests/sample_file_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..31b104f4e37f77d47f56ff8183ee1de1cc22e44d --- /dev/null +++ b/tensorflow/compiler/xla/tests/sample_file_test.cc @@ -0,0 +1,51 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This demonstrates how to use hlo_test_base to create a file based testcase +// and compare results on gpu and cpu. + +#include +#include + +#include "tensorflow/compiler/xla/service/platform_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +class SampleFileTest : public HloTestBase { + protected: + SampleFileTest() + : HloTestBase( + /*test_platform=*/PlatformUtil::GetPlatform("gpu").ValueOrDie(), + /*reference_platform=*/PlatformUtil::GetPlatform("cpu") + .ValueOrDie()) {} +}; + +TEST_F(SampleFileTest, Convolution) { + const string& filename = "compiler/xla/tests/isolated_convolution.hlo"; + string test_srcdir = tensorflow::testing::TensorFlowSrcRoot(); + EXPECT_TRUE(RunAndCompareFromFile( + tensorflow::io::JoinPath(test_srcdir, filename), ErrorSpec{0.01})); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/sample_text_test.cc b/tensorflow/compiler/xla/tests/sample_text_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..b4f2b74e3dc9e80f50454b28eb6f2502cef3e681 --- /dev/null +++ b/tensorflow/compiler/xla/tests/sample_text_test.cc @@ -0,0 +1,66 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This demonstrates how to use hlo_test_base to create textual IR based +// testcases. + +#include +#include + +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/gtl/optional.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +using tensorflow::gtl::nullopt; + +class SampleTextTest : public HloTestBase {}; + +TEST_F(SampleTextTest, Axpy) { + const string& hlo_string = R"( +HloModule axpy_module: +ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { + %alpha = f32[] parameter(0) + %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={} + %x = f32[2,4]{1,0} parameter(1) + %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x) + %y = f32[2,4]{1,0} parameter(2) + ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y) +} +)"; + EXPECT_TRUE(RunAndCompareNoHloPasses(hlo_string, ErrorSpec{0.0001})); +} + +TEST_F(SampleTextTest, Tuple) { + const string& hlo_string = R"( +HloModule TupleCreate_module: +ENTRY %TupleCreate.v4 (v1: f32[], v2: f32[3], v3: f32[2,3]) -> (f32[], f32[3], f32[2,3]) { + %v1 = f32[] parameter(0) + %v2 = f32[3]{0} parameter(1) + %v3 = f32[2,3]{1,0} parameter(2) + ROOT %tuple = (f32[], f32[3]{0}, f32[2,3]{1,0}) tuple(f32[] %v1, f32[3]{0} %v2, f32[2,3]{1,0} %v3) +} +)"; + EXPECT_TRUE(RunAndCompare(hlo_string, nullopt)); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/slice_test.cc b/tensorflow/compiler/xla/tests/slice_test.cc index c21124750ad512cad69b1483e708613ee2857ac0..4db566f7841829359ea06fe25408048418c547ad 100644 --- a/tensorflow/compiler/xla/tests/slice_test.cc +++ b/tensorflow/compiler/xla/tests/slice_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -211,6 +212,13 @@ class SliceR1Test : public ClientLibraryTestBase, } }; +string SliceR1TestDataToString(const ::testing::TestParamInfo& data) { + const R1Spec& spec = data.param; + return ::tensorflow::strings::Printf("%lld_%lld_%lld_%lld", spec.input_dim0, + spec.slice_start, spec.slice_limit, + spec.slice_stride); +} + XLA_TEST_P(SliceR1Test, DoIt_F32) { Run(GetParam()); } XLA_TEST_P(SliceR1Test, DoIt_F64) { Run(GetParam()); } @@ -223,30 +231,66 @@ XLA_TEST_P(SliceR1Test, DoIt_U64) { Run(GetParam()); } XLA_TEST_P(SliceR1Test, DoIt_S64) { Run(GetParam()); } -INSTANTIATE_TEST_CASE_P( // - SliceR1TestInstantiation, // - SliceR1Test, // - ::testing::Values( // - R1Spec{10, 0, 0, 1}, // - R1Spec{10, 7, 7, 1}, // - R1Spec{10, 2, 4, 1}, // - R1Spec{10, 2, 4, 2}, // - R1Spec{10, 0, 10, 1}, // - R1Spec{1024, 1024 - 4, 1024, 1}, // - R1Spec{4096, 7, 7 + 1024, 1}, // - R1Spec{10, 0, 10, 2}, // - R1Spec{10, 0, 10, 3}, // - R1Spec{10, 0, 10, 4}, // - R1Spec{10, 0, 10, 5}, // - R1Spec{10, 0, 10, 10}, // - R1Spec{500, 200, 400, 7}, // - R1Spec{4096, 1, 4095, 3}, // - R1Spec{2047, 1024 - 24, 1024 + 160, 31}, // - R1Spec{2047, 1, 2046, 3 * 128}, // - R1Spec{4096, 1024 + 3, 4095, 500}, // - R1Spec{8192, 0, 8192, 1024 * 3 + 400} // - ) // +// Tests for R1 slice ops. +// The format for each testcase is {input size, start, limit, stride}. +// clang-format off +INSTANTIATE_TEST_CASE_P( + SliceR1TestInstantiation, + SliceR1Test, + ::testing::Values( + R1Spec{10, 0, 0, 1}, + R1Spec{10, 7, 7, 1}, + R1Spec{10, 0, 5, 1}, + R1Spec{10, 3, 5, 1}, + R1Spec{10, 0, 10, 1}, + R1Spec{1024, 0, 5, 1}, + R1Spec{1024, 3, 5, 1}, + R1Spec{1024 + 17, 0, 5, 1}, + R1Spec{1024 + 17, 3, 5, 1}, + R1Spec{1024 + 17, 1024, 1024 + 6, 1}, + R1Spec{1024 + 17, 1024 + 1, 1024 + 6, 1}, + R1Spec{1024, 1024 - 4, 1024, 1}, + R1Spec{4 * 1024, 7, 7 + 1024, 1}, + R1Spec{4 * 1024, 0, 4 * 1024, 1}, + R1Spec{4 * 1024, 1, 4 * 1024 - 1, 1}, + R1Spec{4 * 1024, 1024, 3 * 1024, 1}, + R1Spec{4 * 1024, 1024 + 1, 3 * 1024 - 1, 1}, + R1Spec{16 * 1024, 0, 5, 1}, + R1Spec{16 * 1024, 3, 5, 1}, + R1Spec{16 * 1024 + 17, 0, 5, 1}, + R1Spec{16 * 1024 + 17, 3, 5, 1}, + R1Spec{16 * 1024 + 17, 16 * 1024, 16 * 1024 + 6, 1}, + R1Spec{16 * 1024 + 17, 16 * 1024 + 1, 16 * 1024 + 6, 1}, + R1Spec{16 * 1024, 4 * 1024 - 17, 8 * 1024 - 18, 1}, + R1Spec{64 * 1024, 0, 64 * 1024, 1}, + R1Spec{64 * 1024, 1, 64 * 1024 - 1, 1}, + R1Spec{64 * 1024, 1024, 63 * 1024, 1}, + R1Spec{64 * 1024, 1024 + 1, 63 * 1024 - 1, 1}, + R1Spec{64 * 1024, 32 * 1024, 33 * 1024, 1}, + R1Spec{64 * 1024, 32 * 1024 + 1, 33 * 1024 - 1, 1}, + R1Spec{64 * 1024, 32 * 1024 - 17, 36 * 1024 - 18, 1}, +// TODO(b/69425338): This uses too much memory on GPU. +#ifndef XLA_TEST_BACKEND_GPU + R1Spec{16 * 1024 * 1024, 4 * 1024 * 1024, 12 * 1024 * 1024, 1}, + R1Spec{16 * 1024 * 1024, 4 * 1024 * 1024 + 1, 12 * 1024 * 1024 - 1, 1}, + R1Spec{16 * 1024 * 1024, 4 * 1024 * 1024 - 1, 12 * 1024 * 1024 + 1, 1}, +#endif + R1Spec{10, 2, 4, 2}, + R1Spec{10, 0, 10, 2}, + R1Spec{10, 0, 10, 3}, + R1Spec{10, 0, 10, 4}, + R1Spec{10, 0, 10, 5}, + R1Spec{10, 0, 10, 10}, + R1Spec{500, 200, 400, 7}, + R1Spec{4096, 1, 4095, 3}, + R1Spec{2047, 1024 - 24, 1024 + 160, 31}, + R1Spec{2047, 1, 2046, 3 * 128}, + R1Spec{4096, 1024 + 3, 4095, 500}, + R1Spec{8192, 0, 8192, 1024 * 3 + 400} + ), + SliceR1TestDataToString ); +// clang-format on struct R2Spec { int64 input_dim0; diff --git a/tensorflow/compiler/xla/tests/test_macros.cc b/tensorflow/compiler/xla/tests/test_macros.cc index 173fb1b0008c9e6edaa1902a5eb3ca5f054a2a67..978a669bcab720bddec5c4bcd0144810ba3c8477 100644 --- a/tensorflow/compiler/xla/tests/test_macros.cc +++ b/tensorflow/compiler/xla/tests/test_macros.cc @@ -21,12 +21,13 @@ limitations under the License. #include #include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/regexp.h" namespace xla { namespace { // Mapping from test name; i.e. MyTest.MyTestCase to platforms on which it is -// disabled. +// disabled - a sequence of regexps. using ManifestT = std::unordered_map>; ManifestT ReadManifest() { @@ -66,9 +67,6 @@ ManifestT ReadManifest() { string PrependDisabledIfIndicated(const string& test_case_name, const string& test_name) { - // TODO(leary): this code reads the manifest for every test case instantiated - // in every file. Consider switching to a singleton or using a compile-time - // genrule instead. ManifestT manifest = ReadManifest(); // First try full match: test_case_name.test_name @@ -83,11 +81,13 @@ string PrependDisabledIfIndicated(const string& test_case_name, } } + // Expect a full match vs. one of the platform regexps to disable the test. const std::vector& disabled_platforms = it->second; string platform_string = XLA_PLATFORM; - if (std::find(disabled_platforms.begin(), disabled_platforms.end(), - platform_string) != disabled_platforms.end()) { - return "DISABLED_" + test_name; + for (const auto& s : disabled_platforms) { + if (RE2::FullMatch(/*text=*/platform_string, /*re=*/s)) { + return "DISABLED_" + test_name; + } } // We didn't hit in the disabled manifest entries, so don't disable it. diff --git a/tensorflow/compiler/xla/tests/test_macros.h b/tensorflow/compiler/xla/tests/test_macros.h index 3878ac1013ef1459cbe3c92a48fc6149b6a4948e..28a2d0198a707cec1aa5e0fbed341ee9b2a927f7 100644 --- a/tensorflow/compiler/xla/tests/test_macros.h +++ b/tensorflow/compiler/xla/tests/test_macros.h @@ -66,8 +66,10 @@ limitations under the License. namespace xla { -// Reads a disabled manifest file (and retains it as a singleton) to resolve -// whether test cases should be disabled on a particular platform. +// Reads a disabled manifest file to resolve whether test cases should be +// disabled on a particular platform. For a test that should be disabled, +// returns DISABLED_ prepended to its name; otherwise returns the test name +// unmodified. string PrependDisabledIfIndicated(const string& test_case_name, const string& test_name); @@ -96,7 +98,8 @@ string PrependDisabledIfIndicated(const string& test_case_name, test_name)::test_info_ = \ ::testing::internal::MakeAndRegisterTestInfo( \ #test_case_name, \ - PrependDisabledIfIndicated(#test_case_name, #test_name).c_str(), \ + ::xla::PrependDisabledIfIndicated(#test_case_name, #test_name) \ + .c_str(), \ nullptr, nullptr, \ ::testing::internal::CodeLocation(__FILE__, __LINE__), (parent_id), \ parent_class::SetUpTestCase, parent_class::TearDownTestCase, \ @@ -135,7 +138,8 @@ string PrependDisabledIfIndicated(const string& test_case_name, ::testing::internal::CodeLocation(__FILE__, __LINE__)) \ ->AddTestPattern( \ #test_case_name, \ - PrependDisabledIfIndicated(#test_case_name, #test_name).c_str(), \ + ::xla::PrependDisabledIfIndicated(#test_case_name, #test_name) \ + .c_str(), \ new ::testing::internal::TestMetaFactory()); \ return 0; \ diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc new file mode 100644 index 0000000000000000000000000000000000000000..780b292d1a9b819f0f37e959cdec019f03b4a595 --- /dev/null +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -0,0 +1,259 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/tests/test_utils.h" +#include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" +#include "tensorflow/compiler/xla/service/hlo_verifier.h" +#include "tensorflow/compiler/xla/service/transfer_manager.h" + +namespace xla { + +namespace { + +template +void PopulateWithRandomFloatingPointData(Literal* literal) { + CHECK_EQ(literal->shape().element_type(), + primitive_util::NativeToPrimitiveType()); + std::minstd_rand0 engine; + std::uniform_real_distribution generator(0.0f, 1.0f); + TF_CHECK_OK(literal->Populate( + [&](tensorflow::gtl::ArraySlice /*indices*/) { + return generator(engine); + })); +} + +// The standard library does not have a case for bfloat16, unsurprisingly, so we +// handle that one specially. +template <> +void PopulateWithRandomFloatingPointData(Literal* literal) { + CHECK_EQ(literal->shape().element_type(), BF16); + std::minstd_rand0 engine; + std::uniform_real_distribution generator(0.0f, 1.0f); + TF_CHECK_OK(literal->Populate( + [&](tensorflow::gtl::ArraySlice /*indices*/) { + return static_cast(generator(engine)); + })); +} + +template +void PopulateWithRandomIntegralData(Literal* literal) { + CHECK_EQ(literal->shape().element_type(), + primitive_util::NativeToPrimitiveType()); + std::minstd_rand0 engine; + std::uniform_int_distribution generator( + std::numeric_limits::lowest(), std::numeric_limits::max()); + TF_CHECK_OK(literal->Populate( + [&](tensorflow::gtl::ArraySlice /*indices*/) { + return generator(engine); + })); +} + +// Matches binary addition computations. +bool LooksLikeSum(const HloComputation& computation) { + const HloInstruction* const root = computation.root_instruction(); + return root->opcode() == HloOpcode::kAdd && + computation.num_parameters() == 2 && + root->operand(0)->opcode() == HloOpcode::kParameter && + root->operand(1)->opcode() == HloOpcode::kParameter && + root->operand(0) != root->operand(1); +} + +// Reduce, ReduceWindow, and SelectAndScatter ops may use binary addition, +// which requires an init_value of 0 rather than a random value. +bool NeedsZeroInitValue(const HloUse& use) { + const HloInstruction* const instruction = use.instruction; + const HloOpcode opcode = instruction->opcode(); + const int64 op_num = use.operand_number; + return ( + ((opcode == HloOpcode::kReduce || opcode == HloOpcode::kReduceWindow) && + op_num == 1 && LooksLikeSum(*instruction->to_apply())) || + (opcode == HloOpcode::kSelectAndScatter && op_num == 2 && + LooksLikeSum(*instruction->scatter()))); +} + +// Generate random values that are constrained to the input_shape minus the +// output_shape so as not to produce wrapping slices, for instance. +std::unique_ptr MakeRandomNonwrappingSliceIndex( + const Shape& input_shape, const Shape& slice_shape) { + const int64 rank = ShapeUtil::Rank(input_shape); + std::vector start_indices(rank); + std::minstd_rand0 engine; + for (int i = 0; i < rank; ++i) { + const int32 upper_bound = ShapeUtil::GetDimension(input_shape, i) - + ShapeUtil::GetDimension(slice_shape, i); + std::uniform_int_distribution generator(0, upper_bound); + start_indices[i] = generator(engine); + } + return Literal::CreateR1(start_indices); +} + +// Use dataflow analysis on each parameter to see if there are uses that would +// be problematic when generating input data. Returns the list of instructions +// that correspond to their uses. +// +// Should be paired with the CreateLiteralForConstrainedUses() function below. +std::vector FindConstrainedUses( + const HloDataflowAnalysis& dataflow, const HloInstruction& param) { + std::vector constrained_uses; + for (const auto& pair : dataflow.GetInstructionValueSet(¶m)) { + const HloValue& value = dataflow.GetUniqueValueAt(¶m, pair.first); + for (const HloUse& use : value.uses()) { + HloInstruction* instruction = use.instruction; + const HloOpcode opcode = instruction->opcode(); + const int64 op_num = use.operand_number; + if ((opcode == HloOpcode::kDynamicSlice && op_num == 1) || + (opcode == HloOpcode::kDynamicUpdateSlice && op_num == 2)) { + constrained_uses.push_back(instruction); + } else if (opcode == HloOpcode::kFusion) { + const HloInstruction* const to_analyze = + instruction->fused_parameter(op_num); + auto fused_uses = FindConstrainedUses(dataflow, *to_analyze); + constrained_uses.insert(constrained_uses.end(), fused_uses.begin(), + fused_uses.end()); + } else if (NeedsZeroInitValue(use)) { + constrained_uses.push_back(instruction); + } + } + } + return constrained_uses; +} + +// Given a parameter, generate a random Literal to use as input if there exist +// no constrained uses in the dataflow graph. If such constraints exist, +// generate a constrained literal (either bounded in the case of indices, or +// zero in the case of init_values for reductions). +StatusOr> CreateLiteralForConstrainedUses( + const tensorflow::gtl::ArraySlice constrained_uses, + const HloInstruction& param) { + const auto count = constrained_uses.size(); + if (count > 1) { + return Unimplemented("multiple constrained uses not yet supported"); + } + + if (count == 0) { + return MakeFakeLiteral(param.shape()); + } + + const HloInstruction* const use = constrained_uses[0]; + switch (use->opcode()) { + case HloOpcode::kDynamicSlice: + case HloOpcode::kDynamicUpdateSlice: + return MakeRandomNonwrappingSliceIndex(use->operand(0)->shape(), + use->shape()); + case HloOpcode::kReduce: + case HloOpcode::kReduceWindow: + case HloOpcode::kSelectAndScatter: + return Literal::CreateFromShape(param.shape()); + default: + return Unimplemented("constrained use given; no equivalent literal"); + } +} + +// Given a module entry parameter, use the dataflow analysis to see if a +// special case literal must be created, or if we can generate fake data. +StatusOr> MakeConstrainedArgument( + const HloDataflowAnalysis& dataflow, const HloInstruction& param) { + const auto constrained_uses = FindConstrainedUses(dataflow, param); + return CreateLiteralForConstrainedUses(constrained_uses, param); +} + +} // namespace + +StatusOr> MakeFakeLiteral(const Shape& shape) { + if (ShapeUtil::IsTuple(shape)) { + std::vector> elements; + for (const Shape& element_shape : shape.tuple_shapes()) { + TF_ASSIGN_OR_RETURN(std::unique_ptr element, + MakeFakeLiteral(element_shape)); + elements.push_back(std::move(element)); + } + return Literal::MakeTupleOwned(std::move(elements)); + } + std::unique_ptr literal = Literal::CreateFromShape(shape); + switch (shape.element_type()) { + case BF16: + PopulateWithRandomFloatingPointData(literal.get()); + break; + case F32: + PopulateWithRandomFloatingPointData(literal.get()); + break; + case F64: + PopulateWithRandomFloatingPointData(literal.get()); + break; + case S8: + PopulateWithRandomIntegralData(literal.get()); + break; + case U8: + PopulateWithRandomIntegralData(literal.get()); + break; + case S16: + PopulateWithRandomIntegralData(literal.get()); + break; + case U16: + PopulateWithRandomIntegralData(literal.get()); + break; + case S32: + PopulateWithRandomIntegralData(literal.get()); + break; + case U32: + PopulateWithRandomIntegralData(literal.get()); + break; + case S64: + PopulateWithRandomIntegralData(literal.get()); + break; + case U64: + PopulateWithRandomIntegralData(literal.get()); + break; + case PRED: { + std::uniform_int_distribution generator(0, 1); + std::minstd_rand0 engine; + TF_CHECK_OK(literal->Populate( + [&](tensorflow::gtl::ArraySlice /*indices*/) { + return generator(engine); + })); + break; + } + default: + return Unimplemented("Unsupported type for fake literal generation: %s", + ShapeUtil::HumanString(shape).c_str()); + } + return std::move(literal); +} + +StatusOr>> MakeFakeArguments( + HloModule* const module) { + TF_ASSIGN_OR_RETURN(auto dataflow, HloDataflowAnalysis::Run(module)); + const auto params = module->entry_computation()->parameter_instructions(); + std::vector> arguments(params.size()); + for (int i = 0; i < params.size(); ++i) { + TF_ASSIGN_OR_RETURN(arguments[i], + MakeConstrainedArgument(*dataflow, *params[i])); + } + return std::move(arguments); +} + +Status VerifyHloModule(const perftools::gputools::Platform& platform, + HloModule* const module) { + return HloVerifier( + std::bind( + &TransferManager::GetByteSizeRequirement, + TransferManager::GetForPlatform(&platform).ConsumeValueOrDie(), + std::placeholders::_1)) + .Run(module) + .status(); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/test_utils.h b/tensorflow/compiler/xla/tests/test_utils.h index f3a522b05ebae4f1f86d6d7ddbac6e1749d3e286..0fb024ffb074f1c90b75022bc7f5a8b58b03c0c2 100644 --- a/tensorflow/compiler/xla/tests/test_utils.h +++ b/tensorflow/compiler/xla/tests/test_utils.h @@ -23,12 +23,13 @@ limitations under the License. #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/stream_executor/platform.h" namespace xla { -namespace test_utils { // A class which generates pseudorandom numbers of a given type within a given // range. Not cryptographically secure and likely not perfectly evenly @@ -53,63 +54,23 @@ class PseudorandomGenerator { std::mt19937 generator_; }; -// Convenience function for creating a rank-2 array with arbitrary layout. -template -std::unique_ptr CreateR2LiteralWithLayout( - std::initializer_list> values, - tensorflow::gtl::ArraySlice minor_to_major) { - auto literal = MakeUnique(); - const int64 d0 = values.size(); - const int64 d1 = values.begin()->size(); - literal.get()->PopulateWithValue(0, {d0, d1}); - *literal->mutable_shape()->mutable_layout() = - LayoutUtil::MakeLayout(minor_to_major); - TF_CHECK_OK(ShapeUtil::ValidateShape(literal->shape())); +// Generates fake data in a literal of the given shape, or returns an error +// status if the element type is currently unhandled for fake data generation. +StatusOr> MakeFakeLiteral(const Shape& shape); - int64 dim0 = 0; - for (auto inner_list : values) { - int64 dim1 = 0; - for (auto value : inner_list) { - literal.get()->Set({dim0, dim1}, value); - ++dim1; - } - ++dim0; - } - return literal; -} +// Generates a vector of arguments containing fake data. The number, shape and +// layout of the arguments is appropriate for given HLO module. +// +// Will handle special cases such as making sure that indices used for dynamic +// slices are bounded, reduces that call adds use 0 as an init value, etc. +StatusOr>> MakeFakeArguments( + HloModule* const module); -// Convenience function for creating a rank-3 array with arbitrary layout. -template -std::unique_ptr CreateR3LiteralWithLayout( - std::initializer_list>> - values, - tensorflow::gtl::ArraySlice minor_to_major) { - auto literal = MakeUnique(); - const int64 d0 = values.size(); - const int64 d1 = values.begin()->size(); - const int64 d2 = values.begin()->begin()->size(); - literal.get()->PopulateWithValue(0, {d0, d1, d2}); - *literal->mutable_shape()->mutable_layout() = - LayoutUtil::MakeLayout(minor_to_major); - TF_CHECK_OK(ShapeUtil::ValidateShape(literal->shape())); - - int64 dim0 = 0; - for (auto inner_list : values) { - int64 dim1 = 0; - for (auto inner_inner_list : inner_list) { - int64 dim2 = 0; - for (auto value : inner_inner_list) { - literal.get()->Set({dim0, dim1, dim2}, value); - ++dim2; - } - ++dim1; - } - ++dim0; - } - return literal; -} +// Check that a given module satisfies various constraints before trying to +// execute it. +Status VerifyHloModule(const perftools::gputools::Platform& platform, + HloModule* const module); -} // namespace test_utils } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_TESTS_TEST_UTILS_H_ diff --git a/tensorflow/compiler/xla/tests/transfer_manager_test.cc b/tensorflow/compiler/xla/tests/transfer_manager_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..f2a64749482e5f5a8c5d72034fb7a4eee07baf48 --- /dev/null +++ b/tensorflow/compiler/xla/tests/transfer_manager_test.cc @@ -0,0 +1,215 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/device_memory_allocator.h" +#include "tensorflow/compiler/xla/service/generic_transfer_manager.h" +#include "tensorflow/compiler/xla/service/shaped_buffer.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/local_client_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +class TransferManagerTest : public LocalClientTestBase { + protected: + TransferManagerTest() + : shape_size_fn_([this](const Shape& shape) { + return transfer_manager_->GetByteSizeRequirement(shape); + }) {} + + ~TransferManagerTest() override = default; + + std::unique_ptr AllocateDeviceBuffer(const Shape& shape) { + return ScopedShapedBuffer::Allocate( + shape, GetOrCreateAllocator(local_client_->platform()), + /*device_ordinal=*/0, shape_size_fn_) + .ValueOrDie(); + } + + private: + std::function shape_size_fn_; +}; + +XLA_TEST_F(TransferManagerTest, TransferR0U32) { + std::unique_ptr literal = Literal::CreateR0(42); + const Shape& shape = literal->shape(); + auto device_buffer = AllocateDeviceBuffer(shape); + + // Round trip literal through device. + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( + stream_executor_, *literal, *device_buffer)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice( + stream_executor_, *device_buffer)); + + LiteralTestUtil::ExpectR0Equal(42, *result); +} + +XLA_TEST_F(TransferManagerTest, TransferR1F32) { + std::unique_ptr literal = + Literal::CreateR1({1.25f, 2.5f, -17.0f, -20.125f}); + const Shape& shape = literal->shape(); + auto device_buffer = AllocateDeviceBuffer(shape); + + // Round trip literal through device. + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( + stream_executor_, *literal, *device_buffer)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice( + stream_executor_, *device_buffer)); + + LiteralTestUtil::ExpectR1Equal({1.25f, 2.5f, -17.0f, -20.125f}, + *result); +} + +XLA_TEST_F(TransferManagerTest, TransferR1LargeF32) { + std::vector test_vector(1024 * 1024); + std::iota(test_vector.begin(), test_vector.end(), 0); + std::unique_ptr literal = Literal::CreateR1(test_vector); + const Shape& shape = literal->shape(); + auto device_buffer = AllocateDeviceBuffer(shape); + + // Round trip literal through device. + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( + stream_executor_, *literal, *device_buffer)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice( + stream_executor_, *device_buffer)); + + LiteralTestUtil::ExpectR1Equal(test_vector, *result); +} + +XLA_TEST_F(TransferManagerTest, TransferR1U8) { + const char* test_string = "0123456789abcdef"; + std::unique_ptr literal = Literal::CreateR1U8(test_string); + const Shape& shape = literal->shape(); + auto device_buffer = AllocateDeviceBuffer(shape); + + // Round trip literal through device. + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( + stream_executor_, *literal, *device_buffer)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice( + stream_executor_, *device_buffer)); + + EXPECT_EQ(result->u8s_string(), test_string); +} + +XLA_TEST_F(TransferManagerTest, TransferR2F32) { + std::unique_ptr literal = + Literal::CreateR2({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}); + const Shape& shape = literal->shape(); + auto device_buffer = AllocateDeviceBuffer(shape); + + // Round trip literal through device. + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( + stream_executor_, *literal, *device_buffer)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice( + stream_executor_, *device_buffer)); + + LiteralTestUtil::ExpectR2Equal( + {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, *result); +} + +XLA_TEST_F(TransferManagerTest, + TransferR2F32AndChangeLayoutTransferringToDevice) { + std::unique_ptr literal = Literal::CreateR2WithLayout( + {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, LayoutUtil::MakeLayout({0, 1})); + const Shape ondevice_shape = + ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0}); + auto device_buffer = AllocateDeviceBuffer(ondevice_shape); + + // Round trip literal through device. Set the on-device layout to something + // different than the literal layout. + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( + stream_executor_, *literal, *device_buffer)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice( + stream_executor_, *device_buffer)); + + EXPECT_FALSE( + LayoutUtil::Equal(result->shape().layout(), literal->shape().layout())); + LiteralTestUtil::ExpectR2Equal( + {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, *result); +} + +XLA_TEST_F(TransferManagerTest, TransferTuple) { + std::unique_ptr literal = Literal::MakeTuple( + {Literal::CreateR0(123.0f).get(), + Literal::CreateR2({{1.0f, 2.0f}, {4.0f, 5.0f}}).get(), + Literal::CreateR1({44.0f, -10.0f, 3333333.3f}).get()}); + auto device_buffer = AllocateDeviceBuffer(literal->shape()); + + // Round trip literal through device. + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( + stream_executor_, *literal, *device_buffer)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice( + stream_executor_, *device_buffer)); + + LiteralTestUtil::ExpectEqual(*literal, *result); +} + +XLA_TEST_F(TransferManagerTest, TransferEmptyTuple) { + std::unique_ptr literal = Literal::MakeTuple({}); + auto device_buffer = AllocateDeviceBuffer(literal->shape()); + + // Round trip literal through device. + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( + stream_executor_, *literal, *device_buffer)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice( + stream_executor_, *device_buffer)); + + LiteralTestUtil::ExpectEqual(*literal, *result); +} + +XLA_TEST_F(TransferManagerTest, TransferNestedTuple) { + std::unique_ptr literal = Literal::MakeTuple( + {Literal::CreateR0(123.0f).get(), + Literal::MakeTuple( + {Literal::CreateR2({{1.0f, 2.0f}, {4.0f, 5.0f}}).get(), + Literal::CreateR1({44.0f, -10.0f, 3333333.3f}).get()}) + .get(), + Literal::CreateR1({-10.0f, 123.0f}).get()}); + auto device_buffer = AllocateDeviceBuffer(literal->shape()); + + // Round trip literal through device. + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( + stream_executor_, *literal, *device_buffer)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice( + stream_executor_, *device_buffer)); + + LiteralTestUtil::ExpectEqual(*literal, *result); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/tuple_test.cc b/tensorflow/compiler/xla/tests/tuple_test.cc index 4920f17a7ed21d587c15b8deac550d5e5bb566c9..65489cfff19c8fecbdead8a7e295bf9cca56038f 100644 --- a/tensorflow/compiler/xla/tests/tuple_test.cc +++ b/tensorflow/compiler/xla/tests/tuple_test.cc @@ -180,7 +180,8 @@ XLA_TEST_F(TupleTest, TupleGTEToTuple) { ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); } -XLA_TEST_F(TupleTest, SelectBetweenPredTuples) { +// TODO(b/68395210): GPU does not tolerate ambiguous top-level buffers. +XLA_TEST_F(TupleTest, DISABLED_ON_GPU(SelectBetweenPredTuples)) { ComputationBuilder b(client_, TestName()); ComputationDataHandle v1, v2; @@ -444,5 +445,61 @@ XLA_TEST_F(TupleTest, GetTupleElementOfNestedTuple) { ComputeAndCompareR1(&builder, expected, arguments, ErrorSpec(1e-5)); } +XLA_TEST_F(TupleTest, ComplexTuples) { + ComputationBuilder builder(client_, TestName()); + { + Shape c64r0 = ShapeUtil::MakeShape(C64, {}); + Shape c64r1 = ShapeUtil::MakeShape(C64, {2}); + Shape c64r2 = ShapeUtil::MakeShape(C64, {3, 2}); + Shape arg0_shape = ShapeUtil::MakeTupleShape( + {c64r0, ShapeUtil::MakeTupleShape({c64r1, c64r2})}); + auto input0 = builder.Parameter(0, arg0_shape, "input0"); + auto t0 = builder.GetTupleElement(input0, 0); + auto t1 = builder.GetTupleElement(input0, 1); + auto t10 = builder.GetTupleElement(t1, 0); + auto t11 = builder.GetTupleElement(t1, 1); + auto sum = builder.Add(builder.Add(t10, t11, {1}), t0); + auto input1 = builder.Parameter(1, c64r1, "input1"); + auto prod = builder.Mul(input1, sum, {1}); + builder.Tuple({builder.Tuple({prod, sum}), + builder.ConstantR0({123, 456})}); + } + + std::unique_ptr arg0 = + client_ + ->TransferToServer(*Literal::MakeTuple( + {Literal::CreateR0({1, 2}).get(), + Literal::MakeTuple( + {Literal::CreateR1({{10, 20}, {30, 40}}).get(), + Literal::CreateR2( + {{{100, 200}, {300, 400}}, + {{1000, 2000}, {3000, 4000}}, + {{10000, 20000}, {30000, 40000}}}) + .get()}) + .get()})) + .ConsumeValueOrDie(); + std::unique_ptr arg1 = + client_ + ->TransferToServer(*Literal::CreateR1({{1, 2}, {1, -2}})) + .ConsumeValueOrDie(); + auto sum = Literal::CreateR2({{{111, 222}, {331, 442}}, + {{1011, 2022}, {3031, 4042}}, + {{10011, 20022}, {30031, 40042}}}); + auto prod = Literal::CreateFromShape(sum->shape()); + ASSERT_TRUE(prod->Populate( + [&sum](tensorflow::gtl::ArraySlice indexes) { + return sum->Get(indexes) * + (indexes[indexes.size() - 1] == 0 + ? complex64(1, 2) + : complex64(1, -2)); + }) + .ok()); + auto expected = + Literal::MakeTuple({Literal::MakeTuple({prod.get(), sum.get()}).get(), + Literal::CreateR0({123, 456}).get()}); + ComputeAndCompareTuple(&builder, *expected, {arg0.get(), arg1.get()}, + error_spec_); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc index 71a1b0abee51ba2819daed23208b0da8d5107207..0b3430ee1ee515c2c98c64a947b7a7021c04f22b 100644 --- a/tensorflow/compiler/xla/tests/while_test.cc +++ b/tensorflow/compiler/xla/tests/while_test.cc @@ -357,6 +357,109 @@ TEST_F(WhileTest, WhileWithVectorResultIntoTuple) { ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); } +TEST_F(WhileTest, WhileWithPermutationAndTupleResult) { + std::vector shape_elements = { + ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {3}), + ShapeUtil::MakeShape(F32, {3}), ShapeUtil::MakeShape(F32, {3})}; + Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements); + + // Create a computation for the condition. + // Repeat for N iterations. + const int N = 2; + Computation condition; + { + ComputationBuilder builder(client_, "condition"); + auto prev = builder.Parameter(0, result_shape, "prev"); + auto iteration = builder.GetTupleElement(prev, 0); + builder.Gt(builder.ConstantR0(N), iteration); + condition = builder.Build().ConsumeValueOrDie(); + } + + // Create a computation for the body. + // Add 1 to the iteration variable and permute the weights. + Computation body; + { + ComputationBuilder builder(client_, "body"); + auto prev = builder.Parameter(0, result_shape, "prev"); + auto iteration = builder.GetTupleElement(prev, 0); + auto w1 = builder.GetTupleElement(prev, 1); + auto w2 = builder.GetTupleElement(prev, 2); + auto w3 = builder.GetTupleElement(prev, 3); + auto result = builder.Tuple( + {builder.Add(iteration, builder.ConstantR0(1)), w3, w1, w2}); + body = builder.Build().ConsumeValueOrDie(); + } + + // Create a While node with computations for the condition and the body. + ComputationBuilder builder(client_, "while"); + auto init = builder.Tuple( + {builder.ConstantR0(0), builder.ConstantR1(3, 1.f), + builder.ConstantR1(3, 2.f), builder.ConstantR1(3, 3.f)}); + auto result = builder.While(condition, body, init); + VLOG(2) << "result = " + << ShapeUtil::HumanString( + *builder.GetShape(result).ConsumeValueOrDie()); + + auto expected_counter = Literal::CreateR0(N); + auto expected_w1 = Literal::CreateR1({1.0f, 1.0f, 1.0f}); + auto expected_w2 = Literal::CreateR1({2.0f, 2.0f, 2.0f}); + auto expected_w3 = Literal::CreateR1({3.0f, 3.0f, 3.0f}); + auto expected = Literal::MakeTuple({expected_counter.get(), expected_w2.get(), + expected_w3.get(), expected_w1.get()}); + VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape()); + ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); +} + +TEST_F(WhileTest, WhileWithPermutationAndVectorResult) { + std::vector shape_elements = { + ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {3}), + ShapeUtil::MakeShape(F32, {3}), ShapeUtil::MakeShape(F32, {3})}; + Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements); + + // Create a computation for the condition. + // Repeat for N iterations. + const int N = 2; + Computation condition; + { + ComputationBuilder builder(client_, "condition"); + auto prev = builder.Parameter(0, result_shape, "prev"); + auto iteration = builder.GetTupleElement(prev, 0); + builder.Gt(builder.ConstantR0(N), iteration); + condition = builder.Build().ConsumeValueOrDie(); + } + + // Create a computation for the body. + // Add 1 to the iteration variable permute the weights. + Computation body; + { + ComputationBuilder builder(client_, "body"); + auto prev = builder.Parameter(0, result_shape, "prev"); + auto iteration = builder.GetTupleElement(prev, 0); + auto w1 = builder.GetTupleElement(prev, 1); + auto w2 = builder.GetTupleElement(prev, 2); + auto w3 = builder.GetTupleElement(prev, 3); + auto result = builder.Tuple( + {builder.Add(iteration, builder.ConstantR0(1)), w3, w1, w2}); + body = builder.Build().ConsumeValueOrDie(); + } + + // Create a While node with computations for the condition and the body. + ComputationBuilder builder(client_, "while"); + auto init = builder.Tuple( + {builder.ConstantR0(0), builder.ConstantR1(3, 1.f), + builder.ConstantR1(3, 2.f), builder.ConstantR1(3, 3.f)}); + auto xla_while = builder.While(condition, body, init); + + auto add12 = builder.Add(builder.GetTupleElement(xla_while, 1), + builder.GetTupleElement(xla_while, 2)); + auto result = builder.Add(add12, builder.GetTupleElement(xla_while, 3)); + VLOG(2) << "result = " + << ShapeUtil::HumanString( + *builder.GetShape(result).ConsumeValueOrDie()); + std::vector expected = {6.f, 6.f, 6.f}; + ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); +} + // Tests a while node when the result type T is a Tuple. // // tuple> result(0, vector(10, 0.0f)); @@ -808,8 +911,7 @@ TEST_F(WhileTest, WhileWithPrngScalarResult) { } } -// TODO(b/34969189) Fails with bad AtomicCmpSwap on GPU on 2017-09-11. -TEST_F(WhileTest, DISABLED_ON_GPU(WhileThatSwapsParameterWithTupleElement)) { +TEST_F(WhileTest, WhileThatSwapsParameterWithTupleElement) { auto element_shape = ShapeUtil::MakeShape(F32, {2}); ComputationBuilder outer(client_, "outer"); @@ -845,8 +947,7 @@ TEST_F(WhileTest, DISABLED_ON_GPU(WhileThatSwapsParameterWithTupleElement)) { ErrorSpec(1e-6)); } -// TODO(b/34969189) Fails with bad AtomicCmpSwap on GPU on 2017-09-11. -TEST_F(WhileTest, DISABLED_ON_GPU(WhileThatSwapsParameterWithBroadcast)) { +TEST_F(WhileTest, WhileThatSwapsParameterWithBroadcast) { auto element_shape = ShapeUtil::MakeShape(F32, {2}); ComputationBuilder outer(client_, "outer"); @@ -899,6 +1000,51 @@ TEST_F(WhileTest, WhileThatTurnsScalarParameterToTupleElement) { ErrorSpec(1e-6)); } +// Tests loop where the init value comes from two sources (constant and +// parameter). +// +// int32 result = (0, 1); +// while (result[0] + result[1] < 30) { +// result[0] = result[0] + 1; +// result[1] = result[1] + 1; +// } +TEST_F(WhileTest, WhileWithMixedTupleElements) { + auto result_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(S32, {})}); + + ComputationBuilder outer(client_, "outer"); + auto p = + outer.Tuple({outer.ConstantR0(0), + outer.Parameter(0, ShapeUtil::MakeShape(S32, {}), "t")}); + + ComputationBuilder cond(client_, "cond"); + auto params = cond.Parameter(0, result_shape, "prev"); + auto cond_t = cond.Add(cond.GetTupleElement(params, 1), + cond.GetTupleElement(params, 0)); + cond.Lt(cond_t, cond.ConstantR0(30)); + + ComputationBuilder body(client_, "body"); + auto body_t = body.Parameter(0, result_shape, "t"); + + auto tuple = body.Tuple( + {body.Add(body.GetTupleElement(params, 0), body.ConstantR0(1)), + body.Add(body.GetTupleElement(params, 1), body.ConstantR0(1))}); + + TF_ASSERT_OK_AND_ASSIGN(auto cond_computation, cond.Build()); + TF_ASSERT_OK_AND_ASSIGN(auto body_computation, body.Build()); + outer.While(cond_computation, body_computation, p); + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr parameter_data, + client_->TransferToServer(*Literal::CreateR0(1))); + + auto add1 = Literal::CreateR0(15); + auto add2 = Literal::CreateR0(16); + auto expected = Literal::MakeTuple({add1.get(), add2.get()}); + ComputeAndCompareTuple(&outer, *expected, {parameter_data.get()}, + ErrorSpec(1e-6)); +} + // Tests nested while loops. // // int32 result = 0; diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD index 759921dce5acf3cd23a121776f3ab0731c9bb623..091fa0c3ec807a66449eca0bfbb141285b8eb532 100644 --- a/tensorflow/compiler/xla/tools/BUILD +++ b/tensorflow/compiler/xla/tools/BUILD @@ -88,6 +88,7 @@ cc_library( "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:testing", "//tensorflow/compiler/xla/service:session_proto", + "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", ], diff --git a/tensorflow/compiler/xla/tools/parser/BUILD b/tensorflow/compiler/xla/tools/parser/BUILD index c84ca9fc833881ce49bcaad5dd85394145151912..97aacf6b39f83978e732060817cd93ede81ca782 100644 --- a/tensorflow/compiler/xla/tools/parser/BUILD +++ b/tensorflow/compiler/xla/tools/parser/BUILD @@ -34,9 +34,9 @@ cc_library( deps = [ "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:lib", "//tensorflow/core:regexp_internal", ], @@ -48,6 +48,7 @@ cc_library( hdrs = ["hlo_parser.h"], deps = [ ":hlo_lexer", + "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", @@ -64,6 +65,7 @@ tf_cc_test( srcs = ["hlo_parser_test.cc"], deps = [ ":hlo_parser", + "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", ], diff --git a/tensorflow/compiler/xla/tools/parser/README.md b/tensorflow/compiler/xla/tools/parser/README.md index 2feaa49db86ea700cab0b794ec441b95ac03b468..6232967f5f04cbf316d985357ae84c28335531e2 100644 --- a/tensorflow/compiler/xla/tools/parser/README.md +++ b/tensorflow/compiler/xla/tools/parser/README.md @@ -43,14 +43,22 @@ operand : shape name ; -extra_attributes +attributes : /*empty*/ - | ',' extra_attribute - | ',' extra_attribute extra_attributes + | ',' attribute + | ',' attribute attributes ; -extra_attribute +attribute : attribute_name attribute_value ; +attribute_value + : kInt + | kName + | [0-9bf]{2,}_[0-9io]{2,}->[0-9bf]{2,} /*dim_labels_pattern*/ + | [0-9]+(x[0-9]+)+ /*dxd_pattern*/ + | [0-9]+_[0-9]+(_[0-9]+)?(x[0-9]+_[0-9]+(_[0-9]+)?)* /*pad_pattern*/ + | '{' sub_attributes '}' + ; param_list : '(' param_list1 ')' @@ -82,4 +90,25 @@ identifier : [a-zA-Z_][a-zA-Z0-9_.-]* ; +/* literal is in the right hand side of a constant instruction. */ +literal + : tuple + | non_tuple + ; +tuple + : shape '(' literal_list ')' + ; +literal_list + : /*empty*/ + : literal + | literal_list ',' literal + ; +non_tuple + : rank01 + | rank2345 + ; +rank2345 + : shape nested_array + ; + ``` diff --git a/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc b/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc index 486df6854016d2d796781d722e6a6a27273e1cf3..459d511e90d87537f3a3404b82df7b28b1fe08bd 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/regexp.h" namespace xla { @@ -122,7 +123,7 @@ TokKind HloLexer::LexToken() { current_ptr_++; return TokKind::kArrow; } - return LexDigitOrNegative(); + return LexNumberOrPattern(); case '=': return TokKind::kEqual; case ',': @@ -143,22 +144,29 @@ TokKind HloLexer::LexToken() { return TokKind::kLparen; case ')': return TokKind::kRparen; + case '/': + return LexComment(); + case '"': + return LexString(); } } } -// Lex a shape, name, keyword, or opcode. +// Lex a shape, name, keyword, attribute name, the dim labels pattern, and +// other identifiers. +// // shape ::= ([a-zA-Z0-9_]*[0-9]*)\[([0-9,]*)\](?:\s*{([0-9,]*)})? // name ::= [a-zA-Z_][a-zA-Z0-9_.-]*: // keyword ::= HloModule, ENTRY, ... -// opcode ::= add, greater-than, ... // attribute_name ::= condition, body, dimensions, ... +// dim_labels_pattern ::= [0-9bf]{2,}_[0-9io]{2,}->[0-9bf]{2,} +// identifiers ::= other cases that match [a-zA-Z_][a-zA-Z0-9_.-]* TokKind HloLexer::LexIdentifier() { { auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end()); // 'consumable' will be advanced iff its prefix matches the pattern. static LazyRE2 shape_pattern = { - R"(^(\w*\d*)\[([\d,]*)\](?:\s*{([\d,]*)})?)"}; + R"(^(\w*\d*)\[([\d,]*)\](?:{([\d,]*)})?)"}; if (RE2::Consume(&consumable, *shape_pattern)) { auto status_or_shape = ShapeUtil::ParseShapeString( StringPieceFromPointers(token_start_, consumable.begin())); @@ -201,6 +209,8 @@ TokKind HloLexer::LexIdentifier() { KEYWORD(true); KEYWORD(false); + KEYWORD(inf); + KEYWORD(nan); KEYWORD(HloModule); KEYWORD(ENTRY); KEYWORD(ROOT); @@ -209,15 +219,19 @@ TokKind HloLexer::LexIdentifier() { #undef KEYWORD - // See if this is an opcode. - auto opcode = StringToHloOpcode(identifier.ToString()); - if (opcode.ok()) { - opcode_val_ = opcode.ValueOrDie(); - return TokKind::kOpcode; + { + auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end()); + static LazyRE2 dim_labels_pattern = { + R"([0-9bf]{2,}_[0-9io]{2,}->[0-9bf]{2,})"}; + if (RE2::Consume(&consumable, *dim_labels_pattern)) { + current_ptr_ = consumable.begin(); + str_val_.assign(token_start_, current_ptr_); + return TokKind::kDimLabels; + } } - current_ptr_ = token_start_ + 1; - return TokKind::kError; + str_val_ = identifier.ToString(); + return TokKind::kIdent; } // Lex names after a % character. @@ -236,14 +250,20 @@ TokKind HloLexer::LexPercent() { return TokKind::kError; } -// Lex integer and floating-point values. -// int [-]?[0-9]+ -// fp with exp [-]?([0-9]+|[0-9]+[.][0-9]*|[0-9]*[.][0-9]+)([eE][+-]?[0-9]+) -// fp without exp [-]?([0-9]+[.][0-9]*|[0-9]*[.][0-9]+) -TokKind HloLexer::LexDigitOrNegative() { +// Lex integer and floating-point values, -inf, and patterns for dim labels, +// dxd (e.g. 1x2x3), and pad. +// +// fp with exp ::= [-]?([0-9]+|[0-9]+[.][0-9]*|[0-9]*[.][0-9]+)([eE][+-]?[0-9]+) +// fp without exp ::= [-]?([0-9]+[.][0-9]*|[0-9]*[.][0-9]+) +// dim_labels_pattern ::= [0-9bf]{2,}_[0-9io]{2,}->[0-9bf]{2,} +// dxd_pattern ::= [0-9]+(x[0-9]+)+ +// pad_pattern ::= [0-9]+_[0-9]+(_[0-9]+)?(x[0-9]+_[0-9]+(_[0-9]+)?)* +// int ::= [-]?[0-9]+ +// negative inf ::= '-inf' +TokKind HloLexer::LexNumberOrPattern() { auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end()); static LazyRE2 float_pattern = { - R"([-]?((\d+|\d+[.]\d*|\d*[.]\d+)([eE][+-]?\d+))|(\d+[.]\d*|\d*[.]\d+))"}; + R"([-]?((\d+|\d+[.]\d*|\d*[.]\d+)([eE][+-]?\d+))|[-]?(\d+[.]\d*|\d*[.]\d+))"}; if (RE2::Consume(&consumable, *float_pattern)) { current_ptr_ = consumable.begin(); tensorflow::strings::safe_strtod(string(token_start_, current_ptr_).c_str(), @@ -251,6 +271,30 @@ TokKind HloLexer::LexDigitOrNegative() { return TokKind::kDecimal; } + static LazyRE2 dim_labels_pattern = { + R"([0-9bf]{2,}_[0-9io]{2,}->[0-9bf]{2,})"}; + static LazyRE2 dxd_pattern = {R"([0-9]+(x[0-9]+)+)"}; + static LazyRE2 pad_pattern = { + R"([0-9]+_[0-9]+(_[0-9]+)?(x[0-9]+_[0-9]+(_[0-9]+)?)*)"}; + + if (RE2::Consume(&consumable, *dim_labels_pattern)) { + current_ptr_ = consumable.begin(); + str_val_.assign(token_start_, current_ptr_); + return TokKind::kDimLabels; + } + + if (RE2::Consume(&consumable, *dxd_pattern)) { + current_ptr_ = consumable.begin(); + str_val_.assign(token_start_, current_ptr_); + return TokKind::kDxD; + } + + if (RE2::Consume(&consumable, *pad_pattern)) { + current_ptr_ = consumable.begin(); + str_val_.assign(token_start_, current_ptr_); + return TokKind::kPad; + } + static LazyRE2 int_pattern = {R"([-]?\d+)"}; if (RE2::Consume(&consumable, *int_pattern)) { current_ptr_ = consumable.begin(); @@ -259,23 +303,154 @@ TokKind HloLexer::LexDigitOrNegative() { return TokKind::kInt; } + static LazyRE2 neg_inf = {"-inf"}; + if (RE2::Consume(&consumable, *neg_inf)) { + current_ptr_ = consumable.begin(); + return TokKind::kNegInf; + } + return TokKind::kError; } -StringPiece HloLexer::GetCurrentLine() const { - const char* start = token_start_; - const char* end = current_ptr_; - if (!CanDereference(start) || !CanDereference(end)) { - return "LINE OUT OF RANGE"; +std::pair HloLexer::GetLineAndColumn(LocTy location) const { + unsigned line_no = 1; + const char* start = buf_.begin(); + const char* ptr = start; + if (line_no_cache_.last_query && CanDereference(line_no_cache_.last_query) && + line_no_cache_.last_query <= location) { + ptr = line_no_cache_.last_query; + line_no = line_no_cache_.line_no_of_query; + } + for (; ptr != location; ptr++) { + if (*ptr == '\n') { + line_no++; + } } - while (start > buf_.begin() && *start != '\n') { - start--; + + // Update the line number cache. + line_no_cache_.last_query = ptr; + line_no_cache_.line_no_of_query = line_no; + size_t line_offset = StringPieceFromPointers(start, ptr).rfind('\n'); + if (line_offset == StringPiece::npos) { + line_offset = 0; } - while (end < buf_.end() && *end != '\n') { - end++; + return {line_no, ptr - start - line_offset}; +} + +StringPiece HloLexer::GetLine(LocTy loc) const { + if (!CanDereference(loc)) { + return "LINE OUT OF RANGE"; } + size_t line_start = + StringPieceFromPointers(buf_.begin(), loc + 1).rfind('\n'); + const char* start = line_start == StringPiece::npos + ? buf_.begin() + : buf_.begin() + line_start + 1; + size_t line_end = StringPieceFromPointers(loc, buf_.end()).find('\n'); + const char* end = line_end == StringPiece::npos ? buf_.end() : loc + line_end; + return StringPieceFromPointers(start, end); } +TokKind HloLexer::LexComment() { + auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end()); + static LazyRE2 comment_pattern = {R"(\/\*.*?\*\/)"}; + if (RE2::Consume(&consumable, *comment_pattern)) { + current_ptr_ = consumable.begin(); + return TokKind::kComment; + } + return TokKind::kError; +} + +// Lexes quoted string with escaping characters. If matched, the quoted string +// will be unescaped and stored to str_val_. +TokKind HloLexer::LexString() { + auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end()); + static LazyRE2 escaping_pattern = {R"("([^"\\]|\\.)*")"}; + if (RE2::Consume(&consumable, *escaping_pattern)) { + current_ptr_ = consumable.begin(); + StringPiece raw = + StringPieceFromPointers(token_start_ + 1, current_ptr_ - 1); + string error; + if (!tensorflow::str_util::CUnescape(raw, &str_val_, &error)) { + LOG(ERROR) << "Failed unescaping string: " << raw << ". error: " << error; + return TokKind::kError; + } + return TokKind::kString; + } + return TokKind::kError; +} + +string TokKindToString(TokKind kind) { + switch (kind) { + case TokKind::kEof: + return "kEof"; + case TokKind::kError: + return "kError"; + case TokKind::kEqual: + return "kEqaul"; + case TokKind::kComma: + return "kComma"; + case TokKind::kColon: + return "kColon"; + case TokKind::kLsquare: + return "kLsquare"; + case TokKind::kRsquare: + return "kRsquare"; + case TokKind::kLbrace: + return "kLbrace"; + case TokKind::kRbrace: + return "kRbrace"; + case TokKind::kLparen: + return "kLparen"; + case TokKind::kRparen: + return "kRparen"; + case TokKind::kArrow: + return "kArrow"; + case TokKind::kComment: + return "kComment"; + case TokKind::kw_HloModule: + return "kw_HloModule"; + case TokKind::kw_ENTRY: + return "kw_ENTRY"; + case TokKind::kw_ROOT: + return "kw_ROOT"; + case TokKind::kw_true: + return "kw_true"; + case TokKind::kw_false: + return "kw_false"; + case TokKind::kw_maximal: + return "kw_maximal"; + case TokKind::kw_replicated: + return "kw_replicated"; + case TokKind::kw_nan: + return "kw_nan"; + case TokKind::kw_inf: + return "kw_inf"; + case TokKind::kNegInf: + return "kNegInf"; + case TokKind::kName: + return "kName"; + case TokKind::kAttributeName: + return "kAttributeName"; + case TokKind::kDimLabels: + return "kDimLabels"; + case TokKind::kDxD: + return "kDxD"; + case TokKind::kPad: + return "kPad"; + case TokKind::kIdent: + return "kIdent"; + case TokKind::kString: + return "kString"; + case TokKind::kShape: + return "kShape"; + case TokKind::kInt: + return "kInt"; + case TokKind::kDecimal: + return "kDecimal"; + } +} + } // namespace tools } // namespace xla diff --git a/tensorflow/compiler/xla/tools/parser/hlo_lexer.h b/tensorflow/compiler/xla/tools/parser/hlo_lexer.h index 433a3a3601e969de154d2f463f650f5f0b07a49f..27880b9b8afbfa58abfedc3b2cecd5236b78a6d6 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_lexer.h +++ b/tensorflow/compiler/xla/tools/parser/hlo_lexer.h @@ -18,8 +18,8 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/tools/parser/hlo_token.h" +#include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/logging.h" @@ -37,11 +37,17 @@ class HloLexer { } TokKind Lex() { return current_kind_ = LexToken(); } + TokKind GetKind() const { return current_kind_; } string GetStrVal() const { switch (GetKind()) { case TokKind::kName: case TokKind::kAttributeName: + case TokKind::kDimLabels: + case TokKind::kDxD: + case TokKind::kPad: + case TokKind::kString: + case TokKind::kIdent: return str_val_; default: LOG(FATAL) << "This token does not have string value"; @@ -51,10 +57,6 @@ class HloLexer { CHECK(GetKind() == TokKind::kShape); return shape_val_; } - HloOpcode GetOpcodeVal() const { - CHECK(GetKind() == TokKind::kOpcode); - return opcode_val_; - } int64 GetInt64Val() const { CHECK(GetKind() == TokKind::kInt); return int64_val_; @@ -64,8 +66,16 @@ class HloLexer { return decimal_val_; } - // Returns the line of text that is currently being lexed. - tensorflow::StringPiece GetCurrentLine() const; + typedef const char* LocTy; + + // Returns the location of the current token. + LocTy GetLoc() const { return token_start_; } + + // Returns the line and column of a location in the buffer. + std::pair GetLineAndColumn(LocTy location) const; + + // Returns the whole line given the location. + tensorflow::StringPiece GetLine(LocTy loc) const; private: // Returns the current character. If it's neither the end of input buffer nor @@ -92,7 +102,9 @@ class HloLexer { TokKind LexPercent(); TokKind LexShape(); TokKind LexConstant(); - TokKind LexDigitOrNegative(); + TokKind LexNumberOrPattern(); + TokKind LexComment(); + TokKind LexString(); const tensorflow::StringPiece buf_; const char* current_ptr_; @@ -102,9 +114,15 @@ class HloLexer { TokKind current_kind_; string str_val_; Shape shape_val_; - HloOpcode opcode_val_; int64 int64_val_; double decimal_val_; + + struct LineNoCacheTy { + const char* last_query; + unsigned line_no_of_query; + }; + // This caches the line number of the previous query. + mutable LineNoCacheTy line_no_cache_{nullptr, 0}; }; } // namespace tools diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc index 5dd8ec6636ecca6f34fff39f285454ee0764a8ad..457b6557836bb2767ce9d05c4494855a0944ca60 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc @@ -15,9 +15,13 @@ limitations under the License. #include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/lib/strings/stringprintf.h" namespace xla { namespace tools { @@ -25,12 +29,22 @@ namespace tools { namespace { using tensorflow::StringPiece; +using tensorflow::gtl::optional; +using tensorflow::str_util::Split; +using tensorflow::str_util::SplitAndParseAsInts; +using tensorflow::strings::Printf; +using tensorflow::strings::StrAppend; using tensorflow::strings::StrCat; +const double kF16max = 65504; + // Parser for the HloModule::ToString() format text. class HloParser { public: - explicit HloParser(StringPiece str) : lexer_(str) {} + using LocTy = HloLexer::LocTy; + + explicit HloParser(StringPiece str, const HloModuleConfig& config) + : lexer_(str), config_(config) {} // Runs the parser. Returns false if an error occurred. bool Run(); @@ -49,42 +63,146 @@ class HloParser { bool ParseInstructionList(HloComputation::Builder* builder, string* root_name); bool ParseInstruction(HloComputation::Builder* builder, string* root_name); - bool ParseSharding(HloInstruction* instruction); + bool ParseControlPredecessors(HloInstruction* instruction); bool ParseLiteral(std::unique_ptr* literal, const Shape& shape); + bool ParseTupleLiteral(std::unique_ptr* literal, const Shape& shape); + bool ParseNonTupleLiteral(std::unique_ptr* literal, + const Shape& shape); + // Sets the sub-value of literal at the given index to the given value. The + // literal's shape must have the default layout. + bool SetValueInLiteral(int64 value, int64 linear_index, Literal* literal); + bool SetValueInLiteral(double value, int64 linear_index, Literal* literal); + bool SetValueInLiteral(bool value, int64 linear_index, Literal* literal); + template + bool SetValueInLiteralHelper(ParsedElemT value, int64 linear_index, + Literal* literal); + bool ParseOperands(std::vector* operands); - // Fill parsed operands into 'operands' and expect a certain number of + // Fills parsed operands into 'operands' and expects a certain number of // operands. bool ParseOperands(std::vector* operands, const int expected_size); - template - bool ParseExtraAttribute(T* value, const string& expected_attribute); - template - bool ParseAttributeValue(T* value); + // Describes the start, limit, and stride on every dimension of the operand + // being sliced. + struct SliceRanges { + std::vector starts; + std::vector limits; + std::vector strides; + }; + + // Types of attributes. + enum class AttrTy { + kInt64, + kInt32, + kFloat, + kString, + kBracedInt64List, + kHloComputation, + kWindow, + kConvolutionDimensionNumbers, + kSharding, + kInstructionList, + kSliceRanges, + kPaddingConfig, + kMetadata, + kFusionKind, + kDistribution, + }; + + struct AttrConfig { + bool required; // whether it's required or optional + AttrTy attr_type; // what type it is + void* result; // where to store the parsed result. + }; + + // attributes ::= (',' attribute)* + // + // Parses attributes given names and configs of the attributes. Each parsed + // result is passed back through the result pointer in corresponding + // AttrConfig. Note that the result pointer must point to a optional typed + // variable which outlives this function. Returns false on error. You should + // not use the any of the results if this function failed. + // + // Example usage: + // + // std::unordered_map attrs; + // optional foo; + // attrs["foo"] = {/*required=*/false, AttrTy::kInt64, &foo}; + // optional bar; + // attrs["bar"] = {/*required=*/true, AttrTy::kWindow, &bar}; + // if (!ParseAttributes(attrs)) { + // return false; // Do not use 'foo' 'bar' if failed. + // } + // // Do something with 'bar'. + // if (foo) { // If attr foo is seen, do something with 'foo'. } + // + bool ParseAttributes(const std::unordered_map& attrs); + + // sub_attributes ::= '{' (','? attribute)* '}' + // + // Usage is the same as ParseAttributes. See immediately above. + bool ParseSubAttributes(const std::unordered_map& attrs); + + // Parses one attribute. If it has already been seen, return error. Returns + // true and adds to seen_attrs on success. + // + // Do not call this except in ParseAttributes or ParseSubAttributes. + bool ParseAttributeHelper(const std::unordered_map& attrs, + std::unordered_set* seen_attrs); + + // Parses a name and finds the corresponding hlo computation. + bool ParseComputationName(HloComputation** value); + // Parses a list of names and finds the corresponding hlo instructions. + bool ParseInstructionNames(std::vector* instructions); + bool ParseWindow(Window* window); + bool ParseConvolutionDimensionNumbers(ConvolutionDimensionNumbers* dnums); + bool ParsePaddingConfig(PaddingConfig* padding); + bool ParseMetadata(OpMetadata* metadata); + bool ParseSharding(OpSharding* sharding); + bool ParseSingleSharding(OpSharding* sharding, bool lbrace_pre_lexed); + + // Parses a sub-attribute of the window attribute, e.g.,size=1x2x3. + bool ParseDxD(const string& name, std::vector* result); + // Parses window's pad sub-attriute, e.g., pad=0_0x3x3. + bool ParseWindowPad(std::vector>* pad); + + bool ParseSliceRanges(SliceRanges* result); + bool ParseInt64List(const TokKind start, const TokKind end, + const TokKind delim, std::vector* result); bool ParseParamList(); bool ParseName(string* result); bool ParseAttributeName(string* result); + bool ParseString(string* result); bool ParseShape(Shape* result); bool ParseOpcode(HloOpcode* result); + bool ParseFusionKind(HloInstruction::FusionKind* result); + bool ParseRandomDistribution(RandomDistribution* result); bool ParseInt64(int64* result); - bool ParseDecimal(double* result); + bool ParseDouble(double* result); bool ParseBool(bool* result); bool ParseToken(TokKind kind, const string& msg); // Logs the current parsing line and the given message. Always returns false. bool TokenError(StringPiece msg); + bool Error(LocTy loc, StringPiece msg); // If the current token is 'kind', eats it (i.e. lexes the next token) and // returns true. bool EatIfPresent(TokKind kind); + // Parses a shape, and returns true if the result is compatible with the given + // shape. + bool EatShapeAndCheckCompatible(const Shape& shape); // Adds the instruction to the pool. Returns false and emits an error if the // instruction already exists. - bool AddInstruction(const string& name, HloInstruction* instruction); + bool AddInstruction(const string& name, HloInstruction* instruction, + LocTy name_loc); // Adds the computation to the pool. Returns false and emits an error if the // computation already exists. - bool AddComputation(const string& name, HloComputation* computation); + bool AddComputation(const string& name, HloComputation* computation, + LocTy name_loc); // The map from the instruction name to the instruction. This does not own the // instructions. @@ -93,15 +211,29 @@ class HloParser { HloLexer lexer_; std::unique_ptr module_; + const HloModuleConfig config_; std::vector error_; }; -bool HloParser::TokenError(StringPiece msg) { - error_.push_back( - StrCat("was parsing \"", lexer_.GetCurrentLine(), "\"; ", msg)); +bool HloParser::Error(LocTy loc, StringPiece msg) { + auto line_col = lexer_.GetLineAndColumn(loc); + const unsigned line = line_col.first; + const unsigned col = line_col.second; + std::vector error_lines; + error_lines.push_back( + StrCat("was parsing ", line, ":", col, ": error: ", msg)); + error_lines.push_back(lexer_.GetLine(loc).ToString()); + error_lines.push_back(col == 0 ? "" : StrCat(string(col - 1, ' '), "^")); + + error_.push_back(tensorflow::str_util::Join(error_lines, "\n")); + VLOG(1) << "Error: " << error_.back(); return false; } +bool HloParser::TokenError(StringPiece msg) { + return Error(lexer_.GetLoc(), msg); +} + bool HloParser::Run() { lexer_.Lex(); return ParseHloModule(); @@ -120,7 +252,7 @@ bool HloParser::ParseHloModule() { return false; } - module_ = MakeUnique(name); + module_ = MakeUnique(name, config_); return ParseComputations(); } @@ -139,6 +271,7 @@ bool HloParser::ParseComputations() { bool HloParser::ParseComputation() { const bool is_entry_computation = EatIfPresent(TokKind::kw_ENTRY); string name; + LocTy name_loc = lexer_.GetLoc(); if (!ParseName(&name)) { return false; } @@ -159,6 +292,7 @@ bool HloParser::ParseComputation() { LOG(FATAL) << "instruction " << root_name << " was marked as ROOT but the parser has not seen it before"; } + // Now root can be either an existing instruction or a nullptr. If it's a // nullptr, the implementation of Builder will set the last instruction as // root instruction. @@ -166,7 +300,7 @@ bool HloParser::ParseComputation() { is_entry_computation ? module_->AddEntryComputation(builder->Build(root)) : module_->AddEmbeddedComputation(builder->Build(root)); - return AddComputation(name, computation); + return AddComputation(name, computation, name_loc); } // instruction_list ::= '{' instruction_list1 '}' @@ -186,7 +320,7 @@ bool HloParser::ParseInstructionList(HloComputation::Builder* builder, "expects '}' at the end of instruction list."); } -// instruction ::= ('ROOT')? name '=' shape opcode operands (extra_attribute)* +// instruction ::= ('ROOT')? name '=' shape opcode operands (attribute)* bool HloParser::ParseInstruction(HloComputation::Builder* builder, string* root_name) { string name; @@ -194,6 +328,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, HloOpcode opcode; std::vector operands; bool is_root = EatIfPresent(TokKind::kw_ROOT); + + const LocTy name_loc = lexer_.GetLoc(); if (!ParseName(&name) || !ParseToken(TokKind::kEqual, "expects '=' in instruction") || !ParseShape(&shape) || !ParseOpcode(&opcode)) { @@ -202,6 +338,17 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, if (is_root) { *root_name = name; } + + // Add optional attributes. + std::unordered_map attrs; + optional sharding; + attrs["sharding"] = {/*required=*/false, AttrTy::kSharding, &sharding}; + optional> predecessors; + attrs["control-predecessors"] = {/*required=*/false, AttrTy::kInstructionList, + &predecessors}; + optional metadata; + attrs["metadata"] = {/*required=*/false, AttrTy::kMetadata, &metadata}; + HloInstruction* instruction; switch (opcode) { case HloOpcode::kParameter: { @@ -209,7 +356,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, if (!ParseToken(TokKind::kLparen, "expects '(' before parameter number") || !ParseInt64(¶meter_number) || - !ParseToken(TokKind::kRparen, "expects ')' after parameter number")) { + !ParseToken(TokKind::kRparen, "expects ')' after parameter number") || + !ParseAttributes(attrs)) { return false; } instruction = builder->AddInstruction( @@ -221,7 +369,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, if (!ParseToken(TokKind::kLparen, "expects '(' before constant literal") || !ParseLiteral(&literal, shape) || - !ParseToken(TokKind::kRparen, "expects ')' after constant literal")) { + !ParseToken(TokKind::kRparen, "expects ')' after constant literal") || + !ParseAttributes(attrs)) { return false; } instruction = builder->AddInstruction( @@ -247,7 +396,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, case HloOpcode::kSin: case HloOpcode::kSort: case HloOpcode::kTanh: { - if (!ParseOperands(&operands, /*expected_size=*/1)) { + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseAttributes(attrs)) { return false; } instruction = builder->AddInstruction( @@ -277,7 +427,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, case HloOpcode::kShiftLeft: case HloOpcode::kShiftRightArithmetic: case HloOpcode::kShiftRightLogical: { - if (!ParseOperands(&operands, /*expected_size=*/2)) { + if (!ParseOperands(&operands, /*expected_size=*/2) || + !ParseAttributes(attrs)) { return false; } instruction = builder->AddInstruction(HloInstruction::CreateBinary( @@ -287,7 +438,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, // Ternary ops. case HloOpcode::kClamp: case HloOpcode::kSelect: { - if (!ParseOperands(&operands, /*expected_size=*/3)) { + if (!ParseOperands(&operands, /*expected_size=*/3) || + !ParseAttributes(attrs)) { return false; } instruction = builder->AddInstruction(HloInstruction::CreateTernary( @@ -296,23 +448,34 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, } // Other supported ops. case HloOpcode::kConvert: { - if (!ParseOperands(&operands, /*expected_size=*/1)) { + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseAttributes(attrs)) { return false; } instruction = builder->AddInstruction( HloInstruction::CreateConvert(shape, operands[0])); break; } + case HloOpcode::kBitcastConvert: { + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction( + HloInstruction::CreateBitcastConvert(shape, operands[0])); + break; + } case HloOpcode::kCrossReplicaSum: { - if (!ParseOperands(&operands, /*expected_size=*/1)) { + if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } instruction = builder->AddInstruction( - HloInstruction::CreateCrossReplicaSum(shape, operands[0])); + HloInstruction::CreateCrossReplicaSum(shape, operands)); break; } case HloOpcode::kReshape: { - if (!ParseOperands(&operands, /*expected_size=*/1)) { + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseAttributes(attrs)) { return false; } instruction = builder->AddInstruction( @@ -320,7 +483,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kTuple: { - if (!ParseOperands(&operands)) { + if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } instruction = @@ -328,114 +491,452 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kWhile: { - HloComputation* condition; - HloComputation* body; + optional condition; + optional body; + attrs["condition"] = {/*required=*/true, AttrTy::kHloComputation, + &condition}; + attrs["body"] = {/*required=*/true, AttrTy::kHloComputation, &body}; if (!ParseOperands(&operands, /*expected_size=*/1) || - !ParseExtraAttribute(&condition, - /*expected_attribute=*/"condition") || - !ParseExtraAttribute(&body, /*expected_attribute=*/"body")) { + !ParseAttributes(attrs)) { return false; } instruction = builder->AddInstruction(HloInstruction::CreateWhile( - shape, condition, body, /*init=*/operands[0])); + shape, *condition, *body, /*init=*/operands[0])); break; } case HloOpcode::kRecv: { - int64 channel_id; + optional channel_id; + attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id}; if (!ParseOperands(&operands, /*expected_size=*/0) || - !ParseExtraAttribute(&channel_id, - /*expected_attribute=*/"channel_id")) { + !ParseAttributes(attrs)) { return false; } instruction = builder->AddInstruction( - HloInstruction::CreateRecv(shape, channel_id)); + HloInstruction::CreateRecv(shape.tuple_shapes(0), *channel_id)); + break; + } + case HloOpcode::kRecvDone: { + optional channel_id; + attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id}; + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseAttributes(attrs)) { + return false; + } + if (channel_id != operands[0]->channel_id()) { + return false; + } + instruction = + builder->AddInstruction(HloInstruction::CreateRecvDone(operands[0])); break; } case HloOpcode::kSend: { - int64 channel_id; + optional channel_id; + attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id}; if (!ParseOperands(&operands, /*expected_size=*/1) || - !ParseExtraAttribute(&channel_id, - /*expected_attribute=*/"channel_id")) { + !ParseAttributes(attrs)) { return false; } instruction = builder->AddInstruction( - HloInstruction::CreateSend(operands[0], channel_id)); + HloInstruction::CreateSend(operands[0], *channel_id)); + break; + } + case HloOpcode::kSendDone: { + optional channel_id; + attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id}; + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseAttributes(attrs)) { + return false; + } + if (channel_id != operands[0]->channel_id()) { + return false; + } + instruction = + builder->AddInstruction(HloInstruction::CreateSendDone(operands[0])); break; } case HloOpcode::kGetTupleElement: { - int64 index; + optional index; + attrs["index"] = {/*required=*/true, AttrTy::kInt64, &index}; if (!ParseOperands(&operands, /*expected_size=*/1) || - !ParseExtraAttribute(&index, /*expected_attribute=*/"index")) { + !ParseAttributes(attrs)) { return false; } instruction = builder->AddInstruction( - HloInstruction::CreateGetTupleElement(shape, operands[0], index)); + HloInstruction::CreateGetTupleElement(shape, operands[0], *index)); break; } case HloOpcode::kCall: { - HloComputation* to_apply; - if (!ParseOperands(&operands) || - !ParseExtraAttribute(&to_apply, - /*expected_attribute=*/"to_apply")) { + optional to_apply; + attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, + &to_apply}; + if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction( + HloInstruction::CreateCall(shape, operands, *to_apply)); + break; + } + case HloOpcode::kReduceWindow: { + optional reduce_computation; + optional window; + attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window}; + attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, + &reduce_computation}; + if (!ParseOperands(&operands, /*expected_size=*/2) || + !ParseAttributes(attrs)) { + return false; + } + if (!window) { + window.emplace(); + } + instruction = builder->AddInstruction(HloInstruction::CreateReduceWindow( + shape, /*operand=*/operands[0], /*init_value=*/operands[1], *window, + *reduce_computation)); + break; + } + case HloOpcode::kConvolution: { + optional window; + optional dnums; + attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window}; + attrs["dim_labels"] = {/*required=*/true, + AttrTy::kConvolutionDimensionNumbers, &dnums}; + if (!ParseOperands(&operands, /*expected_size=*/2) || + !ParseAttributes(attrs)) { + return false; + } + if (!window) { + window.emplace(); + } + instruction = builder->AddInstruction(HloInstruction::CreateConvolve( + shape, /*lhs=*/operands[0], /*rhs=*/operands[1], *window, *dnums)); + break; + } + case HloOpcode::kBroadcast: { + optional> broadcast_dimensions; + attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, + &broadcast_dimensions}; + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction(HloInstruction::CreateBroadcast( + shape, operands[0], *broadcast_dimensions)); + break; + } + case HloOpcode::kConcatenate: { + optional> dimensions; + attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, + &dimensions}; + if (!ParseOperands(&operands) || !ParseAttributes(attrs) || + dimensions->size() != 1) { + return false; + } + instruction = builder->AddInstruction(HloInstruction::CreateConcatenate( + shape, operands, dimensions->at(0))); + break; + } + case HloOpcode::kMap: { + optional to_apply; + attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, + &to_apply}; + if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } instruction = builder->AddInstruction( - HloInstruction::CreateCall(shape, operands, to_apply)); + HloInstruction::CreateMap(shape, operands, *to_apply)); break; } - case HloOpcode::kBroadcast: + case HloOpcode::kReduce: { + optional reduce_computation; + attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, + &reduce_computation}; + optional> dimensions_to_reduce; + attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, + &dimensions_to_reduce}; + if (!ParseOperands(&operands, /*expected_size=*/2) || + !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction(HloInstruction::CreateReduce( + shape, /*operand=*/operands[0], /*init_value=*/operands[1], + *dimensions_to_reduce, *reduce_computation)); + break; + } + case HloOpcode::kReverse: { + optional> dimensions; + attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, + &dimensions}; + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction( + HloInstruction::CreateReverse(shape, operands[0], *dimensions)); + break; + } + case HloOpcode::kSelectAndScatter: { + optional select; + attrs["select"] = {/*required=*/true, AttrTy::kHloComputation, &select}; + optional scatter; + attrs["scatter"] = {/*required=*/true, AttrTy::kHloComputation, &scatter}; + optional window; + attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window}; + if (!ParseOperands(&operands, /*expected_size=*/3) || + !ParseAttributes(attrs)) { + return false; + } + if (!window) { + window.emplace(); + } + instruction = + builder->AddInstruction(HloInstruction::CreateSelectAndScatter( + shape, /*operand=*/operands[0], *select, *window, + /*source=*/operands[1], /*init_value=*/operands[2], *scatter)); + break; + } + case HloOpcode::kSlice: { + optional slice_ranges; + attrs["slice"] = {/*required=*/true, AttrTy::kSliceRanges, &slice_ranges}; + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction(HloInstruction::CreateSlice( + shape, operands[0], slice_ranges->starts, slice_ranges->limits, + slice_ranges->strides)); + break; + } + case HloOpcode::kDynamicSlice: { + optional> dynamic_slice_sizes; + attrs["dynamic_slice_sizes"] = { + /*required=*/true, AttrTy::kBracedInt64List, &dynamic_slice_sizes}; + if (!ParseOperands(&operands, /*expected_size=*/2) || + !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction(HloInstruction::CreateDynamicSlice( + shape, /*operand=*/operands[0], /*start_indices=*/operands[1], + *dynamic_slice_sizes)); + break; + } + case HloOpcode::kDynamicUpdateSlice: { + if (!ParseOperands(&operands, /*expected_size=*/3) || + !ParseAttributes(attrs)) { + return false; + } + instruction = + builder->AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + shape, /*operand=*/operands[0], /*update=*/operands[1], + /*start_indices=*/operands[2])); + break; + } + case HloOpcode::kTranspose: { + optional> dimensions; + attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, + &dimensions}; + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction( + HloInstruction::CreateTranspose(shape, operands[0], *dimensions)); + break; + } + case HloOpcode::kBatchNormTraining: { + optional epsilon; + attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon}; + optional feature_index; + attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64, + &feature_index}; + if (!ParseOperands(&operands, /*expected_size=*/3) || + !ParseAttributes(attrs)) { + return false; + } + instruction = + builder->AddInstruction(HloInstruction::CreateBatchNormTraining( + shape, /*operand=*/operands[0], /*scale=*/operands[1], + /*offset=*/operands[2], *epsilon, *feature_index)); + break; + } + case HloOpcode::kBatchNormInference: { + optional epsilon; + attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon}; + optional feature_index; + attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64, + &feature_index}; + if (!ParseOperands(&operands, /*expected_size=*/5) || + !ParseAttributes(attrs)) { + return false; + } + instruction = + builder->AddInstruction(HloInstruction::CreateBatchNormInference( + shape, /*operand=*/operands[0], /*scale=*/operands[1], + /*offset=*/operands[2], /*mean=*/operands[3], + /*variance=*/operands[4], *epsilon, *feature_index)); + break; + } + case HloOpcode::kBatchNormGrad: { + optional epsilon; + attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon}; + optional feature_index; + attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64, + &feature_index}; + if (!ParseOperands(&operands, /*expected_size=*/5) || + !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction(HloInstruction::CreateBatchNormGrad( + shape, /*operand=*/operands[0], /*scale=*/operands[1], + /*mean=*/operands[2], /*variance=*/operands[3], + /*grad_output=*/operands[4], *epsilon, *feature_index)); + break; + } + case HloOpcode::kPad: { + optional padding; + attrs["padding"] = {/*required=*/true, AttrTy::kPaddingConfig, &padding}; + if (!ParseOperands(&operands, /*expected_size=*/2) || + !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction(HloInstruction::CreatePad( + shape, operands[0], /*padding_value=*/operands[1], *padding)); + break; + } + case HloOpcode::kFusion: { + optional fusion_computation; + attrs["calls"] = {/*required=*/true, AttrTy::kHloComputation, + &fusion_computation}; + optional fusion_kind; + attrs["kind"] = {/*required=*/true, AttrTy::kFusionKind, &fusion_kind}; + if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction(HloInstruction::CreateFusion( + shape, *fusion_kind, operands, *fusion_computation)); + break; + } + case HloOpcode::kInfeed: { + optional config; + attrs["infeed_config"] = {/*required=*/false, AttrTy::kString, &config}; + if (!ParseOperands(&operands, /*expected_size=*/0) || + !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction( + HloInstruction::CreateInfeed(shape, config ? *config : "")); + break; + } + case HloOpcode::kOutfeed: { + optional config; + attrs["outfeed_config"] = {/*required=*/false, AttrTy::kString, &config}; + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction(HloInstruction::CreateOutfeed( + shape, operands[0], config ? *config : "")); + break; + } + case HloOpcode::kRng: { + optional distribution; + attrs["distribution"] = {/*required=*/true, AttrTy::kDistribution, + &distribution}; + if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction( + HloInstruction::CreateRng(shape, *distribution, operands)); + break; + } + case HloOpcode::kReducePrecision: { + optional exponent_bits; + optional mantissa_bits; + attrs["exponent_bits"] = {/*required=*/true, AttrTy::kInt64, + &exponent_bits}; + attrs["mantissa_bits"] = {/*required=*/true, AttrTy::kInt64, + &mantissa_bits}; + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseAttributes(attrs)) { + return false; + } + instruction = + builder->AddInstruction(HloInstruction::CreateReducePrecision( + shape, operands[0], static_cast(*exponent_bits), + static_cast(*mantissa_bits))); + break; + } + case HloOpcode::kConditional: case HloOpcode::kCustomCall: - case HloOpcode::kConcatenate: - case HloOpcode::kReducePrecision: - case HloOpcode::kConvolution: - case HloOpcode::kMap: - case HloOpcode::kPad: - case HloOpcode::kReduce: - case HloOpcode::kReduceWindow: - case HloOpcode::kSelectAndScatter: - case HloOpcode::kReverse: - case HloOpcode::kRng: - case HloOpcode::kSlice: - case HloOpcode::kDynamicSlice: - case HloOpcode::kDynamicUpdateSlice: - case HloOpcode::kTranspose: - case HloOpcode::kFusion: - case HloOpcode::kBatchNormTraining: - case HloOpcode::kBatchNormInference: - case HloOpcode::kInfeed: - case HloOpcode::kOutfeed: - case HloOpcode::kBatchNormGrad: case HloOpcode::kTrace: return TokenError(StrCat("parsing not yet implemented for op: ", HloOpcodeString(opcode))); } - // Parse "sharding=". - if (lexer_.GetKind() == TokKind::kComma) { - if (!ParseSharding(instruction)) { - return false; + + // Add common attrs (sharding, control predecessors) to the instruction, if + // they were seen. + if (sharding) { + instruction->set_sharding( + HloSharding::FromProto(sharding.value()).ValueOrDie()); + } + if (predecessors) { + for (auto* pre : *predecessors) { + Status status = pre->AddControlDependencyTo(instruction); + if (!status.ok()) { + return Error(name_loc, StrCat("error adding control dependency for: ", + name, " status: ", status.ToString())); + } } } + if (metadata) { + instruction->set_metadata(*metadata); + } + return AddInstruction(name, instruction, name_loc); +} // NOLINT(readability/fn_size) - return AddInstruction(name, instruction); -} - -// ::= '{' 'replicated'? 'maximal'? ('device=' int)? shape? ('devices=' ('[' -// dims ']')* device_list)? '}' dims ::= int_list device_list ::= int_list -bool HloParser::ParseSharding(HloInstruction* instruction) { - if (!ParseToken(TokKind::kComma, - "expects ',' in front of an extra attribute")) { +// ::= '{' (single_sharding | tuple_sharding) '}' +// +// tuple_sharding ::= single_sharding* (',' single_sharding)* +bool HloParser::ParseSharding(OpSharding* sharding) { + // A single sharding starts with '{' and is not followed by '{'. + // A tuple sharding starts with '{' and is followed by '{', or is '{''}' for + // an empty tuple. + if (!ParseToken(TokKind::kLbrace, + "expected '{' to start sharding attribute")) { return false; } - string attribute_name; - if (!ParseAttributeName(&attribute_name) || attribute_name != "sharding") { - return TokenError("expects attribute name: sharding"); + + if (lexer_.GetKind() != TokKind::kLbrace && + lexer_.GetKind() != TokKind::kRbrace) { + return ParseSingleSharding(sharding, /*lbrace_pre_lexed=*/true); } - if (!ParseToken(TokKind::kLbrace, + // Tuple sharding. + // Allow empty tuple shardings. + if (lexer_.GetKind() != TokKind::kRbrace) { + do { + if (!ParseSingleSharding(sharding->add_tuple_shardings(), + /*lbrace_pre_lexed=*/false)) { + return false; + } + } while (EatIfPresent(TokKind::kComma)); + } + sharding->set_type(OpSharding::Type::OpSharding_Type_TUPLE); + + return ParseToken(TokKind::kRbrace, "expected '}' to end sharding attribute"); +} + +// ::= '{' 'replicated'? 'maximal'? ('device=' int)? shape? +// ('devices=' ('[' dims ']')* device_list)? '}' +// dims ::= int_list device_list ::= int_list +bool HloParser::ParseSingleSharding(OpSharding* sharding, + bool lbrace_pre_lexed) { + if (!lbrace_pre_lexed && + !ParseToken(TokKind::kLbrace, "expected '{' to start sharding attribute")) { return false; } + LocTy loc = lexer_.GetLoc(); bool maximal = false; bool replicated = false; std::vector devices; @@ -501,83 +1002,370 @@ bool HloParser::ParseSharding(HloInstruction* instruction) { } } - OpSharding sharding; if (replicated) { if (!devices.empty()) { - return TokenError( - "replicated shardings should not have any devices assigned"); + return Error(loc, + "replicated shardings should not have any devices assigned"); } if (!ShapeUtil::Equal(tile_shape, Shape())) { - return TokenError( - "replicated shardings should not have any tile shape set"); + return Error(loc, + "replicated shardings should not have any tile shape set"); } - sharding.set_type(OpSharding::Type::OpSharding_Type_REPLICATED); + sharding->set_type(OpSharding::Type::OpSharding_Type_REPLICATED); } else if (maximal) { if (devices.size() != 1) { - return TokenError( - "maximal shardings should have exactly one device assigned"); + return Error(loc, + "maximal shardings should have exactly one device assigned"); } if (!ShapeUtil::Equal(tile_shape, Shape())) { - return TokenError("maximal shardings should not have any tile shape set"); + return Error(loc, "maximal shardings should not have any tile shape set"); } - sharding.set_type(OpSharding::Type::OpSharding_Type_MAXIMAL); - sharding.add_tile_assignment_devices(devices[0]); + sharding->set_type(OpSharding::Type::OpSharding_Type_MAXIMAL); + sharding->add_tile_assignment_devices(devices[0]); } else { if (devices.size() <= 1) { - return TokenError( - "non-maximal shardings must have more than one device assigned"); + return Error( + loc, "non-maximal shardings must have more than one device assigned"); } if (ShapeUtil::Equal(tile_shape, Shape())) { - return TokenError("non-maximal shardings should have a tile shape set"); + return Error(loc, "non-maximal shardings should have a tile shape set"); } if (tile_assignment_dimensions.empty()) { - return TokenError( + return Error( + loc, "non-maximal shardings must have a tile assignment list including " "dimensions"); } - sharding.set_type(OpSharding::Type::OpSharding_Type_OTHER); - *sharding.mutable_tile_shape() = tile_shape; + sharding->set_type(OpSharding::Type::OpSharding_Type_OTHER); + *sharding->mutable_tile_shape() = tile_shape; for (int64 dim : tile_assignment_dimensions) { - sharding.add_tile_assignment_dimensions(dim); + sharding->add_tile_assignment_dimensions(dim); } for (int64 device : devices) { - sharding.add_tile_assignment_devices(device); + sharding->add_tile_assignment_devices(device); } } - instruction->set_sharding(HloSharding::FromProto(sharding).ValueOrDie()); lexer_.Lex(); return true; } -bool HloParser::ParseLiteral(std::unique_ptr* literal, - const Shape& shape) { +// '{' name+ '}' +bool HloParser::ParseInstructionNames( + std::vector* instructions) { + if (!ParseToken(TokKind::kLbrace, + "expects '{' at the beginning of instruction name list")) { + return false; + } + LocTy loc = lexer_.GetLoc(); + do { + string name; + if (!ParseName(&name)) { + return Error(loc, "expects a instruction name"); + } + HloInstruction* instr = + tensorflow::gtl::FindPtrOrNull(instruction_pool_, name); + if (!instr) { + return TokenError( + Printf("instruction '%s' is not defined", name.c_str())); + } + instructions->push_back(instr); + } while (EatIfPresent(TokKind::kComma)); + + return ParseToken(TokKind::kRbrace, + "expects '}' at the end of instruction name list"); +} + +bool HloParser::SetValueInLiteral(int64 value, int64 linear_index, + Literal* literal) { + const Shape& shape = literal->shape(); switch (shape.element_type()) { - case PRED: - bool b; - if (!ParseBool(&b)) { - return false; - } - *literal = Literal::CreateR0(b); - return true; + case S8: + return SetValueInLiteralHelper(value, linear_index, literal); + case S16: + return SetValueInLiteralHelper(value, linear_index, literal); case S32: - int64 i; - if (!ParseInt64(&i)) { - return false; - } - *literal = Literal::CreateR0(i); - return true; + return SetValueInLiteralHelper(value, linear_index, literal); + case S64: + return SetValueInLiteralHelper(value, linear_index, literal); + case U8: + return SetValueInLiteralHelper(value, linear_index, literal); + case U16: + return SetValueInLiteralHelper(value, linear_index, literal); + case U32: + return SetValueInLiteralHelper(value, linear_index, literal); + case U64: + return SetValueInLiteralHelper(value, linear_index, literal); + default: + LOG(FATAL) << "unknown integral primitive type " + << PrimitiveType_Name(shape.element_type()); + } +} + +bool HloParser::SetValueInLiteral(double value, int64 linear_index, + Literal* literal) { + const Shape& shape = literal->shape(); + switch (shape.element_type()) { + case F16: + return SetValueInLiteralHelper(value, linear_index, literal); + case BF16: + return SetValueInLiteralHelper(value, linear_index, literal); case F32: - double d; - if (!ParseDecimal(&d)) { - return false; - } - *literal = Literal::CreateR0(d); - return true; + return SetValueInLiteralHelper(value, linear_index, literal); + case F64: + return SetValueInLiteralHelper(value, linear_index, literal); + default: + LOG(FATAL) << "unknown floating point primitive type " + << PrimitiveType_Name(shape.element_type()); + } +} + +bool HloParser::SetValueInLiteral(bool value, int64 linear_index, + Literal* literal) { + const Shape& shape = literal->shape(); + switch (shape.element_type()) { + case PRED: + return SetValueInLiteralHelper(value, linear_index, literal); default: - return TokenError(StrCat("unsupported constant in shape: ", - ShapeUtil::HumanString(shape))); + LOG(FATAL) << PrimitiveType_Name(shape.element_type()) + << " is not PRED type"; + } +} + +template +bool HloParser::SetValueInLiteralHelper(ParsedElemT value, int64 linear_index, + Literal* literal) { + // Check that linear_index is in range. + if (linear_index >= ShapeUtil::ElementsIn(literal->shape())) { + return TokenError( + StrCat("trys to set value ", value, " to a literal in shape ", + ShapeUtil::HumanString(literal->shape()), " at linear index ", + linear_index, ", but the index is out of range")); + } + + if (std::isnan(value) || + (std::numeric_limits::has_infinity && + (std::numeric_limits::infinity() == value || + -std::numeric_limits::infinity() == value))) { + // Skip range checking for non-finite value. + } else if (literal->shape().element_type() == F16 || + literal->shape().element_type() == BF16) { + if (value > kF16max || value < -kF16max) { + return TokenError(StrCat( + "value ", value, " is out of range for literal's primitive type ", + PrimitiveType_Name(literal->shape().element_type()))); + } + } else if (value > static_cast( + std::numeric_limits::max()) || + value < static_cast( + std::numeric_limits::lowest())) { + // Value is out of range for LiteralNativeT. + return TokenError(StrCat( + "value ", value, " is out of range for literal's primitive type ", + PrimitiveType_Name(literal->shape().element_type()))); } + + literal->GetMutableArraySlice().at(linear_index) = + static_cast(value); + return true; +} + +bool HloParser::EatShapeAndCheckCompatible(const Shape& shape) { + Shape new_shape; + if (!ParseShape(&new_shape)) { + return TokenError(StrCat("expects shape ", ShapeUtil::HumanString(shape))); + } + if (!ShapeUtil::Compatible(shape, new_shape)) { + return TokenError(StrCat( + "expects shape ", ShapeUtil::HumanString(shape), + ", but sees a different shape: ", ShapeUtil::HumanString(new_shape))); + } + return true; +} + +// literal +// ::= tuple +// ::= non_tuple +bool HloParser::ParseLiteral(std::unique_ptr* literal, + const Shape& shape) { + return ShapeUtil::IsTuple(shape) ? ParseTupleLiteral(literal, shape) + : ParseNonTupleLiteral(literal, shape); +} + +// tuple +// ::= shape '(' literal_list ')' +// literal_list +// ::= /*empty*/ +// ::= literal (',' literal)* +bool HloParser::ParseTupleLiteral(std::unique_ptr* literal, + const Shape& shape) { + if (!EatShapeAndCheckCompatible(shape)) { + return TokenError(StrCat("expects tuple constant in shape ", + ShapeUtil::HumanString(shape))); + } + if (!ParseToken(TokKind::kLparen, "expects '(' in front of tuple elements")) { + return false; + } + std::vector> elements( + ShapeUtil::TupleElementCount(shape)); + + if (lexer_.GetKind() == TokKind::kRparen) { + // empty + } else { + // literal, (',' literal)* + for (int i = 0; i < elements.size(); i++) { + if (i > 0) { + ParseToken(TokKind::kComma, "exepcts ',' to separate tuple elements"); + } + if (!ParseLiteral(&elements[i], + ShapeUtil::GetTupleElementShape(shape, i))) { + return TokenError(StrCat("expects the ", i, "th element")); + } + } + } + *literal = Literal::MakeTupleOwned(std::move(elements)); + return ParseToken(TokKind::kRparen, + StrCat("expects ')' at the end of the tuple with ", + ShapeUtil::TupleElementCount(shape), "elements")); +} + +// non_tuple +// ::= rank01 +// ::= rank2345 +// rank2345 ::= shape nested_array +bool HloParser::ParseNonTupleLiteral(std::unique_ptr* literal, + const Shape& shape) { + const int64 rank = ShapeUtil::Rank(shape); + if (rank > 1 && !EatShapeAndCheckCompatible(shape)) { + return false; + } + + // Create a literal with the given shape in default layout. + *literal = Literal::CreateFromDimensions(shape.element_type(), + AsInt64Slice(shape.dimensions())); + int64 nest_level = 0; + int64 linear_index = 0; + // elems_seen_per_dim[i] is how many elements or sub-arrays we have seen for + // the dimension i. For example, to parse f32[2,3] {{1, 2, 3}, {4, 5, 6}}, + // when we are parsing the 2nd '{' (right before '1'), we are seeing a + // sub-array of the dimension 0, so elems_seen_per_dim[0]++. When we are at + // the first '}' (right after '3'), it means the sub-array ends, and the + // sub-array is supposed to contain exactly 3 elements, so check if + // elems_seen_per_dim[1] is 3. + std::vector elems_seen_per_dim(rank); + auto get_index_str = [&elems_seen_per_dim](int dim) -> string { + std::vector elems_seen_until_dim(elems_seen_per_dim.begin(), + elems_seen_per_dim.begin() + dim); + return StrCat("[", + tensorflow::str_util::Join( + elems_seen_until_dim, ",", + [](string* out, const int64& num_elems) { + tensorflow::strings::StrAppend(out, num_elems - 1); + }), + "]"); + }; + do { + switch (lexer_.GetKind()) { + default: + return TokenError("unexpected token type in a literal"); + case TokKind::kLbrace: { + nest_level++; + if (nest_level > rank) { + return TokenError(Printf( + "expects nested array in rank %lld, but sees larger", rank)); + } + if (nest_level > 1) { + elems_seen_per_dim[nest_level - 2]++; + if (elems_seen_per_dim[nest_level - 2] > + shape.dimensions(nest_level - 2)) { + return TokenError(Printf( + "expects %lld elements in the %sth element, but sees more", + shape.dimensions(nest_level - 2), + get_index_str(nest_level - 2).c_str())); + } + } + lexer_.Lex(); + break; + } + case TokKind::kRbrace: { + nest_level--; + if (elems_seen_per_dim[nest_level] != shape.dimensions(nest_level)) { + return TokenError(Printf( + "expects %lld elements in the %sth element, but sees %lld", + shape.dimensions(nest_level), get_index_str(nest_level).c_str(), + elems_seen_per_dim[nest_level])); + } + elems_seen_per_dim[nest_level] = 0; + lexer_.Lex(); + break; + } + case TokKind::kComma: + case TokKind::kComment: + // Skip. + lexer_.Lex(); + break; + case TokKind::kw_true: + case TokKind::kw_false: + case TokKind::kInt: + case TokKind::kDecimal: + case TokKind::kw_nan: + case TokKind::kw_inf: + case TokKind::kNegInf: { + if (rank > 0) { + if (nest_level != rank) { + return TokenError( + Printf("expects nested array in rank %lld, but sees %lld", rank, + nest_level)); + } + elems_seen_per_dim[rank - 1]++; + if (elems_seen_per_dim[rank - 1] > shape.dimensions(rank - 1)) { + return TokenError( + Printf("expects %lld elements on the minor-most dimension, but " + "sees more", + shape.dimensions(rank - 1))); + } + } + if (lexer_.GetKind() == TokKind::kw_true || + lexer_.GetKind() == TokKind::kw_false) { + // TODO(congliu): bool type literals with rank >= 1 are actually + // printed in a compact form instead of "true" or "false". Fix that. + if (!SetValueInLiteral(lexer_.GetKind() == TokKind::kw_true, + linear_index++, literal->get())) { + return false; + } + lexer_.Lex(); + } else if (primitive_util::IsIntegralType(shape.element_type())) { + LocTy loc = lexer_.GetLoc(); + int64 value; + if (!ParseInt64(&value)) { + return Error(loc, StrCat("expects integer for primitive type: ", + PrimitiveType_Name(shape.element_type()))); + } + if (!SetValueInLiteral(value, linear_index++, literal->get())) { + return false; + } + } else if (primitive_util::IsFloatingPointType(shape.element_type())) { + LocTy loc = lexer_.GetLoc(); + double value; + if (!ParseDouble(&value)) { + return Error( + loc, StrCat("expect floating point value for primitive type: ", + PrimitiveType_Name(shape.element_type()))); + } + if (!SetValueInLiteral(value, linear_index++, literal->get())) { + return false; + } + } else { + return TokenError(StrCat("unsupported premitive type ", + PrimitiveType_Name(shape.element_type()))); + } + break; + } + } // end of switch + } while (nest_level > 0); + + *literal = (*literal)->Relayout(shape.layout()); + return true; } // operands ::= '(' operands1 ')' @@ -594,6 +1382,7 @@ bool HloParser::ParseOperands(std::vector* operands) { // empty } else { do { + LocTy loc = lexer_.GetLoc(); Shape shape; string name; if (!ParseShape(&shape) || !ParseName(&name)) { @@ -602,7 +1391,7 @@ bool HloParser::ParseOperands(std::vector* operands) { HloInstruction* instruction = tensorflow::gtl::FindPtrOrNull(instruction_pool_, name); if (!instruction) { - return TokenError(StrCat("instruction does not exist: ", name)); + return Error(loc, StrCat("instruction does not exist: ", name)); } operands->push_back(instruction); } while (EatIfPresent(TokKind::kComma)); @@ -612,52 +1401,513 @@ bool HloParser::ParseOperands(std::vector* operands) { bool HloParser::ParseOperands(std::vector* operands, const int expected_size) { + LocTy loc = lexer_.GetLoc(); if (!ParseOperands(operands)) { return false; } if (expected_size != operands->size()) { - return TokenError(StrCat("expects ", expected_size, " operands, but has ", + return Error(loc, StrCat("expects ", expected_size, " operands, but has ", operands->size(), " operands")); } return true; } -// extra_attribute ::= ',' attribute_name value -template -bool HloParser::ParseExtraAttribute(T* value, - const string& expected_attribute) { - if (!ParseToken(TokKind::kComma, - "expects ',' in front of an extra attribute")) { +// sub_attributes ::= '{' (','? attribute)* '}' +bool HloParser::ParseSubAttributes( + const std::unordered_map& attrs) { + LocTy loc = lexer_.GetLoc(); + if (!ParseToken(TokKind::kLbrace, "expects '{' to start sub attributes")) { return false; } - string attribute_name; - if (!ParseAttributeName(&attribute_name) && - attribute_name != expected_attribute) { - return TokenError(StrCat("expects attribute name: ", expected_attribute)); + std::unordered_set seen_attrs; + if (lexer_.GetKind() == TokKind::kRbrace) { + // empty + } else { + do { + EatIfPresent(TokKind::kComma); + if (!ParseAttributeHelper(attrs, &seen_attrs)) { + return false; + } + } while (lexer_.GetKind() != TokKind::kRbrace); } - if (!ParseAttributeValue(value)) { - return TokenError( - StrCat("expects value for attribute: ", expected_attribute)); + // Check that all required attrs were seen. + for (const auto& attr_it : attrs) { + if (attr_it.second.required && + seen_attrs.find(attr_it.first) == seen_attrs.end()) { + return Error(loc, Printf("sub-attribute %s is expected but not seen", + attr_it.first.c_str())); + } + } + return ParseToken(TokKind::kRbrace, "expects '}' to end sub attributes"); +} + +// attributes ::= (',' attribute)* +bool HloParser::ParseAttributes( + const std::unordered_map& attrs) { + LocTy loc = lexer_.GetLoc(); + std::unordered_set seen_attrs; + while (EatIfPresent(TokKind::kComma)) { + if (!ParseAttributeHelper(attrs, &seen_attrs)) { + return false; + } + } + // Check that all required attrs were seen. + for (const auto& attr_it : attrs) { + if (attr_it.second.required && + seen_attrs.find(attr_it.first) == seen_attrs.end()) { + return Error(loc, Printf("attribute %s is expected but not seen", + attr_it.first.c_str())); + } + } + return true; +} + +bool HloParser::ParseAttributeHelper( + const std::unordered_map& attrs, + std::unordered_set* seen_attrs) { + LocTy loc = lexer_.GetLoc(); + string name; + if (!ParseAttributeName(&name)) { + return Error(loc, "error parsing attributes"); + } + VLOG(1) << "Parsing attribute " << name; + if (!seen_attrs->insert(name).second) { + return Error(loc, Printf("attribute %s already exists", name.c_str())); + } + auto attr_it = attrs.find(name); + if (attr_it == attrs.end()) { + return Error(loc, Printf("unexpected attribute %s", name.c_str())); + } + AttrTy attr_type = attr_it->second.attr_type; + void* attr_out_ptr = attr_it->second.result; + bool success = [&] { + LocTy attr_loc = lexer_.GetLoc(); + switch (attr_type) { + case AttrTy::kInt64: { + int64 result; + if (!ParseInt64(&result)) { + return false; + } + static_cast*>(attr_out_ptr)->emplace(result); + return true; + } + case AttrTy::kInt32: { + int64 result; + if (!ParseInt64(&result)) { + return false; + } + if (result != static_cast(result)) { + return Error(attr_loc, "value out of range for int32"); + } + static_cast*>(attr_out_ptr) + ->emplace(static_cast(result)); + return true; + } + case AttrTy::kFloat: { + double result; + if (!ParseDouble(&result)) { + return false; + } + if (result > std::numeric_limits::max() || + result < std::numeric_limits::lowest()) { + return Error(attr_loc, "value out of range for float"); + } + static_cast*>(attr_out_ptr) + ->emplace(static_cast(result)); + return true; + } + case AttrTy::kHloComputation: { + HloComputation* result; + if (!ParseComputationName(&result)) { + return false; + } + static_cast*>(attr_out_ptr)->emplace(result); + return true; + } + case AttrTy::kWindow: { + Window result; + if (!ParseWindow(&result)) { + return false; + } + static_cast*>(attr_out_ptr)->emplace(result); + return true; + } + case AttrTy::kConvolutionDimensionNumbers: { + ConvolutionDimensionNumbers result; + if (!ParseConvolutionDimensionNumbers(&result)) { + return false; + } + static_cast*>(attr_out_ptr) + ->emplace(result); + return true; + } + case AttrTy::kSharding: { + OpSharding sharding; + if (!ParseSharding(&sharding)) { + return false; + } + static_cast*>(attr_out_ptr)->emplace(sharding); + return true; + } + case AttrTy::kInstructionList: { + std::vector result; + if (!ParseInstructionNames(&result)) { + return false; + } + static_cast>*>(attr_out_ptr) + ->emplace(result); + return true; + } + case AttrTy::kFusionKind: { + HloInstruction::FusionKind result; + if (!ParseFusionKind(&result)) { + return false; + } + static_cast*>(attr_out_ptr) + ->emplace(result); + return true; + } + case AttrTy::kBracedInt64List: { + std::vector result; + if (!ParseInt64List(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma, + &result)) { + return false; + } + static_cast>*>(attr_out_ptr) + ->emplace(result); + return true; + } + case AttrTy::kSliceRanges: { + SliceRanges result; + if (!ParseSliceRanges(&result)) { + return false; + } + static_cast*>(attr_out_ptr)->emplace(result); + return true; + } + case AttrTy::kPaddingConfig: { + PaddingConfig result; + if (!ParsePaddingConfig(&result)) { + return false; + } + static_cast*>(attr_out_ptr)->emplace(result); + return true; + } + case AttrTy::kString: { + string result; + if (!ParseString(&result)) { + return false; + } + static_cast*>(attr_out_ptr)->emplace(result); + return true; + } + case AttrTy::kMetadata: { + OpMetadata result; + if (!ParseMetadata(&result)) { + return false; + } + static_cast*>(attr_out_ptr)->emplace(result); + return true; + } + case AttrTy::kDistribution: { + RandomDistribution result; + if (!ParseRandomDistribution(&result)) { + return false; + } + static_cast*>(attr_out_ptr) + ->emplace(result); + return true; + } + } + }(); + if (!success) { + return Error(loc, Printf("error parsing attribute %s", name.c_str())); } return true; } -template <> -bool HloParser::ParseAttributeValue(HloComputation** value) { +bool HloParser::ParseComputationName(HloComputation** value) { string name; + LocTy loc = lexer_.GetLoc(); if (!ParseName(&name)) { - return TokenError("expects computation name"); + return Error(loc, "expects computation name"); } *value = tensorflow::gtl::FindPtrOrNull(computation_pool_, name); if (*value == nullptr) { - return TokenError(StrCat("computation does not exist: ", name)); + return Error(loc, StrCat("computation does not exist: ", name)); } return true; } -template <> -bool HloParser::ParseAttributeValue(int64* value) { - return ParseInt64(value); +// ::= '{' size stride? pad? lhs_dilate? rhs_dilate? '}' +// The subattributes can appear in any order. 'size=' is required, others are +// optional. +bool HloParser::ParseWindow(Window* window) { + LocTy loc = lexer_.GetLoc(); + if (!ParseToken(TokKind::kLbrace, "expected '{' to start window attribute")) { + return false; + } + + std::vector size; + std::vector stride; + std::vector> pad; + std::vector lhs_dilate; + std::vector rhs_dilate; + std::vector rhs_reversal; + while (lexer_.GetKind() != TokKind::kRbrace) { + LocTy attr_loc = lexer_.GetLoc(); + string field_name; + if (!ParseAttributeName(&field_name)) { + return Error(attr_loc, "expects sub-attributes in window"); + } + bool ok = [&] { + if (field_name == "size") { + return ParseDxD("size", &size); + } + if (field_name == "stride") { + return ParseDxD("stride", &stride); + } + if (field_name == "lhs_dilate") { + return ParseDxD("lhs_dilate", &lhs_dilate); + } + if (field_name == "rhs_dilate") { + return ParseDxD("rls_dilate", &rhs_dilate); + } + if (field_name == "pad") { + return ParseWindowPad(&pad); + } + if (field_name == "rhs_reversal") { + return ParseDxD("rhs_reversal", &rhs_reversal); + } + return Error(loc, StrCat("unexpected attribute name: ", field_name)); + }(); + if (!ok) { + return false; + } + } + + if (size.empty()) { + return Error(loc, + "sub-attribute 'size=' is required in the window attribute"); + } + if (!stride.empty() && stride.size() != size.size()) { + return Error(loc, "expects 'stride=' has the same size as 'size='"); + } + if (!lhs_dilate.empty() && lhs_dilate.size() != size.size()) { + return Error(loc, "expects 'lhs_dilate=' has the same size as 'size='"); + } + if (!rhs_dilate.empty() && rhs_dilate.size() != size.size()) { + return Error(loc, "expects 'rhs_dilate=' has the same size as 'size='"); + } + if (!pad.empty() && pad.size() != size.size()) { + return Error(loc, "expects 'pad=' has the same size as 'size='"); + } + + for (int i = 0; i < size.size(); i++) { + window->add_dimensions()->set_size(size[i]); + if (!pad.empty()) { + window->mutable_dimensions(i)->set_padding_low(pad[i][0]); + window->mutable_dimensions(i)->set_padding_high(pad[i][1]); + } + // If some field is not present, it has the default value. + window->mutable_dimensions(i)->set_stride(stride.empty() ? 1 : stride[i]); + window->mutable_dimensions(i)->set_base_dilation( + lhs_dilate.empty() ? 1 : lhs_dilate[i]); + window->mutable_dimensions(i)->set_window_dilation( + rhs_dilate.empty() ? 1 : rhs_dilate[i]); + window->mutable_dimensions(i)->set_window_reversal( + rhs_reversal.empty() ? false : (rhs_reversal[i] == 1)); + } + return ParseToken(TokKind::kRbrace, "expected '}' to end window attribute"); +} + +// This is the inverse of HloInstruction::ConvolutionDimensionNumbersToString. +// The string looks like "dim_labels=0bf_0io->0bf". +bool HloParser::ParseConvolutionDimensionNumbers( + ConvolutionDimensionNumbers* dnums) { + if (lexer_.GetKind() != TokKind::kDimLabels) { + return TokenError("expects dim labels pattern, e.g., 'bf0_0io->0bf'"); + } + string str = lexer_.GetStrVal(); + + // The str is expected to have 3 items, lhs, rhs, out, and it must looks like + // lhs_rhs->out, that is, the first separator is "_" and the second is "->". + // So we replace the "->" with "_" and then split on "_". + str = tensorflow::str_util::StringReplace(str, /*oldsub=*/"->", + /*newsub=*/"_", + /*replace_all=*/false); + std::vector lhs_rhs_out = Split(str, "_"); + if (lhs_rhs_out.size() != 3) { + LOG(FATAL) << "expects 3 items: lhs, rhs, and output dims, but sees " + << str; + } + + const int64 rank = lhs_rhs_out[0].length(); + if (rank != lhs_rhs_out[1].length() || rank != lhs_rhs_out[2].length()) { + return TokenError( + "convolution lhs, rhs, and output must have the same rank"); + } + if (rank < 2) { + return TokenError("convolution rank must >=2"); + } + + auto is_unique = [](string str) -> bool { + std::sort(str.begin(), str.end()); + return std::unique(str.begin(), str.end()) == str.end(); + }; + + // lhs + { + const string& lhs = lhs_rhs_out[0]; + if (!is_unique(lhs)) { + return TokenError( + StrCat("expects unique lhs dimension numbers, but sees ", lhs)); + } + for (int i = 0; i < rank - 2; i++) { + dnums->add_input_spatial_dimensions(-1); + } + for (int i = 0; i < rank; i++) { + char c = lhs[i]; + if (c == 'b') { + dnums->set_input_batch_dimension(i); + } else if (c == 'f') { + dnums->set_input_feature_dimension(i); + } else if (c < '0' + rank && c >= '0') { + dnums->set_input_spatial_dimensions(c - '0', i); + } else { + return TokenError( + Printf("expects [0-%lldbf] in lhs dimension numbers", rank - 1)); + } + } + } + // rhs + { + const string& rhs = lhs_rhs_out[1]; + if (!is_unique(rhs)) { + return TokenError( + StrCat("expects unique rhs dimension numbers, but sees ", rhs)); + } + for (int i = 0; i < rank - 2; i++) { + dnums->add_kernel_spatial_dimensions(-1); + } + for (int i = 0; i < rank; i++) { + char c = rhs[i]; + if (c == 'i') { + dnums->set_kernel_input_feature_dimension(i); + } else if (c == 'o') { + dnums->set_kernel_output_feature_dimension(i); + } else if (c < '0' + rank && c >= '0') { + dnums->set_kernel_spatial_dimensions(c - '0', i); + } else { + return TokenError( + Printf("expects [0-%lldio] in rhs dimension numbers", rank - 1)); + } + } + } + // output + { + const string& out = lhs_rhs_out[2]; + if (!is_unique(out)) { + return TokenError( + StrCat("expects unique output dimension numbers, but sees ", out)); + } + for (int i = 0; i < rank - 2; i++) { + dnums->add_output_spatial_dimensions(-1); + } + for (int i = 0; i < rank; i++) { + char c = out[i]; + if (c == 'b') { + dnums->set_output_batch_dimension(i); + } else if (c == 'f') { + dnums->set_output_feature_dimension(i); + } else if (c < '0' + rank && c >= '0') { + dnums->set_output_spatial_dimensions(c - '0', i); + } else { + return TokenError( + Printf("expects [0-%lldbf] in output dimension numbers", rank - 1)); + } + } + } + + lexer_.Lex(); + return true; +} + +// ::= '{' ranges '}' +// ::= /*empty*/ +// ::= range (',' range)* +// range ::= '[' start ':' limit (':' stride)? ']' +// +// The slice ranges are printed as: +// +// {[dim0_start:dim0_limit:dim0stride], [dim1_start:dim1_limit], ...} +// +// This function extracts the starts, limits, and strides as 3 vectors to the +// result. If stride is not present, stride is 1. For example, if the slice +// ranges is printed as: +// +// {[2:3:4], [5:6:7], [8:9]} +// +// The the parsed result will be: +// +// {/*starts=*/{2, 5, 8}, /*limits=*/{3, 6, 9}, /*strides=*/{4, 7, 1}} +// +bool HloParser::ParseSliceRanges(SliceRanges* result) { + if (!ParseToken(TokKind::kLbrace, "expects '{' to start ranges")) { + return false; + } + std::vector> ranges; + if (lexer_.GetKind() == TokKind::kRbrace) { + // empty + return ParseToken(TokKind::kRbrace, "expects '}' to end ranges"); + } + do { + LocTy loc = lexer_.GetLoc(); + ranges.emplace_back(); + if (!ParseInt64List(TokKind::kLsquare, TokKind::kRsquare, TokKind::kColon, + &ranges.back())) { + return false; + } + const auto& range = ranges.back(); + if (range.size() != 2 && range.size() != 3) { + return Error(loc, Printf("expects [start:limit:step] or [start:limit], " + "but sees %ld elements.", + range.size())); + } + } while (EatIfPresent(TokKind::kComma)); + + for (const auto& range : ranges) { + result->starts.push_back(range[0]); + result->limits.push_back(range[1]); + result->strides.push_back(range.size() == 3 ? range[2] : 1); + } + return ParseToken(TokKind::kRbrace, "expects '}' to end ranges"); +} + +// int64list ::= start int64_elements end +// int64_elements +// ::= /*empty*/ +// ::= int64_val (delim int64_val)* +bool HloParser::ParseInt64List(const TokKind start, const TokKind end, + const TokKind delim, + std::vector* result) { + if (!ParseToken(start, StrCat("expects an int64 list starting with ", + TokKindToString(start)))) { + return false; + } + if (lexer_.GetKind() == end) { + // empty + } else { + do { + int64 i; + if (!ParseInt64(&i)) { + return false; + } + result->push_back(i); + } while (EatIfPresent(delim)); + } + return ParseToken( + end, StrCat("expects an int64 list to end with ", TokKindToString(end))); } // param_list ::= '(' param_list1 ')' @@ -735,12 +1985,171 @@ bool HloParser::ParseAttributeName(string* result) { return true; } +bool HloParser::ParseString(string* result) { + VLOG(1) << "ParseString"; + if (lexer_.GetKind() != TokKind::kString) { + return TokenError("expects string"); + } + *result = lexer_.GetStrVal(); + lexer_.Lex(); + return true; +} + +bool HloParser::ParseDxD(const string& name, std::vector* result) { + LocTy loc = lexer_.GetLoc(); + if (!result->empty()) { + return Error(loc, + Printf("sub-attribute '%s=' already exists", name.c_str())); + } + // 1D + if (lexer_.GetKind() == TokKind::kInt) { + int64 number; + if (!ParseInt64(&number)) { + return Error(loc, Printf("expects sub-attribute '%s=i'", name.c_str())); + } + result->push_back(number); + return true; + } + // 2D or higher. + if (lexer_.GetKind() == TokKind::kDxD) { + string str = lexer_.GetStrVal(); + if (!SplitAndParseAsInts(str, 'x', result)) { + return Error(loc, + Printf("expects sub-attribute '%s=ixj...'", name.c_str())); + } + lexer_.Lex(); + return true; + } + return TokenError("expects token type kInt or kDxD"); +} + +bool HloParser::ParseWindowPad(std::vector>* pad) { + LocTy loc = lexer_.GetLoc(); + if (!pad->empty()) { + return Error(loc, "sub-attribute 'pad=' already exists"); + } + if (lexer_.GetKind() != TokKind::kPad) { + return TokenError("expects window pad pattern, e.g., '0_0x3_3'"); + } + string str = lexer_.GetStrVal(); + std::vector padding_str = Split(str, 'x'); + for (int i = 0; i < padding_str.size(); i++) { + std::vector low_high; + if (!SplitAndParseAsInts(padding_str[i], '_', &low_high) || + low_high.size() != 2) { + return Error(loc, + "expects padding_low and padding_high separated by '_'"); + } + pad->push_back(low_high); + } + lexer_.Lex(); + return true; +} + +// This is the inverse xla::ToString(PaddingConfig). The padding config string +// looks like "0_0_0x3_3_1". The string is first separated by 'x', each +// substring represents one PaddingConfigDimension. The substring is 3 (or 2) +// numbers joined by '_'. +bool HloParser::ParsePaddingConfig(PaddingConfig* padding) { + if (lexer_.GetKind() != TokKind::kPad) { + return TokenError("expects padding config, e.g., '0_0_0x3_3_1'"); + } + LocTy loc = lexer_.GetLoc(); + string str = lexer_.GetStrVal(); + std::vector padding_str = Split(str, 'x'); + for (const auto& padding_dim_str : padding_str) { + std::vector padding_dim; + if (!SplitAndParseAsInts(padding_dim_str, '_', &padding_dim) || + (padding_dim.size() != 2 && padding_dim.size() != 3)) { + return Error(loc, + "expects padding config pattern like 'low_high_interior' or " + "'low_high'"); + } + auto* dim = padding->add_dimensions(); + dim->set_edge_padding_low(padding_dim[0]); + dim->set_edge_padding_high(padding_dim[1]); + dim->set_interior_padding(padding_dim.size() == 3 ? padding_dim[2] : 0); + } + lexer_.Lex(); + return true; +} + +// '{' metadata_string '}' +bool HloParser::ParseMetadata(OpMetadata* metadata) { + std::unordered_map attrs; + optional op_type; + optional op_name; + optional source_file; + optional source_line; + attrs["op_type"] = {/*required=*/false, AttrTy::kString, &op_type}; + attrs["op_name"] = {/*required=*/false, AttrTy::kString, &op_name}; + attrs["source_file"] = {/*required=*/false, AttrTy::kString, &source_file}; + attrs["source_line"] = {/*required=*/false, AttrTy::kInt32, &source_line}; + if (!ParseSubAttributes(attrs)) { + return false; + } + if (op_type) { + metadata->set_op_type(*op_type); + } + if (op_name) { + metadata->set_op_name(*op_name); + } + if (source_file) { + metadata->set_source_file(*source_file); + } + if (source_line) { + metadata->set_source_line(*source_line); + } + return true; +} + bool HloParser::ParseOpcode(HloOpcode* result) { VLOG(1) << "ParseOpcode"; - if (lexer_.GetKind() != TokKind::kOpcode) { + if (lexer_.GetKind() != TokKind::kIdent) { return TokenError("expects opcode"); } - *result = lexer_.GetOpcodeVal(); + string val = lexer_.GetStrVal(); + auto status_or_result = StringToHloOpcode(val); + if (!status_or_result.ok()) { + return TokenError( + Printf("expects opcode but sees: %s, error: %s", val.c_str(), + status_or_result.status().error_message().c_str())); + } + *result = status_or_result.ValueOrDie(); + lexer_.Lex(); + return true; +} + +bool HloParser::ParseFusionKind(HloInstruction::FusionKind* result) { + VLOG(1) << "ParseFusionKind"; + if (lexer_.GetKind() != TokKind::kIdent) { + return TokenError("expects fusion kind"); + } + string val = lexer_.GetStrVal(); + auto status_or_result = StringToFusionKind(val); + if (!status_or_result.ok()) { + return TokenError( + Printf("expects fusion kind but sees: %s, error: %s", val.c_str(), + status_or_result.status().error_message().c_str())); + } + *result = status_or_result.ValueOrDie(); + lexer_.Lex(); + return true; +} + +bool HloParser::ParseRandomDistribution(RandomDistribution* result) { + VLOG(1) << "ParseRandomDistribution"; + if (lexer_.GetKind() != TokKind::kIdent) { + return TokenError("expects random distribution"); + } + string val = lexer_.GetStrVal(); + auto status_or_result = StringToRandomDistribution(val); + if (!status_or_result.ok()) { + return TokenError( + Printf("expects random distribution but sees: %s, error: %s", + val.c_str(), status_or_result.status().error_message().c_str())); + } + *result = status_or_result.ValueOrDie(); lexer_.Lex(); return true; } @@ -755,7 +2164,7 @@ bool HloParser::ParseInt64(int64* result) { return true; } -bool HloParser::ParseDecimal(double* result) { +bool HloParser::ParseDouble(double* result) { switch (lexer_.GetKind()) { case TokKind::kDecimal: *result = lexer_.GetDecimalVal(); @@ -763,6 +2172,15 @@ bool HloParser::ParseDecimal(double* result) { case TokKind::kInt: *result = static_cast(lexer_.GetInt64Val()); break; + case TokKind::kw_nan: + *result = std::numeric_limits::quiet_NaN(); + break; + case TokKind::kw_inf: + *result = std::numeric_limits::infinity(); + break; + case TokKind::kNegInf: + *result = -std::numeric_limits::infinity(); + break; default: return TokenError("expects decimal or integer"); } @@ -781,6 +2199,7 @@ bool HloParser::ParseBool(bool* result) { } bool HloParser::ParseToken(TokKind kind, const string& msg) { + VLOG(1) << "ParseToken " << TokKindToString(kind) << " " << msg; if (lexer_.GetKind() != kind) { return TokenError(msg); } @@ -796,33 +2215,39 @@ bool HloParser::EatIfPresent(TokKind kind) { return true; } -bool HloParser::AddInstruction(const string& name, - HloInstruction* instruction) { +bool HloParser::AddInstruction(const string& name, HloInstruction* instruction, + LocTy name_loc) { auto result = instruction_pool_.insert({name, instruction}); if (!result.second) { - return TokenError(StrCat("instruction already exists: ", name)); + return Error(name_loc, StrCat("instruction already exists: ", name)); } return true; } -bool HloParser::AddComputation(const string& name, - HloComputation* computation) { +bool HloParser::AddComputation(const string& name, HloComputation* computation, + LocTy name_loc) { auto result = computation_pool_.insert({name, computation}); if (!result.second) { - return TokenError(StrCat("computation already exists: ", name)); + return Error(name_loc, StrCat("computation already exists: ", name)); } return true; } } // namespace -StatusOr> Parse(StringPiece str) { - HloParser parser(str); +StatusOr> Parse(StringPiece str, + const HloModuleConfig& config) { + HloParser parser(str, config); if (!parser.Run()) { - return InvalidArgument("Syntax error: %s", parser.GetError().c_str()); + return InvalidArgument("Syntax error:\n%s", parser.GetError().c_str()); } return parser.ConsumeHloModule(); } +StatusOr> Parse(StringPiece str) { + HloModuleConfig config; + return Parse(str, config); +} + } // namespace tools } // namespace xla diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.h b/tensorflow/compiler/xla/tools/parser/hlo_parser.h index 9aaf18ef20d769cd9ac6f0e48bc92f62292ba31a..2f97a2b9b19d0cdb64a2869913da62c55e14c1d5 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser.h +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.h @@ -28,7 +28,12 @@ namespace xla { namespace tools { // The api of the hlo parser. Given a string in the HloModule::ToString() -// format, returns the parsed HloModule. +// format, parses the string and creates a HloModule with the given config. +StatusOr> Parse(tensorflow::StringPiece str, + const HloModuleConfig& config); + +// The api of the hlo parser. Given a string in the HloModule::ToString() +// format, parses the string and creates a HloModule with default config. StatusOr> Parse(tensorflow::StringPiece str); } // namespace tools diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc index 5be4d6a2cb1b09355e09e25a40e8dc88bae01650..7eebc5dc93ffff1f5895e69023a4d81ab7279241 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc @@ -17,12 +17,16 @@ limitations under the License. #include #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/test.h" namespace xla { namespace tools { namespace { +using tensorflow::StringPiece; +using tensorflow::strings::StrCat; + struct TestData { string test_name; string module_string; @@ -32,6 +36,10 @@ string TestDataToString(const ::testing::TestParamInfo& data) { return data.param.test_name; } +// For each string below, we check that: +// - we parse it to an HloModule successfully, and +// - the stringification of the resulting HloModule is equal to our original +// string. std::vector CreateTestCases() { // clang-format off return std::vector({ @@ -40,10 +48,11 @@ std::vector CreateTestCases() { "AxpyParam", R"(HloModule axpy_module: -ENTRY %axpy.v5 (alpha: f32[2,4], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { - %alpha = f32[2,4]{1,0} parameter(0) +ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { + %alpha = f32[] parameter(0) + %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={} %x = f32[2,4]{1,0} parameter(1) - %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %alpha, f32[2,4]{1,0} %x) + %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x) %y = f32[2,4]{1,0} parameter(2) ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y) } @@ -56,7 +65,7 @@ ENTRY %axpy.v5 (alpha: f32[2,4], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { R"(HloModule constant_pred_module: ENTRY %constant_pred () -> pred[] { - ROOT %constant = pred[] constant(true) + ROOT %constant = pred[] constant(true), metadata={op_type="const" op_name="\"it\'s not a problem\n" source_file="path/to/test.cc" source_line=68} } )" @@ -74,12 +83,80 @@ ENTRY %constant_s32 () -> s32[] { }, // f32 constant, but the value is not a decimal { -"ConstantF32", R"(HloModule ConstantF32_module: +"ConstantF32", +R"(HloModule ConstantF32_module: ENTRY %ConstantF32.v4 () -> f32[] { ROOT %constant = f32[] constant(42) } +)" +}, +// f32 constant, rank 1 empty array. +{ +"ConstantF32R1Empty", +R"(HloModule ConstantF32Empty_module: + +ENTRY %ConstantF32Empty.v4 () -> f32[0] { + ROOT %constant = f32[0]{0} constant({}) +} + +)" +}, +// f32 constant, rank 4 empty array. +{ +"ConstantF32R4Empty", +R"(HloModule ConstantF32R4Empty_module: + +ENTRY %ConstantF32R4Empty.v4 () -> f32[2,0,4,3] { + ROOT %constant = f32[2,0,4,3]{3,2,1,0} constant(f32[2,0,4,3] { { /*i0=0*/ }, { /*i0=1*/ } }) +} + +)" +}, +// constant 4D +{ +"Constant4D", +R"(HloModule Small_3x2x1x1_module: + +ENTRY %Small_3x2x1x1.v1 () -> f32[3,2,1,1] { + ROOT %constant = f32[3,2,1,1]{3,2,1,0} constant(f32[3,2,1,1] { { /*i0=0*/ { /*i1=0*/ {-1} }, { /*i1=1*/ {4.1} } }, { /*i0=1*/ { /*i1=0*/ {2} }, { /*i1=1*/ {4.1} } }, { /*i0=2*/ { /*i1=0*/ {5} }, { /*i1=1*/ {4.4} } } }) +} + +)" +}, +// non-finite constants: nan, inf, -inf +{ +"ConstantNonFinite", +R"(HloModule IsFiniteR1F32s_module: + +ENTRY %IsFiniteR1F32s.v2 () -> pred[6] { + %constant = f32[6]{0} constant({nan, 7, nan, -1, inf, -inf}) + ROOT %is-finite = pred[6]{0} is-finite(f32[6]{0} %constant) +} + +)" +}, +// constant f16 +{ +"ConstantF16", +R"(HloModule ConstantF16_module: + +ENTRY %ConstantF16.v4 () -> f16[] { + ROOT %constant = f16[] constant(500) +} + +)" +}, +// bf16 +{ +"BF16", +R"(HloModule BF16: + +ENTRY %BF16.v4 () -> bf16[] { + ROOT %constant = bf16[] constant(500) +} + )" }, // constant + constant @@ -92,6 +169,17 @@ ENTRY %add_constants () -> f32[] { ROOT %add = f32[] add(f32[] %constant, f32[] %constant) } +)" +}, +// tuple constant +{ +"TupleConstant", +R"(HloModule TupleConstant_module: + +ENTRY %TupleConstant.v1 () -> (f32[2,1], f32[2]) { + ROOT %constant = (f32[2,1]{1,0}, f32[2]{0}) constant((f32[2,1], f32[2]) ( f32[2,1] { { 1 }, { 2 } }, {2, 42} )) +} + )" }, // v1 > v2 ? v1 : v2 @@ -103,7 +191,7 @@ ENTRY %SelectR1F32WithCmpR1F32sFromParamsSmall.v4 (v1: f32[4], v2: f32[4]) -> f3 %v1 = f32[4]{0} parameter(0), sharding={maximal device=1} %v2 = f32[4]{0} parameter(1), sharding={maximal device=1} %greater-than = pred[4]{0} greater-than(f32[4]{0} %v1, f32[4]{0} %v2), sharding={replicated} - ROOT %select = f32[4]{0} select(pred[4]{0} %greater-than, f32[4]{0} %v1, f32[4]{0} %v2) + ROOT %select = f32[4]{0} select(pred[4]{0} %greater-than, f32[4]{0} %v1, f32[4]{0} %v2), sharding={} } )" @@ -131,6 +219,19 @@ ENTRY %TupleCreate.v4 (v1: f32[], v2: f32[3], v3: f32[2,3]) -> (f32[], f32[3], f ROOT %tuple = (f32[], f32[3]{0}, f32[2,3]{1,0}) tuple(f32[] %v1, f32[3]{0} %v2, f32[2,3]{1,0} %v3) } +)" +}, +{ +"ShardedTupleCreate", +R"(HloModule ShardedTupleCreate_module: + +ENTRY %ShardedTupleCreate.v4 (v1: f32[], v2: f32[3], v3: f32[2,3]) -> (f32[], f32[3], f32[2,3]) { + %v1 = f32[] parameter(0) + %v2 = f32[3]{0} parameter(1) + %v3 = f32[2,3]{1,0} parameter(2) + ROOT %tuple = (f32[], f32[3]{0}, f32[2,3]{1,0}) tuple(f32[] %v1, f32[3]{0} %v2, f32[2,3]{1,0} %v3), sharding={{replicated}, {maximal device=0}, {replicated}} +} + )" }, // int32 result = 0; @@ -164,9 +265,11 @@ ENTRY %WhileWithScalarS32Result.v2 () -> s32[] { R"(HloModule TwoSendRecvBothWayRecvFist_module: ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { - %recv = f32[] recv(), channel_id=15, sharding={maximal device=1} - ROOT %constant = f32[] constant(2.1), sharding={maximal device=0} - %send = () send(f32[] %constant), channel_id=16, sharding={maximal device=0} + %recv = (f32[], u32[]) recv(), channel_id=15, sharding={maximal device=1} + ROOT %recv-done = f32[] recv-done((f32[], u32[]) %recv), channel_id=15, sharding={maximal device=1} + %constant = f32[] constant(2.1), sharding={maximal device=0} + %send = (f32[], u32[]) send(f32[] %constant), channel_id=16, sharding={maximal device=0}, control-predecessors={%recv} + %send-done = () send-done((f32[], u32[]) %send), channel_id=16, sharding={maximal device=0} } )" @@ -176,11 +279,11 @@ ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { "GetTupleElement", R"(HloModule GetTupleElement_module: -ENTRY %GetTupleElement.v4 () -> s32[] { - %constant = f32[] constant(1.23) - %constant.1 = s32[] constant(4) - %tuple = (f32[], s32[]) tuple(f32[] %constant, s32[] %constant.1) - ROOT %get-tuple-element = s32[] get-tuple-element((f32[], s32[]) %tuple), index=1, sharding={maximal device=0} +ENTRY %GetTupleElement.v4 () -> s32[2,3] { + %constant = f32[3]{0} constant({1, 2, 3}) + %constant.1 = s32[2,3]{1,0} constant(s32[2,3] { { 1, 2, 3 }, { 4, 5, 6 } }) + %tuple = (f32[3]{0}, s32[2,3]{1,0}) tuple(f32[3]{0} %constant, s32[2,3]{1,0} %constant.1) + ROOT %get-tuple-element = s32[2,3]{1,0} get-tuple-element((f32[3]{0}, s32[2,3]{1,0}) %tuple), index=1, sharding={maximal device=0} } )" @@ -199,6 +302,407 @@ ENTRY %CallR0F32IdentityScalar.v2 () -> f32[] { ROOT %call = f32[] call(f32[] %constant), to_apply=%Identity.v1 } +)" +}, +// reduce window +{ +"ReduceWindow", +R"(HloModule R4UnitWindow_module: + +%add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] { + %lhs = f32[] parameter(0) + %rhs = f32[] parameter(1) + ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs) +} + +ENTRY %R4UnitWindow.v3 (operand: f32[13,12,8,15]) -> f32[13,3,8,15] { + %operand = f32[13,12,8,15]{0,3,2,1} parameter(0) + %constant = f32[] constant(0) + ROOT %reduce-window = f32[13,3,8,15]{0,3,2,1} reduce-window(f32[13,12,8,15]{0,3,2,1} %operand, f32[] %constant), window={size=1x1x7x1 stride=1x4x1x1 pad=0_0x0_0x3_3x0_0}, to_apply=%add_F32.v3 +} + +)" +}, +// reduce window on scalar +{ +"ReduceWindowScalar", +R"(HloModule reduce_window_scalar: + +%add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] { + %lhs = f32[] parameter(0) + %rhs = f32[] parameter(1) + ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs) +} + +ENTRY %R4UnitWindowScalar () -> f32[] { + %constant = f32[] constant(42) + %constant.1 = f32[] constant(1) + ROOT %reduce-window = f32[] reduce-window(f32[] %constant, f32[] %constant.1), to_apply=%add_F32.v3 +} + +)" +}, +// convolution +{ +"Convolution", +R"(HloModule Convolve1D1Window_0_module: + +ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] { + %input = f32[1,2,1]{2,1,0} parameter(0) + %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input) + %filter = f32[1,1,1]{2,1,0} parameter(1) + ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f +} + +)" +}, +// convolution rank 2 +{ +"ConvolutionR2", +R"(HloModule ConvolveR2_module: + +ENTRY %ConvolveR2.v3 (input: f32[1,2], filter: f32[1,1]) -> f32[1,2] { + %input = f32[1,2]{1,0} parameter(0) + %filter = f32[1,1]{1,0} parameter(1) + ROOT %convolution = f32[1,2]{0,1} convolution(f32[1,2]{1,0} %input, f32[1,1]{1,0} %filter), dim_labels=bf_io->bf +} + +)" +}, +// convolution backward +{ +"ConvolutionBackward", +R"(HloModule ConvolveBackward_module: + +ENTRY %ConvolveBackward (input: f32[128,7,7,512], filter: f32[3,3,512,512]) -> f32[128,14,14,512] { + %input = f32[128,7,7,512]{0,3,2,1} parameter(0) + %filter = f32[3,3,512,512]{3,2,1,0} parameter(1) + ROOT %convolution-base-dilated = f32[128,14,14,512]{0,3,2,1} convolution(f32[128,7,7,512]{0,3,2,1} %input, f32[3,3,512,512]{3,2,1,0} %filter), window={size=3x3 pad=1_2x1_2 lhs_dilate=2x2 rhs_reversal=1x1}, dim_labels=b01f_01oi->b01f +} + +)" +}, +// reverse(constant) +{ +"Reverse4D", +R"(HloModule Reverse4DFloatArrayOnDim01_module: + +ENTRY %Reverse4DFloatArrayOnDim01.v2 () -> f32[4,3,2,1] { + %constant = f32[4,3,2,1]{0,1,2,3} constant(f32[4,3,2,1] { { /*i0=0*/ { /*i1=0*/ {1}, {2} }, { /*i1=1*/ {3}, {4} }, { /*i1=2*/ {5}, {6} } }, { /*i0=1*/ { /*i1=0*/ {7}, {8} }, { /*i1=1*/ {9}, {10} }, { /*i1=2*/ {11}, {12} } }, { /*i0=2*/ { /*i1=0*/ {13}, {14} }, { /*i1=1*/ {15}, {16} }, { /*i1=2*/ {17}, {18} } }, { /*i0=3*/ { /*i1=0*/ {19}, {20} }, { /*i1=1*/ {21}, {22} }, { /*i1=2*/ {23}, {24} } } }) + ROOT %reverse = f32[4,3,2,1]{0,1,2,3} reverse(f32[4,3,2,1]{0,1,2,3} %constant), dimensions={0,1} +} + +)" +}, +// concat +{ +"Concat", +R"(HloModule Concat2x3With2x5_module: + +ENTRY %Concat2x3With2x5.v3 () -> f32[2,8] { + %constant = f32[2,3]{1,0} constant(f32[2,3] { { 0, 1, 2 }, { 1000, 1001, 1002 } }) + %constant.1 = f32[2,5]{1,0} constant(f32[2,5] { { 64, 65, 66, 67, 68 }, { 1064, 1065, 1066, 1067, 1068 } }) + ROOT %concatenate = f32[2,8]{1,0} concatenate(f32[2,3]{1,0} %constant, f32[2,5]{1,0} %constant.1), dimensions={1} +} + +)" +}, +// map +{ +"Map", +R"(HloModule MapBinaryAdder_module: + +%add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] { + %lhs = f32[] parameter(0) + %rhs = f32[] parameter(1) + ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs) +} + +ENTRY %MapBinaryAdder.v3 (param0: f32[4], param1: f32[4]) -> f32[4] { + %param0 = f32[4]{0} parameter(0) + %param1 = f32[4]{0} parameter(1) + ROOT %map = f32[4]{0} map(f32[4]{0} %param0, f32[4]{0} %param1), to_apply=%add_F32.v3 +} + +)" +}, +// reduce +{ +"Reduce", +R"(HloModule ReduceR3ToR2_module: + +%add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] { + %lhs = f32[] parameter(0) + %rhs = f32[] parameter(1) + ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs) +} + +ENTRY %ReduceR3ToR2.v3 (input: f32[8,16,256]) -> f32[8,16] { + %input = f32[8,16,256]{2,1,0} parameter(0) + %constant = f32[] constant(0) + ROOT %reduce = f32[8,16]{1,0} reduce(f32[8,16,256]{2,1,0} %input, f32[] %constant), dimensions={2}, to_apply=%add_F32.v3 +} + +)" +}, +// select and scatter +{ +"SelectAndScatter", +R"(HloModule R4F32OverlapSmall_module: + +%ge_F32.v3 (lhs: f32[], rhs: f32[]) -> pred[] { + %lhs = f32[] parameter(0) + %rhs = f32[] parameter(1) + ROOT %greater-than-or-equal-to = pred[] greater-than-or-equal-to(f32[] %lhs, f32[] %rhs) +} + +%add_F32.v3 (lhs.1: f32[], rhs.1: f32[]) -> f32[] { + %lhs.1 = f32[] parameter(0) + %rhs.1 = f32[] parameter(1) + ROOT %add = f32[] add(f32[] %lhs.1, f32[] %rhs.1) +} + +ENTRY %R4F32OverlapSmall.v4 () -> f32[4,5,1,1] { + %constant = f32[4,5,1,1]{3,2,1,0} constant(f32[4,5,1,1] { { /*i0=0*/ { /*i1=0*/ {7} }, { /*i1=1*/ {2} }, { /*i1=2*/ {5} }, { /*i1=3*/ {3} }, { /*i1=4*/ {8} } }, { /*i0=1*/ { /*i1=0*/ {3} }, { /*i1=1*/ {8} }, { /*i1=2*/ {9} }, { /*i1=3*/ {3} }, { /*i1=4*/ {4} } }, { /*i0=2*/ { /*i1=0*/ {1} }, { /*i1=1*/ {5} }, { /*i1=2*/ {7} }, { /*i1=3*/ {5} }, { /*i1=4*/ {6} } }, { /*i0=3*/ { /*i1=0*/ {0} }, { /*i1=1*/ {6} }, { /*i1=2*/ {2} }, { /*i1=3*/ {10} }, { /*i1=4*/ {2} } } }) + %constant.1 = f32[2,2,1,1]{3,2,1,0} constant(f32[2,2,1,1] { { /*i0=0*/ { /*i1=0*/ {2} }, { /*i1=1*/ {6} } }, { /*i0=1*/ { /*i1=0*/ {3} }, { /*i1=1*/ {1} } } }) + %constant.2 = f32[] constant(0) + ROOT %select-and-scatter = f32[4,5,1,1]{3,2,1,0} select-and-scatter(f32[4,5,1,1]{3,2,1,0} %constant, f32[2,2,1,1]{3,2,1,0} %constant.1, f32[] %constant.2), window={size=2x3x1x1 stride=2x2x1x1}, select=%ge_F32.v3, scatter=%add_F32.v3 +} + +)" +}, +// select and scatter on scalar +{ +"SelectAndScatterScalar", +R"(HloModule select_and_scatter_scalar: + +%ge_F32.v3 (lhs: f32[], rhs: f32[]) -> pred[] { + %lhs = f32[] parameter(0) + %rhs = f32[] parameter(1) + ROOT %greater-than-or-equal-to = pred[] greater-than-or-equal-to(f32[] %lhs, f32[] %rhs) +} + +%add_F32.v3 (lhs.1: f32[], rhs.1: f32[]) -> f32[] { + %lhs.1 = f32[] parameter(0) + %rhs.1 = f32[] parameter(1) + ROOT %add = f32[] add(f32[] %lhs.1, f32[] %rhs.1) +} + +ENTRY %SelectAndScatterScalar () -> f32[] { + %constant = f32[] constant(42) + %constant.1 = f32[] constant(1) + %constant.2 = f32[] constant(2) + ROOT %select-and-scatter = f32[] select-and-scatter(f32[] %constant, f32[] %constant.1, f32[] %constant.2), select=%ge_F32.v3, scatter=%add_F32.v3 +} + +)" +}, +// slice +{ +"Slice", +R"(HloModule slice_module: + +ENTRY %slice.v2 (p0: f32[3,3,4,4]) -> f32[3,3,2,4] { + %p0 = f32[3,3,4,4]{3,2,1,0} parameter(0) + ROOT %slice = f32[3,3,2,4]{3,2,1,0} slice(f32[3,3,4,4]{3,2,1,0} %p0), slice={[0:3:1], [0:3:1], [0:4:2], [0:4:1]} +} + +)" +}, +// slice, no stride +{ +"SliceNoStride", +R"(HloModule Slice3x3x3_To_1x3x3_F32_module: + +ENTRY %Slice3x3x3_To_1x3x3_F32.v2 () -> f32[1,3,3] { + %constant = f32[3,3,3]{2,1,0} constant(f32[3,3,3] { { { 0, 1, 2 }, { 3, 4, 5 }, { 6, 7, 8 } }, { { 9, 10, 11 }, { 12, 13, 14 }, { 15, 16, 17 } }, { { 18, 19, 20 }, { 21, 22, 23 }, { 24, 25, 26 } } }) + ROOT %slice = f32[1,3,3]{2,1,0} slice(f32[3,3,3]{2,1,0} %constant), slice={[0:1], [0:3], [0:3]} +} + +)" +}, +// slice R0 +{ +"SliceR0", +R"(HloModule SliceR0_module: + +ENTRY %SliceR0.v2 () -> s32[] { + %constant = s32[] constant(1) + ROOT %slice = s32[] slice(s32[] %constant), slice={} +} + +)" +}, +// transpose +{ +"Transpose", +R"(HloModule Transpose_module: + +ENTRY %Transpose.v2 () -> s32[1,2,3] { + %constant = s32[1,2,3]{2,1,0} constant(s32[1,2,3] { { { 1, 2, 3 }, { 4, 5, 6 } } }) + ROOT %transpose = s32[1,2,3]{2,1,0} transpose(s32[1,2,3]{2,1,0} %constant), dimensions={0,1,2} +} + +)" +}, +// Dynamic slice +{ +"DynamicSlice", +R"(HloModule DynamicSlice_module: + +ENTRY %DynamicSlice.v5 (original_parameter: s32[2,2,258], start_index: s32[1]) -> s32[2,2,258] { + %original_parameter = s32[2,2,258]{2,1,0} parameter(0) + %constant = s32[1]{0} constant({0}) + %start_index = s32[1]{0} parameter(1) + %concatenate = s32[3]{0} concatenate(s32[1]{0} %constant, s32[1]{0} %constant, s32[1]{0} %start_index), dimensions={0} + ROOT %dynamic-slice = s32[2,2,258]{2,1,0} dynamic-slice(s32[2,2,258]{2,1,0} %original_parameter, s32[3]{0} %concatenate), dynamic_slice_sizes={2,2,258} +} + +)" +}, +// Dynamic update slice +{ +"DynamicUpdateSlice", +R"(HloModule DynamicUpdateSlice_module: + +ENTRY %DynamicUpdateSlice.v4 (input: s32[1,1,25,1], update: s32[1,1,2,1], start_indices: s32[4]) -> s32[1,1,25,1] { + %input = s32[1,1,25,1]{3,2,1,0} parameter(0) + %update = s32[1,1,2,1]{3,2,1,0} parameter(1) + %start_indices = s32[4]{0} parameter(2) + ROOT %dynamic-update-slice = s32[1,1,25,1]{3,2,1,0} dynamic-update-slice(s32[1,1,25,1]{3,2,1,0} %input, s32[1,1,2,1]{3,2,1,0} %update, s32[4]{0} %start_indices) +} + +)" +}, +// batch norm training +{ +"BatchNormTraining", +R"(HloModule BasicTraining_module: + +ENTRY %BasicTraining.v4 () -> (f32[2,2,1,2], f32[2], f32[2]) { + %constant = f32[2,2,1,2]{3,2,1,0} constant(f32[2,2,1,2] { { /*i0=0*/ { /*i1=0*/ {1, 2} }, { /*i1=1*/ {3, 4} } }, { /*i0=1*/ { /*i1=0*/ {5, 6} }, { /*i1=1*/ {7, 8} } } }) + %constant.1 = f32[2]{0} constant({2, 3}) + %constant.2 = f32[2]{0} constant({1, 2}) + ROOT %batch-norm-training = (f32[2,2,1,2]{3,2,1,0}, f32[2]{0}, f32[2]{0}) batch-norm-training(f32[2,2,1,2]{3,2,1,0} %constant, f32[2]{0} %constant.1, f32[2]{0} %constant.2), epsilon=0.001, feature_index=3 +} + +)" +}, +// batch norm inference +{ +"BatchNormInference", +R"(HloModule BatchNormInference_module: + +ENTRY %BatchNormInference.v6 (input: f32[2,2,2,2], offset: f32[2], scale: f32[2], mean: f32[2], variance: f32[2]) -> f32[2,2,2,2] { + %input = f32[2,2,2,2]{3,2,1,0} parameter(0) + %offset = f32[2]{0} parameter(1) + %scale = f32[2]{0} parameter(2) + %mean = f32[2]{0} parameter(3) + %variance = f32[2]{0} parameter(4) + ROOT %batch-norm-inference = f32[2,2,2,2]{3,2,1,0} batch-norm-inference(f32[2,2,2,2]{3,2,1,0} %input, f32[2]{0} %offset, f32[2]{0} %scale, f32[2]{0} %mean, f32[2]{0} %variance), epsilon=0.001, feature_index=0 +} + +)" +}, +// batch norm grad +{ +"BatchNormGrad", +R"(HloModule BatchNormGrad_module: + +ENTRY %BatchNormGrad.v4 (input: f32[2,2,2,2], scale: f32[2], mean: f32[2], variance: f32[2], grad_output: f32[2,2,2,2]) -> (f32[2,2,2,2], f32[2], f32[2]) { + %input = f32[2,2,2,2]{3,2,1,0} parameter(0) + %scale = f32[2]{0} parameter(1) + %mean = f32[2]{0} parameter(2) + %variance = f32[2]{0} parameter(3) + %grad_output = f32[2,2,2,2]{3,2,1,0} parameter(4) + ROOT %batch-norm-grad = (f32[2,2,2,2]{3,2,1,0}, f32[2]{0}, f32[2]{0}) batch-norm-grad(f32[2,2,2,2]{3,2,1,0} %input, f32[2]{0} %scale, f32[2]{0} %mean, f32[2]{0} %variance, f32[2,2,2,2]{3,2,1,0} %grad_output), epsilon=0.001, feature_index=0 +} + +)" +}, +// pad +{ +"Pad", +R"(HloModule Pad1DS3Array_module: + +ENTRY %Pad1DS3Array.v3 () -> f32[8] { + %constant = f32[3]{0} constant({1, 2, 3}) + %constant.1 = f32[] constant(0.1) + ROOT %pad = f32[8]{0} pad(f32[3]{0} %constant, f32[] %constant.1), padding=3_1 +} + +)" +}, +// pad has interior +{ +"PadHasInterior", +R"(HloModule PadHasInterior_module: + +ENTRY %PadHasInterior.v3 (input: f32[1,25,7,7]) -> f32[1,25,17,11] { + %input = f32[1,25,7,7]{3,2,1,0} parameter(0) + %constant = f32[] constant(-5.123) + ROOT %pad = f32[1,25,17,11]{3,2,1,0} pad(f32[1,25,7,7]{3,2,1,0} %input, f32[] %constant), padding=0_0_0x0_0_0x2_2_1x2_2_0 +} + +)" +}, +// fusion +{ +"Fusion", +R"(HloModule fusion_module: + +%fused_computation (constant.param_0: f32[3,2,1,1], constant.1.param_1: f32[2]) -> f32[3,2,1,1] { + %constant.param_0 = f32[3,2,1,1]{3,2,1,0} parameter(0) + %constant.1.param_1 = f32[2]{0} parameter(1) + %broadcast = f32[3,2,1,1]{3,2,1,0} broadcast(f32[2]{0} %constant.1.param_1), dimensions={1} + ROOT %subtract = f32[3,2,1,1]{3,2,1,0} subtract(f32[3,2,1,1]{3,2,1,0} %constant.param_0, f32[3,2,1,1]{3,2,1,0} %broadcast) +} + +ENTRY %fusion.v3 () -> f32[3,2,1,1] { + %constant = f32[3,2,1,1]{3,2,1,0} constant(f32[3,2,1,1] { { /*i0=0*/ { /*i1=0*/ {-1} }, { /*i1=1*/ {4.1} } }, { /*i0=1*/ { /*i1=0*/ {2} }, { /*i1=1*/ {4.1} } }, { /*i0=2*/ { /*i1=0*/ {5} }, { /*i1=1*/ {4.4} } } }) + %constant.1 = f32[2]{0} constant({3.14, 4.25}) + ROOT %fusion = f32[3,2,1,1]{3,2,1,0} fusion(f32[3,2,1,1]{3,2,1,0} %constant, f32[2]{0} %constant.1), kind=kLoop, calls=%fused_computation +} + +)" +}, +// infeed/outfeed +{ +"InfeedOutfeed", +R"(HloModule outfeed_module: + +ENTRY %InfeedToOutfeed () -> (u32[3], pred[]) { + %infeed = (u32[3]{0}, pred[]) infeed() + %outfeed = () outfeed((u32[3]{0}, pred[]) %infeed) + ROOT %infeed.1 = (u32[3]{0}, pred[]) infeed() + %outfeed.1 = () outfeed((u32[3]{0}, pred[]) %infeed.1) +} + +)" +}, +// Rng +{ +"Rng", +R"(HloModule rng_module: + +ENTRY %Rng () -> f32[8] { + %constant = f32[] constant(0) + %constant.1 = f32[] constant(1) + ROOT %rng = f32[8]{0} rng(f32[] %constant, f32[] %constant.1), distribution=rng_uniform +} + +)" +}, +// Reduce precision +{ +"ReducePrevison", +R"(HloModule reduce_precision: + +ENTRY %ReducePrecision () -> f32[1] { + %constant = f32[1]{0} constant({3.14159}) + ROOT %reduce-precision = f32[1]{0} reduce-precision(f32[1]{0} %constant), exponent_bits=8, mantissa_bits=10 +} + )" } }); @@ -208,15 +712,24 @@ ENTRY %CallR0F32IdentityScalar.v2 () -> f32[] { class HloParserTest : public ::testing::Test, public ::testing::WithParamInterface { protected: - void ExpectSuccess() { + static void ExpectHasSubstr(StringPiece s, StringPiece expected) { + EXPECT_TRUE(StringPiece(s).contains(expected)) + << "'" << s << "' does not contain '" << expected << "'"; + } + + // Expects "ToString(Parse(string)) == string", that is, parses the string, + // asserts that it succeeded, stringifies the parsed module, and checks that + // the it equals the original string. + void ExpectEqual() { const string& original = GetParam().module_string; auto result = Parse(original); - TF_EXPECT_OK(result.status()); - EXPECT_EQ(original, result.ValueOrDie()->ToString()); + TF_ASSERT_OK(result.status()); + EXPECT_EQ(original, + result.ValueOrDie()->ToString(/*include_large_constants=*/true)); } }; -TEST_P(HloParserTest, Run) { ExpectSuccess(); } +TEST_P(HloParserTest, Run) { ExpectEqual(); } INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserTest, ::testing::ValuesIn(CreateTestCases()), @@ -301,6 +814,63 @@ ENTRY %SelectScalarS32True.v4 () -> s32[] { // but the constant names will not be exactly the same. } +TEST_F(HloParserTest, LiteralDimensionsMismatch_1) { + const string original = R"(HloModule some_2_module: + +ENTRY %some_2 () -> f32[2] { + ROOT %constant = f32[2]{0} constant({1,{2}}) +} + +)"; + auto result = Parse(original); + EXPECT_NE(tensorflow::Status::OK(), result.status()); + ExpectHasSubstr(result.status().error_message(), + "expects nested array in rank 1, but sees larger"); +} + +TEST_F(HloParserTest, LiteralDimensionsMismatch_2) { + const string original = R"(HloModule some_2x3_module: + +ENTRY %some_2x3 () -> f32[2,3] { + ROOT %constant = f32[2,3]{1,0} constant(f32[2,3] {1, 2, 3, 4, 5, 6}) +} + +)"; + auto result = Parse(original); + EXPECT_NE(tensorflow::Status::OK(), result.status()); + ExpectHasSubstr(result.status().error_message(), + "expects nested array in rank 2, but sees 1"); +} + +TEST_F(HloParserTest, LiteralDimensionsMismatch_3) { + const string original = R"(HloModule some_2x3x2_module: + +ENTRY %some_2x3x2 () -> f32[2,3,2] { + ROOT %constant = f32[2,3,2]{2,1,0} constant(f32[2,3,2] {{{1, 2}, {3, 4}, {5, 6}, {7, 8}, {9, 10}, {11, 12}}}) +} + +)"; + auto result = Parse(original); + EXPECT_NE(tensorflow::Status::OK(), result.status()); + ExpectHasSubstr(result.status().error_message(), + "expects 3 elements in the [0]th element"); +} + +TEST_F(HloParserTest, ConstantF16Overflow) { + const string original = + R"(HloModule ConstantF16Overflow_module: + +ENTRY %ConstantF16Overflow.v4 () -> f16[] { + ROOT %constant = f16[] constant(-65505) +} + +)"; + auto result = Parse(original); + EXPECT_NE(tensorflow::Status::OK(), result.status()); + ExpectHasSubstr(result.status().error_message(), + "is out of range for literal's primitive type F16"); +} + TEST_F(HloParserTest, ConstantWithExp) { const string original = R"(HloModule ConstantWithExp_module: @@ -316,6 +886,130 @@ ENTRY %ConstantWithExp.v4 () -> f32[] { // printed as "300". } +TEST_F(HloParserTest, AttibutesAnyOrder) { + const string original = R"(HloModule any_order_module: + +ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] { + %input = f32[1,2,1]{2,1,0} parameter(0) + %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input) + %filter = f32[1,1,1]{2,1,0} parameter(1) + ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), sharding={maximal device=1}, dim_labels=b0f_0io->b0f, window={pad=1_1 size=2} +} + +)"; + TF_EXPECT_OK(Parse(original).status()); +} + +TEST_F(HloParserTest, InvalidDimLabels) { + string prefix = R"(HloModule invalid_dim_labels_module: + +ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] { + %input = f32[1,2,1]{2,1,0} parameter(0) + %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input) + %filter = f32[1,1,1]{2,1,0} parameter(1) + ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1} )"; + string suffix = R"( +} + +)"; + + ExpectHasSubstr(Parse(StrCat(prefix, ",dim_labels=00_01_10", suffix)) + .status() + .error_message(), + "expects dim labels pattern"); + + ExpectHasSubstr(Parse(StrCat(prefix, ",dim_labels=010_1100->010", suffix)) + .status() + .error_message(), + "must have the same rank"); +} + +TEST_F(HloParserTest, UnexpectedAttribute) { + const string original = R"(HloModule unexpected_attr_module: + +ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { + %recv = (f32[], u32[]) recv(), channel_id=15 + %recv-done = f32[] recv-done((f32[], u32[]) %recv), channel_id=15 + ROOT %constant = f32[] constant(2.1) + %send = (f32[], u32[]) send(f32[] %constant), channel_id=16, calls=%recv + %send-done = () send-done((f32[], u32[]) %send), channel_id=16 +} + +)"; + ExpectHasSubstr(Parse(original).status().error_message(), + "unexpected attribute calls"); +} + +TEST_F(HloParserTest, MissingAttribute) { + const string original = R"(HloModule missing_attr_module: + +ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { + %recv = (f32[], u32[]) recv(), channel_id=15 + %recv-done = f32[] recv-done((f32[], u32[]) %recv), channel_id=15 + ROOT %constant = f32[] constant(-2.1) + %send = (f32[], u32[]) send(f32[] %constant) + %send-done = () send-done((f32[], u32[]) %send), channel_id=16 +} + +)"; + ExpectHasSubstr(Parse(original).status().error_message(), + "attribute channel_id is expected but not seen"); +} + +TEST_F(HloParserTest, PredecessorUndefined) { + const string original = R"(HloModule pre_not_found_module: + +ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { + %recv = (f32[], u32[]) recv(), channel_id=15 + %recv-done = f32[] recv-done((f32[], u32[]) %recv), channel_id=15 + ROOT %constant = f32[] constant(2.1) + %send = (f32[], u32[]) send(f32[] %constant), channel_id=16, control-predecessors={%done} + %send-done = () send-done((f32[], u32[]) %send), channel_id=16 +} + +)"; + ExpectHasSubstr(Parse(original).status().error_message(), + "'done' is not defined"); +} + +TEST_F(HloParserTest, SliceAllowOmitStride1) { + const string original = R"(HloModule slice_module: + +ENTRY %slice.v2 (p0: f32[3,3,4,4]) -> f32[3,3,2,4] { + %p0 = f32[3,3,4,4]{3,2,1,0} parameter(0) + ROOT %slice = f32[3,3,2,4]{3,2,1,0} slice(f32[3,3,4,4]{3,2,1,0} %p0), slice={[0:3], [0:3], [0:4:2], [0:4]} +} + +)"; + TF_EXPECT_OK(Parse(original).status()); +} + +TEST_F(HloParserTest, PaddingConfigIsNotWindowPad) { + const string original = R"(HloModule window_pad_module: + +ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] { + %input = f32[1,2,1]{2,1,0} parameter(0) + %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input) + %filter = f32[1,1,1]{2,1,0} parameter(1) + ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), dim_labels=b0f_0io->b0f, window={pad=1_1_0 size=1} +} + +)"; + ExpectHasSubstr(Parse(original).status().error_message(), + "expects padding_low and padding_high separated by '_'"); +} + +TEST_F(HloParserTest, CommaBetweenSubAttributes) { + const string original = R"(HloModule test_comma_module: + +ENTRY %test_comma.v4 () -> f32[] { + ROOT %constant = f32[] constant(-4.2), metadata={source_line=5, op_type="::const"} +} + +)"; + TF_EXPECT_OK(Parse(original).status()); +} + } // namespace } // namespace tools } // namespace xla diff --git a/tensorflow/compiler/xla/tools/parser/hlo_token.h b/tensorflow/compiler/xla/tools/parser/hlo_token.h index a40300e2bf0d3279967826be6bf74875f8320f11..7928bee5c2097f353b182095a555c334d7b69c95 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_token.h +++ b/tensorflow/compiler/xla/tools/parser/hlo_token.h @@ -16,6 +16,11 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_TOKEN_H_ #define TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_TOKEN_H_ +#include + +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/types.h" + namespace xla { namespace tools { @@ -36,7 +41,8 @@ enum class TokKind { kLparen, kRparen, // ( ) - kArrow, // -> + kArrow, // -> + kComment, // /*xxx*/ // Keywords kw_HloModule, @@ -46,16 +52,26 @@ enum class TokKind { kw_false, kw_maximal, kw_replicated, + kw_nan, + kw_inf, + + kNegInf, // -inf // Typed tokens. kName, // %foo kAttributeName, // dimensions= + kDimLabels, // [0-9bf]{2,}_[0-9io]{2,}->[0-9bf]{2,} + kDxD, // [0-9]+(x[0-9]+)+ + kPad, // [0-9]+_[0-9]+(_[0-9]+)?(x[0-9]+_[0-9]+(_[0-9]+)?)* + kIdent, // other identifiers + kString, // "abcd\"\n" kShape, // f32[2,3]{1,0} - kOpcode, // add kInt, // 42 kDecimal, // 4.2 }; +string TokKindToString(TokKind kind); + } // namespace tools } // namespace xla diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc index 89b26b8916b67eeb38852c9e91314187fc8a7d48..a7dc5862057047f7c56faeb211cc0b13992caec7 100644 --- a/tensorflow/compiler/xla/tools/replay_computation.cc +++ b/tensorflow/compiler/xla/tools/replay_computation.cc @@ -45,6 +45,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/threadpool.h" @@ -58,18 +59,26 @@ namespace xla { namespace tools { namespace { +// Command-line opts to this tool. See main() for descriptions of these +// fields. +struct Options { + string fake_infeed_shape; + bool use_fake_data = false; + bool print_result = true; + int num_runs = 1; +}; + // Invokes the given computation passing arbitrary data for every (unbound) // parameter if use_fake_data, Otherwise use recorded data if available. // // Similarly, infeeds fake data of shape fake_infeed_shape if it is provided; // otherwise, no infeed is performed. StatusOr> ReplayComputation( - const SessionModule& module, tensorflow::StringPiece fake_infeed_shape, - bool use_fake_data, Client* client) { + const SessionModule& module, Client* client, const Options& opts) { TF_ASSIGN_OR_RETURN(Computation computation, client->LoadSnapshot(module)); std::vector> arguments; - if (use_fake_data) { + if (opts.use_fake_data) { arguments = MakeFakeArgumentsOrDie(computation, client); } else { // use recorded data if available for (const auto& proto : module.arguments()) { @@ -84,12 +93,12 @@ StatusOr> ReplayComputation( // concurrent infeed occur via the fake_infeed_shape. tensorflow::gtl::optional pool; - if (!fake_infeed_shape.empty()) { + if (!opts.fake_infeed_shape.empty()) { pool.emplace(tensorflow::Env::Default(), "infeed", /*num_threads=*/1); - pool->Schedule([fake_infeed_shape, client]() { + pool->Schedule([opts, client]() { StatusOr shape_status = - ShapeUtil::ParseShapeString(fake_infeed_shape); + ShapeUtil::ParseShapeString(opts.fake_infeed_shape); TF_CHECK_OK(shape_status.status()); Shape shape = std::move(shape_status).ValueOrDie(); StatusOr> data_status = MakeFakeLiteral(shape); @@ -106,11 +115,32 @@ StatusOr> ReplayComputation( for (auto& argument : arguments) { execute_arguments.push_back(argument.get()); } - return client->ExecuteAndTransfer(computation, execute_arguments); + + // Run the computation num_runs times, and return the result from the last + // execution. + std::unique_ptr result; + for (int i = 0; i < opts.num_runs; ++i) { + ExecutionProfile profile; + if (opts.print_result) { + TF_ASSIGN_OR_RETURN(result, client->ExecuteAndTransfer( + computation, execute_arguments, + /*execution_options=*/nullptr, &profile)); + } else { + // If we're not printing the result, execute the computation but don't + // bother retrieving the result. This can be a significant speedup. + TF_RETURN_IF_ERROR(client + ->Execute(computation, execute_arguments, + /*execution_options=*/nullptr, &profile) + .status()); + } + LOG(INFO) << "Execution took " + << static_cast(profile.compute_time_ns()) / 1e9 << "s"; + } + + return std::move(result); } -int RealMain(tensorflow::gtl::ArraySlice args, - tensorflow::StringPiece fake_infeed_shape, bool use_fake_data) { +int RealMain(tensorflow::gtl::ArraySlice args, const Options& opts) { Client* client = ClientLibrary::LocalClientOrDie(); tensorflow::Env* env = tensorflow::Env::Default(); int exit_status = EXIT_SUCCESS; @@ -118,21 +148,24 @@ int RealMain(tensorflow::gtl::ArraySlice args, SessionModule module; TF_CHECK_OK(tensorflow::ReadBinaryProto(env, arg, &module)); StatusOr> result_status = - ReplayComputation(module, fake_infeed_shape, use_fake_data, client); + ReplayComputation(module, client, opts); if (!result_status.ok()) { fprintf(stderr, "%s: error: %s\n", arg, result_status.status().ToString().c_str()); exit_status = EXIT_FAILURE; continue; } + std::unique_ptr result = result_status.ConsumeValueOrDie(); - fprintf(stdout, "%s: %s :: %s:%s\n", arg, module.entry().name().c_str(), - ShapeUtil::HumanString(result->shape()).c_str(), - result->ToString().c_str()); - if (module.has_result()) { - fprintf(stdout, "was %s:%s\n", - ShapeUtil::HumanString(module.result().shape()).c_str(), - Literal(module.result()).ToString().c_str()); + if (result != nullptr) { + fprintf(stdout, "%s: %s :: %s:%s\n", arg, module.entry().name().c_str(), + ShapeUtil::HumanString(result->shape()).c_str(), + result->ToString().c_str()); + if (module.has_result()) { + fprintf(stdout, "was %s:%s\n", + ShapeUtil::HumanString(module.result().shape()).c_str(), + Literal(module.result()).ToString().c_str()); + } } } return exit_status; @@ -143,13 +176,15 @@ int RealMain(tensorflow::gtl::ArraySlice args, } // namespace xla int main(int argc, char** argv) { - // Flags - xla::string fake_infeed_shape; - bool use_fake_data = false; + xla::tools::Options opts; const std::vector flag_list = { - tensorflow::Flag("use_fake_data", &use_fake_data, + tensorflow::Flag("use_fake_data", &opts.use_fake_data, "Replay computation using fake data"), - tensorflow::Flag("fake_infeed_shape", &fake_infeed_shape, + tensorflow::Flag("print_result", &opts.print_result, + "Print the result of the computation to stdout"), + tensorflow::Flag("num_runs", &opts.num_runs, + "Number of times to run each computation"), + tensorflow::Flag("fake_infeed_shape", &opts.fake_infeed_shape, "Shape of fake data to construct for (infinite) infeed"), }; xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); @@ -161,5 +196,5 @@ int main(int argc, char** argv) { tensorflow::gtl::ArraySlice args(argv, argc); args.pop_front(); // Pop off the binary name, argv[0] - return xla::tools::RealMain(args, fake_infeed_shape, use_fake_data); + return xla::tools::RealMain(args, opts); } diff --git a/tensorflow/compiler/xla/types.h b/tensorflow/compiler/xla/types.h index 3b19ca321cad35aad18f7f498e08fd744ffbc371..9fa4297523bab0748863479be52dff1b7b523a8b 100644 --- a/tensorflow/compiler/xla/types.h +++ b/tensorflow/compiler/xla/types.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "third_party/eigen3/Eigen/Core" +#include "tensorflow/core/framework/numeric_types.h" #include "tensorflow/core/platform/types.h" #include @@ -32,6 +33,8 @@ using ::tensorflow::int16; using ::tensorflow::int32; using ::tensorflow::int64; +using ::tensorflow::bfloat16; + using ::tensorflow::uint8; using ::tensorflow::uint16; using ::tensorflow::uint32; diff --git a/tensorflow/compiler/xla/util.cc b/tensorflow/compiler/xla/util.cc index 2624ef0252fd9482a600fe3aec07f7f328a86d69..fe5d29a6b655a89d559eb1214c2b8dd54d34094c 100644 --- a/tensorflow/compiler/xla/util.cc +++ b/tensorflow/compiler/xla/util.cc @@ -42,15 +42,15 @@ Status WithLogBacktrace(const Status& status) { } // namespace -ScopedLoggingTimer::ScopedLoggingTimer(const string& label, int32 vlog_level) - : label(label), vlog_level(vlog_level) { - if (VLOG_IS_ON(vlog_level)) { +ScopedLoggingTimer::ScopedLoggingTimer(const string& label, bool enabled) + : enabled(enabled), label(label) { + if (enabled) { start_micros = tensorflow::Env::Default()->NowMicros(); } } ScopedLoggingTimer::~ScopedLoggingTimer() { - if (VLOG_IS_ON(vlog_level)) { + if (enabled) { uint64 end_micros = tensorflow::Env::Default()->NowMicros(); double secs = (end_micros - start_micros) / 1000000.0; @@ -191,9 +191,9 @@ std::vector ComposePermutations(tensorflow::gtl::ArraySlice p1, return output; } -bool IsIdentityPermutation(tensorflow::gtl::ArraySlice p) { - for (int64 i = 0; i < p.size(); ++i) { - if (p[i] != i) { +bool IsIdentityPermutation(tensorflow::gtl::ArraySlice permutation) { + for (int64 i = 0; i < permutation.size(); ++i) { + if (permutation[i] != i) { return false; } } diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h index f58f57b44396c90a3820835a3d0ecc792aaa7cd0..b722095d1f38bf8a984c3ce9092a65f8e0baa911 100644 --- a/tensorflow/compiler/xla/util.h +++ b/tensorflow/compiler/xla/util.h @@ -50,13 +50,43 @@ using DimensionVector = tensorflow::gtl::InlinedVector; // RAII timer that logs with a given label the wall clock time duration in human // readable form. This differs from base's ElapsedTimer primarily in that it // spits out the human-readable duration form. +// +// By default, the timing traces are only printed at VLOG(1) and above: +// +// XLA_SCOPED_LOGGING_TIMER("fooing bar"); // nop if !VLOG_IS_ON(1). +// +// but you can control this via: +// +// XLA_SCOPED_LOGGING_TIMER_LEVEL("fooing bar", 2); // nop if !VLOG_IS_ON(2) +// +#define XLA_SCOPED_LOGGING_TIMER(label) \ + XLA_SCOPED_LOGGING_TIMER_HELPER(label, 1, __COUNTER__) +#define XLA_SCOPED_LOGGING_TIMER_LEVEL(label, level) \ + XLA_SCOPED_LOGGING_TIMER_HELPER(label, level, __COUNTER__) + +// Helper for implementing macros above. Do not use directly. +// +// Forces the evaluation of "counter", which we expect is equal to __COUNTER__. +#define XLA_SCOPED_LOGGING_TIMER_HELPER(label, level, counter) \ + XLA_SCOPED_LOGGING_TIMER_HELPER2(label, level, counter) + +// Helper for macros above. Don't use directly. +#define XLA_SCOPED_LOGGING_TIMER_HELPER2(label, level, counter) \ + ::xla::ScopedLoggingTimer XLA_ScopedLoggingTimerInstance##counter( \ + label, VLOG_IS_ON(level)) + +// RAII timer for XLA_SCOPED_LOGGING_TIMER and XLA_SCOPED_LOGGING_TIMER_LEVEL +// macros above. Recommended usage is via the macros so you don't have to give +// the timer a name or worry about calling VLOG_IS_ON yourself. struct ScopedLoggingTimer { - explicit ScopedLoggingTimer(const string& label, int32 vlog_level = 1); + // The timer does nothing if enabled is false. This lets you pass in your + // file's VLOG_IS_ON value. + ScopedLoggingTimer(const string& label, bool enabled); ~ScopedLoggingTimer(); - uint64 start_micros; + bool enabled; string label; - int32 vlog_level; + uint64 start_micros; }; // Given a vector, returns a MutableArraySlice that points at its diff --git a/tensorflow/compiler/xla/window_util.cc b/tensorflow/compiler/xla/window_util.cc index 23161873a0b722dfbea34507fefc38a7a02c023d..293f0781a203d092a7996d5548de1dbf5bf32e4c 100644 --- a/tensorflow/compiler/xla/window_util.cc +++ b/tensorflow/compiler/xla/window_util.cc @@ -26,8 +26,8 @@ namespace xla { namespace window_util { /* static */ string ToString(const WindowDimension& dim) { - using tensorflow::strings::StrCat; using tensorflow::strings::StrAppend; + using tensorflow::strings::StrCat; string str = StrCat("(size=", dim.size()); if (dim.stride() != 1) { StrAppend(&str, ",stride=", dim.stride()); @@ -44,27 +44,30 @@ namespace window_util { if (dim.window_dilation() != 1) { StrAppend(&str, ",window_dilation=", dim.window_dilation()); } + if (dim.window_reversal()) { + StrAppend(&str, ",window_reversal"); + } StrAppend(&str, ")"); return str; } string ToString(const Window& window) { - using tensorflow::strings::StrCat; using tensorflow::strings::StrAppend; + using tensorflow::strings::StrCat; string str; - const auto add_field = [&]( - const char* heading, - std::function format) { - StrAppend(&str, heading, "="); - const char* prefix = ""; - for (const auto& window_dimension : window.dimensions()) { - StrAppend(&str, prefix, format(window_dimension)); - prefix = "x"; - } - }; - - add_field("window", + const auto add_field = + [&](const char* heading, + std::function format) { + StrAppend(&str, heading, "="); + const char* prefix = ""; + for (const auto& window_dimension : window.dimensions()) { + StrAppend(&str, prefix, format(window_dimension)); + prefix = "x"; + } + }; + + add_field("size", [](const WindowDimension& dim) { return StrCat(dim.size()); }); if (HasStride(window)) { add_field(" stride", @@ -85,6 +88,11 @@ string ToString(const Window& window) { return StrCat(dim.window_dilation()); }); } + if (HasWindowReversal(window)) { + add_field(" rhs_reversal", [](const WindowDimension& dim) { + return StrCat(dim.window_reversal() ? 1 : 0); + }); + } return str; } @@ -138,6 +146,15 @@ bool HasWindowDilation(const Window& window) { return false; } +bool HasWindowReversal(const Window& window) { + for (const auto& dim : window.dimensions()) { + if (dim.window_reversal()) { + return true; + } + } + return false; +} + bool HasDilation(const Window& window) { return HasBaseDilation(window) || HasWindowDilation(window); } diff --git a/tensorflow/compiler/xla/window_util.h b/tensorflow/compiler/xla/window_util.h index 235cb2d59d451a25dc4f824ab488f8cef6b03bfb..125900dac0c5ab478b834c315b4a438c9238ef6d 100644 --- a/tensorflow/compiler/xla/window_util.h +++ b/tensorflow/compiler/xla/window_util.h @@ -39,6 +39,8 @@ bool HasBaseDilation(const Window& window); bool HasWindowDilation(const Window& window); bool HasDilation(const Window& window); +bool HasWindowReversal(const Window& window); + // Returns the new bound after dilation. // // If a window with the given bound in some dimension is dilated with the given diff --git a/tensorflow/compiler/xla/xla.bzl b/tensorflow/compiler/xla/xla.bzl index 3fa5bcc1df4f0294582b6c74735fef08c87433eb..6b136d333bbf079efd314833f46fe3b98743fbac 100644 --- a/tensorflow/compiler/xla/xla.bzl +++ b/tensorflow/compiler/xla/xla.bzl @@ -17,3 +17,5 @@ def xla_proto_library(name, srcs=[], deps=[], visibility=None, testonly=0): protoc="@protobuf_archive//:protoc", testonly=testonly, visibility=visibility,) + +ORC_JIT_MEMORY_MAPPER_TARGETS = [] diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index ce3c3eee68ad7f7ebb42836e3cae14803f8650d7..127e5e81ac6d21945c7125ef913d236e8892758e 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -167,6 +167,14 @@ message DebugOptions { // computation will run 2! * 4! times. bool xla_test_all_input_layouts = 91; + // Assign colors based on sharding information when generating the Graphviz + // HLO graph. + bool xla_hlo_graph_sharding_color = 92; + + // Prefix the name scopes of the TF graph exports with "devX" device + // assignments, if available. + bool xla_hlo_tfgraph_device_scopes = 93; + // Extra options to pass to the compilation backend; specific interpretation // of these values is left to the backend. map xla_backend_extra_options = 500; @@ -361,6 +369,7 @@ message WaitForExecutionResponse { message IsConstantRequest { ComputationHandle computation = 1; ComputationDataHandle operand = 2; + int64 num_parameters = 3; } message IsConstantResponse { @@ -371,6 +380,7 @@ message ComputeConstantRequest { ComputationHandle computation = 1; ComputationDataHandle operand = 2; Layout output_layout = 3; + repeated LiteralProto parameters = 4; } message ComputeConstantResponse { diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index 080e3c4267a2dca2b70c5cff51126cbf4b3e2881..215707634bc29263bc1ef472f498ac1bb1ca9181 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -46,6 +46,12 @@ enum PrimitiveType { // converted to f16 from f32 at arbirary points in the computation. F16 = 10; F32 = 11; + + // Truncated 16 bit floating-point format. This is similar to IEEE's 16 bit + // floating-point format, but uses 1 bit for the sign, 8 bits for the exponent + // and 7 bits for the mantissa. + BF16 = 16; + F64 = 12; // Complex values of fixed width. @@ -63,6 +69,8 @@ enum PrimitiveType { // An opaque type used for passing context specific data to a custom // operation. OPAQUE = 14; + + // Next = 17 } // Describes the value held inside padding elements. @@ -310,7 +318,10 @@ message LiteralProto { repeated double f64s = 9; repeated float c64s = 12; // Stored as interleaved real, imag floats. repeated LiteralProto tuple_literals = 10; - bytes f16s = 11; // Note: the F16s are encoded in little endian byte order + // The F16s and BF16s are encoded in little endian byte order + bytes f16s = 11; + bytes bf16s = 13; + // Next = 14 } message WindowDimension { @@ -346,6 +357,10 @@ message WindowDimension { // means no dilation. base_dilation - 1 no-op entries ("holes") are implicitly // placed between each base area element. See documentation for convolution. int64 base_dilation = 6; + + // Window reversal means that this dimension was logically reversed before the + // operation. + bool window_reversal = 7; } // Describes the windowing in an operation such as convolution. @@ -402,15 +417,9 @@ message ConvolutionDimensionNumbers { // The number of the dimension that represents features in the input. int64 input_feature_dimension = 8; - // The number of the dimension that represents batch in the output. - int64 output_batch_dimension = 9; - - // The number of the dimension that represents features in the output. - int64 output_feature_dimension = 10; - // The dimension numbers for the spatial dimensions that the window - // moves through in the input (lhs) and output. - repeated int64 spatial_dimensions = 5; + // moves through in the input. + repeated int64 input_spatial_dimensions = 11; // The number of the dimension that represents input features in the // convolutional kernel (rhs). @@ -424,15 +433,41 @@ message ConvolutionDimensionNumbers { // moves through in the kernel (rhs). window.strides(0) is the // stride in the kernel_spatial_dimensions(0) dimension. repeated int64 kernel_spatial_dimensions = 6; + + // The number of the dimension that represents batch in the output. + int64 output_batch_dimension = 9; + + // The number of the dimension that represents features in the output. + int64 output_feature_dimension = 10; + + // The dimension numbers for the spatial dimensions that the window + // moves through in the output. + repeated int64 output_spatial_dimensions = 12; + + // Next = 13 }; message ConvolveRequest { ComputationDataHandle lhs = 2; ComputationDataHandle rhs = 3; // This is the filter/kernel. - Window window = 4; // Describes the filter/kenel. + Window window = 4; // Describes the filter/kernel. ConvolutionDimensionNumbers dimension_numbers = 5; } +enum FftType { + FFT = 0; // Forward FFT; complex in, complex out. + IFFT = 1; // Inverse FFT; complex in, complex out. + RFFT = 2; // Forward real FFT; real in, fft_length / 2 + 1 complex out + IRFFT = 3; // Inverse real FFT; fft_length / 2 + 1 complex in, + // fft_length real out +} + +message FftRequest { + FftType fft_type = 1; + repeated int64 fft_length = 2; // Multivalent for higher-order FFT. + ComputationDataHandle operand = 3; +} + message InfeedRequest { // The shape of the data returned by reading the device's infeed buffer. Shape shape = 2; @@ -463,6 +498,23 @@ message CustomCallRequest { Shape shape = 4; } +message DotDimensionNumbers { + // The dimension numbers that represent the 'lhs' contracting dimensions. + repeated int64 lhs_contracting_dimensions = 1; + // The dimension numbers that represent the 'rhs' contracting dimensions. + repeated int64 rhs_contracting_dimensions = 2; + // The dimension numbers that represent the 'lhs' batch dimensions. + repeated int64 lhs_batch_dimensions = 3; + // The dimension numbers that represent the 'rhs' batch dimensions. + repeated int64 rhs_batch_dimensions = 4; +}; + +message DotRequest { + ComputationDataHandle lhs = 2; + ComputationDataHandle rhs = 3; + DotDimensionNumbers dimension_numbers = 4; +} + message MapRequest { repeated ComputationDataHandle operands = 2; ComputationHandle to_apply = 3; @@ -616,6 +668,14 @@ message ConcatenateRequest { int64 dimension = 3; } +message ConditionalRequest { + ComputationDataHandle predicate = 2; + ComputationDataHandle true_operand = 3; + ComputationHandle true_computation = 4; + ComputationDataHandle false_operand = 5; + ComputationHandle false_computation = 6; +} + message WhileRequest { ComputationHandle condition = 2; ComputationHandle body = 3; @@ -697,9 +757,6 @@ enum BinaryOperation { BINOP_LT = 9; BINOP_NE = 10; - // Dot product, matrix multiply. - BINOP_DOT = 12; - // Element-wise maximum. BINOP_MAX = 14; @@ -811,8 +868,10 @@ message OpSharding { REPLICATED = 0; // This sharding is maximal - one device runs the entire operation. MAXIMAL = 1; - // Neither of the above; tile_shape and tile_assignment are both used. - OTHER = 2; + // This sharding is a tuple - only the tuple_shardings field is valid. + TUPLE = 2; + // None of the above; tile_shape and tile_assignment are both used. + OTHER = 3; } Type type = 1; // The shape of the sharded tile. @@ -824,6 +883,13 @@ message OpSharding { // Flattened list of device IDs. The order of flattening is the same as used // by IndexUtil::MultiToLinearIndex(tile_assignment_shape). repeated int64 tile_assignment_devices = 4; + // If type == TUPLE, the sub-shardings, one per leaf node in the tuple shape, + // in pre-order. The tuple shape could be nested; here we store just a + // flattened list of all leaves in the tuple shape. Note that the tuple shape + // is not stored here; shardings do not store the shapes to which they are + // applied, this is inferred from the instruction this sharding gets attached + // to. + repeated OpSharding tuple_shardings = 5; } message OpRequest { @@ -841,6 +907,7 @@ message OpRequest { ConvolveRequest convolve_request = 8; CrossReplicaSumRequest cross_replica_sum_request = 9; CustomCallRequest custom_call_request = 10; + DotRequest dot_request = 43; DynamicSliceRequest dynamic_slice_request = 11; DynamicUpdateSliceRequest dynamic_update_slice_request = 12; GetTupleElementRequest get_tuple_element_request = 13; @@ -868,7 +935,10 @@ message OpRequest { BatchNormTrainingRequest batch_norm_training_request = 35; BatchNormGradRequest batch_norm_grad_request = 37; BatchNormInferenceRequest batch_norm_inference_request = 38; - // Next: 41 + FftRequest fft_request = 41; + ConvertRequest bitcast_convert_request = 42; + ConditionalRequest conditional_request = 44; + // Next: 45 } } diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index 2e9b96bb1d31f7c985df992c094784660d6e274c..604c41bf8acc910b47f8ee4a871d4740a2f1ba2f 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -9,7 +9,12 @@ load("//third_party/mpi:mpi.bzl", "if_mpi") py_library( name = "contrib_py", - srcs = glob(["**/*.py"]), + srcs = glob( + ["**/*.py"], + exclude = [ + "**/*_test.py", + ], + ), srcs_version = "PY2AND3", visibility = ["//visibility:public"], deps = [ @@ -51,17 +56,20 @@ py_library( "//tensorflow/contrib/linalg:linalg_py", "//tensorflow/contrib/linear_optimizer:sdca_estimator_py", "//tensorflow/contrib/linear_optimizer:sdca_ops_py", + "//tensorflow/contrib/lite/python:lite", "//tensorflow/contrib/lookup:lookup_py", "//tensorflow/contrib/losses:losses_py", "//tensorflow/contrib/losses:metric_learning_py", "//tensorflow/contrib/memory_stats:memory_stats_py", "//tensorflow/contrib/meta_graph_transform", "//tensorflow/contrib/metrics:metrics_py", + "//tensorflow/contrib/model_pruning", "//tensorflow/contrib/nccl:nccl_py", "//tensorflow/contrib/ndlstm", "//tensorflow/contrib/nearest_neighbor:nearest_neighbor_py", "//tensorflow/contrib/nn:nn_py", "//tensorflow/contrib/opt:opt_py", + "//tensorflow/contrib/periodic_resample:init_py", "//tensorflow/contrib/predictor", "//tensorflow/contrib/quantization:quantization_py", "//tensorflow/contrib/quantize:quantize_graph", diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py index a26fdb982c0f4d6d85b73912c194647a989d0ef6..08247c6b38a4df663ad28a6b4d3c41a1da41a020 100644 --- a/tensorflow/contrib/__init__.py +++ b/tensorflow/contrib/__init__.py @@ -51,9 +51,11 @@ from tensorflow.contrib import lookup from tensorflow.contrib import losses from tensorflow.contrib import memory_stats from tensorflow.contrib import metrics +from tensorflow.contrib import model_pruning from tensorflow.contrib import nccl from tensorflow.contrib import nn from tensorflow.contrib import opt +from tensorflow.contrib import periodic_resample from tensorflow.contrib import predictor from tensorflow.contrib import quantization from tensorflow.contrib import quantize @@ -78,6 +80,7 @@ from tensorflow.contrib import tpu from tensorflow.contrib import training from tensorflow.contrib import util from tensorflow.contrib.eager.python import tfe as eager +from tensorflow.contrib.lite.python import lite from tensorflow.contrib.ndlstm import python as ndlstm from tensorflow.contrib.remote_fused_graph import pylib as remote_fused_graph from tensorflow.contrib.specs import python as specs diff --git a/tensorflow/contrib/android/README.md b/tensorflow/contrib/android/README.md index f49e5857fe5255c2459793cb1389052a2ff5f88f..c7c128bf14f03d3769ef08e83da61f6d2f91fbd2 100644 --- a/tensorflow/contrib/android/README.md +++ b/tensorflow/contrib/android/README.md @@ -15,9 +15,9 @@ For prebuilt libraries, see the page for a recent build. The TensorFlow Inference Interface is also available as a -[JCenter package](https://bintray.com/google/tensorflow/tensorflow-android) and -can be included quite simply in your android project with a couple of lines in -the project's `build.gradle` file: +[JCenter package](https://bintray.com/google/tensorflow/tensorflow) +(see the tensorflow-android directory) and can be included quite simply in your +android project with a couple of lines in the project's `build.gradle` file: ``` allprojects { diff --git a/tensorflow/contrib/android/asset_manager_filesystem.cc b/tensorflow/contrib/android/asset_manager_filesystem.cc index 9e4d3290c3d99fab42f512f7144defde54f8ece8..380a652435ad089f46f3ca80e4fd43097fd96e10 100644 --- a/tensorflow/contrib/android/asset_manager_filesystem.cc +++ b/tensorflow/contrib/android/asset_manager_filesystem.cc @@ -97,7 +97,7 @@ class RandomAccessFileFromAsset : public RandomAccessFile { off64_t new_offset = AAsset_seek64(asset.get(), offset, SEEK_SET); off64_t length = AAsset_getLength64(asset.get()); if (new_offset < 0) { - result->set(scratch, 0); + *result = StringPiece(scratch, 0); return errors::OutOfRange("Read after file end."); } const off64_t region_left = @@ -106,7 +106,7 @@ class RandomAccessFileFromAsset : public RandomAccessFile { if (read < 0) { return errors::Internal("Error reading from asset."); } - result->set(scratch, region_left); + *result = StringPiece(scratch, region_left); return (region_left == to_read) ? Status::OK() : errors::OutOfRange("Read less bytes than requested."); diff --git a/tensorflow/contrib/android/cmake/CMakeLists.txt b/tensorflow/contrib/android/cmake/CMakeLists.txt index 25ada5ba27aa167e4aaf4cebd6517e3b80aa1058..a115d1610e2334a6626f29674f3dd195e3a3c648 100644 --- a/tensorflow/contrib/android/cmake/CMakeLists.txt +++ b/tensorflow/contrib/android/cmake/CMakeLists.txt @@ -34,10 +34,12 @@ add_library(lib_tf STATIC IMPORTED ) set_target_properties(lib_tf PROPERTIES IMPORTED_LOCATION ${PREBUILT_DIR}/lib/libtensorflow-core.a) # Change to compile flags should be replicated into bazel build file +# TODO: Consider options other than -O2 for binary size. +# e.g. -Os for gcc, and -Oz for clang. set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DIS_SLIM_BUILD \ -std=c++11 -fno-rtti -fno-exceptions \ -O2 -Wno-narrowing -fomit-frame-pointer \ - -mfpu=neon -mfloat-abi=softfp -fPIE \ + -mfpu=neon -mfloat-abi=softfp -fPIE -fPIC \ -ftemplate-depth=900 \ -DGOOGLE_PROTOBUF_NO_RTTI \ -DGOOGLE_PROTOBUF_NO_STATIC_INITIALIZER") diff --git a/tensorflow/contrib/android/cmake/README.md b/tensorflow/contrib/android/cmake/README.md index 6f19b657fe72064bd7b005b568540cd52a5e19e8..934b58c7242fc06064ee3c06bc8f4c2740bd24ef 100644 --- a/tensorflow/contrib/android/cmake/README.md +++ b/tensorflow/contrib/android/cmake/README.md @@ -14,7 +14,7 @@ Add TensorFlow-Android-Inference as a dependency of your Android application ``` include ':TensorFlow-Android-Inference' -findProject(":TensorFlow-Android-Inference").projectDir = +findProject(":TensorFlow-Android-Inference").projectDir = new File("${/path/to/tensorflow_repo}/contrib/android/cmake") ``` diff --git a/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java b/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java index 1f423a7a5bf6a115dc627ddd6f5e98c074282585..dc5b9fb88742d78d0f40207b589e29451a6358dd 100644 --- a/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java +++ b/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java @@ -160,7 +160,7 @@ public class TensorFlowInferenceInterface { throw new RuntimeException("Failed to load model from the input stream", e); } } - + /* * Construct a TensorFlowInferenceInterface with provided Graph * @@ -168,7 +168,7 @@ public class TensorFlowInferenceInterface { */ public TensorFlowInferenceInterface(Graph g) { prepareNativeRuntime(); - + // modelName is redundant here, here is for // avoiding error in initialization as modelName is marked final. this.modelName = ""; @@ -290,7 +290,7 @@ public class TensorFlowInferenceInterface { */ public void feed(String inputName, boolean[] src, long... dims) { byte[] b = new byte[src.length]; - + for (int i = 0; i < src.length; i++) { b[i] = src[i] ? (byte) 1 : (byte) 0; } diff --git a/tensorflow/contrib/batching/BUILD b/tensorflow/contrib/batching/BUILD index 8b7df4a84c558f662405a28a42426583d5ab39cd..a111cfecb366fe245150cc71d2c43662d0d69090 100644 --- a/tensorflow/contrib/batching/BUILD +++ b/tensorflow/contrib/batching/BUILD @@ -82,6 +82,7 @@ cc_library( tf_cc_test( name = "adaptive_shared_batch_scheduler_test", srcs = ["adaptive_shared_batch_scheduler_test.cc"], + tags = ["manual"], # b/69013768 deps = [ ":adaptive_shared_batch_scheduler", "//tensorflow/contrib/batching/test_util:fake_clock_env", diff --git a/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h b/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h index a0606427a526ffc67e10d12a084eabc64564e4ab..9e32bee505640ea04edfeffea0a14d1937c3a2b1 100644 --- a/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h +++ b/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h @@ -208,6 +208,8 @@ class ASBSQueue : public BatchScheduler { // place any more tasks in this batch. void ReleaseBatch(const ASBSBatch* batch); + size_t max_task_size() const override { return options_.max_batch_size; } + private: std::shared_ptr> scheduler_; const QueueOptions options_; @@ -399,7 +401,7 @@ ASBSQueue::~ASBSQueue() { template Status ASBSQueue::Schedule(std::unique_ptr* task) { - bool added_new_batch = false; + ASBSBatch* new_batch = nullptr; size_t size = (*task)->size(); if (size > options_.max_batch_size) { return errors::InvalidArgument("Task size ", size, @@ -418,15 +420,14 @@ Status ASBSQueue::Schedule(std::unique_ptr* task) { current_batch_ = nullptr; } if (!current_batch_) { - added_new_batch = true; num_enqueued_batches_++; - current_batch_ = + current_batch_ = new_batch = new ASBSBatch(this, scheduler_->GetEnv()->NowMicros()); } current_batch_->AddTask(std::move(*task)); num_enqueued_tasks_++; } - if (added_new_batch) scheduler_->AddBatch(current_batch_); + if (new_batch != nullptr) scheduler_->AddBatch(new_batch); return Status::OK(); } diff --git a/tensorflow/contrib/batching/adaptive_shared_batch_scheduler_test.cc b/tensorflow/contrib/batching/adaptive_shared_batch_scheduler_test.cc index a07cd6d834fa28904bf7748b16972cca217503c1..e2aac54eebccaf53da9560591cfe909989774bab 100644 --- a/tensorflow/contrib/batching/adaptive_shared_batch_scheduler_test.cc +++ b/tensorflow/contrib/batching/adaptive_shared_batch_scheduler_test.cc @@ -186,6 +186,7 @@ TEST(AdaptiveSharedBatchSchedulerTest, ObeysQueueOptions) { queue_options.max_enqueued_batches = 2; TF_ASSERT_OK( scheduler->AddQueue(queue_options, queue_0_callback, &queue_0)); + EXPECT_EQ(10, queue_0->max_task_size()); queue_options.max_batch_size = 0; // Queue must have max_batch_size > 0. EXPECT_FALSE( diff --git a/tensorflow/contrib/batching/basic_batch_scheduler.h b/tensorflow/contrib/batching/basic_batch_scheduler.h index 9d3805fbaf39978159dd2f4a754e6d41a07acf6a..91065db2499dffd2687a53bd6304d9b7593f7b3a 100644 --- a/tensorflow/contrib/batching/basic_batch_scheduler.h +++ b/tensorflow/contrib/batching/basic_batch_scheduler.h @@ -192,6 +192,10 @@ class BasicBatchScheduler : public BatchScheduler { size_t NumEnqueuedTasks() const override; size_t SchedulingCapacity() const override; + size_t max_task_size() const override { + return shared_scheduler_queue_->max_task_size(); + } + private: explicit BasicBatchScheduler( std::unique_ptr> shared_scheduler_queue); diff --git a/tensorflow/contrib/batching/basic_batch_scheduler_test.cc b/tensorflow/contrib/batching/basic_batch_scheduler_test.cc index e020301795c7dadee2815c0e0d727e53e5fb9e6e..187823151cf840dcf8058677fcf74d1beffc3bc2 100644 --- a/tensorflow/contrib/batching/basic_batch_scheduler_test.cc +++ b/tensorflow/contrib/batching/basic_batch_scheduler_test.cc @@ -73,6 +73,7 @@ TEST(BasicBatchSchedulerTest, Basic) { std::unique_ptr> scheduler; TF_ASSERT_OK( BasicBatchScheduler::Create(options, callback, &scheduler)); + EXPECT_EQ(10, scheduler->max_task_size()); EXPECT_EQ(0, scheduler->NumEnqueuedTasks()); EXPECT_EQ(3 * 10, scheduler->SchedulingCapacity()); TF_ASSERT_OK(ScheduleTask(3, scheduler.get())); diff --git a/tensorflow/contrib/batching/batch_scheduler.h b/tensorflow/contrib/batching/batch_scheduler.h index a5072f439abad3c5db79a514a7f2baff0b021b39..e18cf6c35059e4d720768e3b2c02b03727a6bac4 100644 --- a/tensorflow/contrib/batching/batch_scheduler.h +++ b/tensorflow/contrib/batching/batch_scheduler.h @@ -178,6 +178,10 @@ class BatchScheduler { // This method is useful for monitoring, or for guaranteeing a future slot in // the schedule (but being mindful about the caveats listed above). virtual size_t SchedulingCapacity() const = 0; + + // Returns the maximum allowed size of tasks submitted to the scheduler. (This + // is typically equal to a configured maximum batch size.) + virtual size_t max_task_size() const = 0; }; ////////// diff --git a/tensorflow/contrib/batching/kernels/batch_kernels.cc b/tensorflow/contrib/batching/kernels/batch_kernels.cc index 3b7c538fcc42b2e8f100d374c273ee3ca3d6056b..6041d8c9b2ca14bd325d1e7ea562bc4bc27d6a51 100644 --- a/tensorflow/contrib/batching/kernels/batch_kernels.cc +++ b/tensorflow/contrib/batching/kernels/batch_kernels.cc @@ -461,7 +461,7 @@ class BatchResource : public ResourceBase { return Status::OK(); } - // Looks up the batcher queue for 'queue_name'. If it did't previously exist, + // Looks up the batcher queue for 'queue_name'. If it didn't previously exist, // creates it. Status LookupOrCreateBatcherQueue(const string& queue_name, BatcherQueue** queue) { diff --git a/tensorflow/contrib/batching/shared_batch_scheduler.h b/tensorflow/contrib/batching/shared_batch_scheduler.h index 41a3f99137ade2552432fee62ddce17d064148a4..1d2158062e589db71b7df4c47af1b7851b41a036 100644 --- a/tensorflow/contrib/batching/shared_batch_scheduler.h +++ b/tensorflow/contrib/batching/shared_batch_scheduler.h @@ -248,6 +248,9 @@ class Queue { // BatchScheduler::SchedulingCapacity(). size_t SchedulingCapacity() const; + // Returns the maximum allowed size of tasks submitted to the queue. + size_t max_task_size() const { return options_.max_batch_size; } + // Called by a thread that is ready to process a batch, to request one from // this queue. Either returns a batch that is ready to be processed, or // nullptr if the queue declines to schedule a batch at this time. If it @@ -338,6 +341,8 @@ class QueueHandle : public BatchScheduler { size_t NumEnqueuedTasks() const override; size_t SchedulingCapacity() const override; + size_t max_task_size() const override { return queue_->max_task_size(); } + private: // The scheduler that owns 'queue_'. std::shared_ptr> scheduler_; diff --git a/tensorflow/contrib/batching/shared_batch_scheduler_test.cc b/tensorflow/contrib/batching/shared_batch_scheduler_test.cc index 3e924ae5f13519b4fe9a3f4b510773ca2bddaf23..3ac79a8fdc47389816db8ca09f27846d1c4623c2 100644 --- a/tensorflow/contrib/batching/shared_batch_scheduler_test.cc +++ b/tensorflow/contrib/batching/shared_batch_scheduler_test.cc @@ -429,6 +429,7 @@ TEST(SharedBatchSchedulerTest, ConstMethods) { queue_options.max_enqueued_batches = max_enqueued_batches; std::unique_ptr> queue; TF_ASSERT_OK(scheduler->AddQueue(queue_options, callback, &queue)); + EXPECT_EQ(2, queue->max_task_size()); EXPECT_EQ(0, queue->NumEnqueuedTasks()); EXPECT_EQ(max_enqueued_batches * 2, queue->SchedulingCapacity()); diff --git a/tensorflow/contrib/bayesflow/BUILD b/tensorflow/contrib/bayesflow/BUILD index 8bb742d289a0836378a9a03c90d46293cfbfe75b..a262d4aecdbb69dfcd8b88bc0a09060500d6b1c9 100644 --- a/tensorflow/contrib/bayesflow/BUILD +++ b/tensorflow/contrib/bayesflow/BUILD @@ -3,12 +3,15 @@ # particularly useful for Bayesian inference. # APIs here are meant to evolve over time. +package(default_visibility = [ + "//learning/brain/contrib/bayesflow:__subpackages__", + "//tensorflow:__subpackages__", +]) + licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) -package(default_visibility = ["//tensorflow:__subpackages__"]) - load("//tensorflow:tensorflow.bzl", "cuda_py_test") py_library( @@ -16,9 +19,9 @@ py_library( srcs = ["__init__.py"] + glob(["python/ops/*.py"]), srcs_version = "PY2AND3", deps = [ + "//tensorflow/contrib/distributions:distributions_py", "//tensorflow/contrib/framework:framework_py", "//tensorflow/python:array_ops", - "//tensorflow/python:check_ops", "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:functional_ops", @@ -29,12 +32,8 @@ py_library( "//tensorflow/python:platform", "//tensorflow/python:random_ops", "//tensorflow/python:state_ops", - "//tensorflow/python:training", "//tensorflow/python:util", - "//tensorflow/python:variable_scope", - "//tensorflow/python/ops/distributions", "//third_party/py/numpy", - "@six_archive//:six", ], ) @@ -101,61 +100,61 @@ cuda_py_test( ) cuda_py_test( - name = "entropy_test", - size = "medium", - srcs = ["python/kernel_tests/entropy_test.py"], + name = "layers_dense_variational_test", + size = "small", + srcs = ["python/kernel_tests/layers_dense_variational_test.py"], additional_deps = [ ":bayesflow_py", "//third_party/py/numpy", "//tensorflow/contrib/distributions:distributions_py", - "//tensorflow/contrib/layers:layers_py", "//tensorflow/python/ops/distributions", + "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:gradients", + "//tensorflow/python:linalg_ops", "//tensorflow/python:math_ops", "//tensorflow/python:nn_ops", - "//tensorflow/python:variables", ], ) cuda_py_test( - name = "stochastic_variables_test", - size = "medium", - srcs = ["python/kernel_tests/stochastic_variables_test.py"], + name = "monte_carlo_test", + size = "small", + srcs = ["python/kernel_tests/monte_carlo_test.py"], additional_deps = [ ":bayesflow_py", "//third_party/py/numpy", "//tensorflow/contrib/distributions:distributions_py", - "//tensorflow/python:array_ops", + "//tensorflow/contrib/layers:layers_py", + "//tensorflow/python/ops/distributions", "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", + "//tensorflow/python:gradients", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", - "//tensorflow/python:random_ops", - "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", + "//tensorflow/python:random_seed", ], ) cuda_py_test( - name = "monte_carlo_test", + name = "halton_sequence_test", size = "small", - srcs = ["python/kernel_tests/monte_carlo_test.py"], + srcs = ["python/kernel_tests/halton_sequence_test.py"], additional_deps = [ ":bayesflow_py", "//third_party/py/numpy", - "//tensorflow/contrib/distributions:distributions_py", - "//tensorflow/contrib/layers:layers_py", - "//tensorflow/python/ops/distributions", + "//tensorflow/python:array_ops", + "//tensorflow/python:math_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:framework", "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:gradients", - "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", - "//tensorflow/python:random_seed", + "//tensorflow/python:random_ops", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", ], ) @@ -181,84 +180,23 @@ cuda_py_test( ) cuda_py_test( - name = "stochastic_graph_test", + name = "sgld_optimizer_test", size = "small", - srcs = ["python/kernel_tests/stochastic_graph_test.py"], - additional_deps = [ - ":bayesflow_py", - "//tensorflow/contrib/distributions:distributions_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:gradients", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform_test", - ], -) - -cuda_py_test( - name = "variational_inference_test", - size = "small", - srcs = ["python/kernel_tests/variational_inference_test.py"], - additional_deps = [ - ":bayesflow_py", - "//tensorflow/contrib/distributions:distributions_py", - "//tensorflow/contrib/layers:layers_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform_test", - "//tensorflow/python:variables", - ], -) - -cuda_py_test( - name = "stochastic_tensor_test", - size = "small", - srcs = ["python/kernel_tests/stochastic_tensor_test.py"], - additional_deps = [ - ":bayesflow_py", - "//third_party/py/numpy", - "//tensorflow/contrib/distributions:distributions_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:platform_test", - ], -) - -cuda_py_test( - name = "stochastic_gradient_estimators_test", - size = "medium", - srcs = ["python/kernel_tests/stochastic_gradient_estimators_test.py"], + srcs = ["python/kernel_tests/sgld_optimizer_test.py"], additional_deps = [ ":bayesflow_py", "//third_party/py/numpy", "//tensorflow/contrib/distributions:distributions_py", + "//tensorflow/contrib/layers:layers_py", + "//tensorflow/python/ops/distributions", "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", "//tensorflow/python:gradients", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", - "//tensorflow/python:variables", - ], -) - -cuda_py_test( - name = "reinforce_simple_example", - size = "small", - srcs = ["examples/reinforce_simple/reinforce_simple_example.py"], - additional_deps = [ - ":bayesflow_py", - "//tensorflow:tensorflow_py", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:platform_test", + "//tensorflow/python:random_seed", ], ) diff --git a/tensorflow/contrib/bayesflow/__init__.py b/tensorflow/contrib/bayesflow/__init__.py index 8b27fa76bd31a926558abe681d6e510c0a4997c1..95b9452b1ada60c44672f37800ced2133d2bd8b2 100644 --- a/tensorflow/contrib/bayesflow/__init__.py +++ b/tensorflow/contrib/bayesflow/__init__.py @@ -23,24 +23,30 @@ from __future__ import print_function # pylint: disable=unused-import,line-too-long from tensorflow.contrib.bayesflow.python.ops import csiszar_divergence from tensorflow.contrib.bayesflow.python.ops import custom_grad -from tensorflow.contrib.bayesflow.python.ops import entropy +from tensorflow.contrib.bayesflow.python.ops import halton_sequence from tensorflow.contrib.bayesflow.python.ops import hmc +from tensorflow.contrib.bayesflow.python.ops import layers from tensorflow.contrib.bayesflow.python.ops import metropolis_hastings from tensorflow.contrib.bayesflow.python.ops import monte_carlo -from tensorflow.contrib.bayesflow.python.ops import stochastic_gradient_estimators -from tensorflow.contrib.bayesflow.python.ops import stochastic_graph -from tensorflow.contrib.bayesflow.python.ops import stochastic_tensor -from tensorflow.contrib.bayesflow.python.ops import stochastic_variables -from tensorflow.contrib.bayesflow.python.ops import variational_inference +from tensorflow.contrib.bayesflow.python.ops import optimizers # pylint: enable=unused-import,line-too-long from tensorflow.python.util.all_util import remove_undocumented -_allowed_symbols = ['csiszar_divergence', 'custom_grad', 'entropy', - 'metropolis_hastings', 'monte_carlo', 'hmc', 'special_math', - 'stochastic_gradient_estimators', 'stochastic_graph', - 'stochastic_tensor', 'stochastic_variables', - 'variational_inference'] +_allowed_symbols = [ + 'csiszar_divergence', + 'custom_grad', + 'entropy', + 'halton_sequence', + 'hmc', + 'layers', + 'metropolis_hastings', + 'monte_carlo', + 'optimizers', + 'special_math', + 'stochastic_variables', + 'variational_inference', +] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/bayesflow/examples/reinforce_simple/reinforce_simple_example.py b/tensorflow/contrib/bayesflow/examples/reinforce_simple/reinforce_simple_example.py deleted file mode 100644 index 2eb625487f4cd18bdec10ddbc0cf64cb8c8499b8..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/examples/reinforce_simple/reinforce_simple_example.py +++ /dev/null @@ -1,140 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Simple examples of the REINFORCE algorithm.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import tensorflow as tf - - -distributions = tf.contrib.distributions -sg = tf.contrib.bayesflow.stochastic_graph -st = tf.contrib.bayesflow.stochastic_tensor - - -def split_apply_merge(inp, partitions, fns): - """Split input according to partitions. Pass results through fns and merge. - - Args: - inp: the input vector - partitions: tensor of same length as input vector, having values 0, 1 - fns: the two functions. - - Returns: - the vector routed, where routed[i] = fns[partitions[i]](inp[i]) - """ - new_inputs = tf.dynamic_partition(inp, partitions, len(fns)) - new_outputs = [fns[i](x) for i, x in enumerate(new_inputs)] - new_indices = tf.dynamic_partition( - tf.range(0, inp.get_shape()[0]), partitions, len(fns)) - return tf.dynamic_stitch(new_indices, new_outputs) - - -def plus_1(inputs): - return inputs + 1.0 - - -def minus_1(inputs): - return inputs - 1.0 - - -def build_split_apply_merge_model(): - """Build the Split-Apply-Merge Model. - - Route each value of input [-1, -1, 1, 1] through one of the - functions, plus_1, minus_1. The decision for routing is made by - 4 Bernoulli R.V.s whose parameters are determined by a neural network - applied to the input. REINFORCE is used to update the NN parameters. - - Returns: - The 3-tuple (route_selection, routing_loss, final_loss), where: - - - route_selection is an int 4-vector - - routing_loss is a float 4-vector - - final_loss is a float scalar. - """ - inputs = tf.constant([[-1.0], [-1.0], [1.0], [1.0]]) - targets = tf.constant([[0.0], [0.0], [0.0], [0.0]]) - paths = [plus_1, minus_1] - weights = tf.get_variable("w", [1, 2]) - bias = tf.get_variable("b", [1, 1]) - logits = tf.matmul(inputs, weights) + bias - - # REINFORCE forward step - route_selection = st.StochasticTensor( - distributions.Categorical(logits=logits)) - - # Accessing route_selection as a Tensor below forces a sample of - # the Categorical distribution based on its logits. - # This is equivalent to calling route_selection.value(). - # - # route_selection.value() returns an int32 4-vector with random - # values in {0, 1} - # COPY+ROUTE+PASTE - outputs = split_apply_merge(inputs, route_selection, paths) - - # flatten routing_loss to a row vector (from a column vector) - routing_loss = tf.reshape(tf.square(outputs - targets), shape=[-1]) - - # Total loss: score function loss + routing loss. - # The score function loss (through `route_selection.loss(routing_loss)`) - # returns: - # [stop_gradient(routing_loss) * - # route_selection.log_pmf(stop_gradient(route_selection.value()))], - # where log_pmf has gradients going all the way back to weights and bias. - # In this case, the routing_loss depends on the variables only through - # "route_selection", which has a stop_gradient on it. So the - # gradient of the loss really come through the score function - surrogate_loss = sg.surrogate_loss([routing_loss]) - final_loss = tf.reduce_sum(surrogate_loss) - - return (route_selection, routing_loss, final_loss) - - -class REINFORCESimpleExample(tf.test.TestCase): - - def testSplitApplyMerge(self): - # Repeatability. SGD has a tendency to jump around, even here. - tf.set_random_seed(1) - - with self.test_session() as sess: - # Use sampling to train REINFORCE - with st.value_type(st.SampleValue()): - (route_selection, - routing_loss, - final_loss) = build_split_apply_merge_model() - - sgd = tf.train.GradientDescentOptimizer(1.0).minimize(final_loss) - - tf.global_variables_initializer().run() - - for i in range(10): - # Run loss and inference step. This toy problem converges VERY quickly. - (routing_loss_v, final_loss_v, route_selection_v, _) = sess.run( - [routing_loss, final_loss, tf.identity(route_selection), sgd]) - print( - "Iteration %d, routing loss: %s, final_loss: %s, " - "route selection: %s" - % (i, routing_loss_v, final_loss_v, route_selection_v)) - - self.assertAllEqual([0, 0, 1, 1], route_selection_v) - self.assertAllClose([0.0, 0.0, 0.0, 0.0], routing_loss_v) - self.assertAllClose(0.0, final_loss_v) - - -if __name__ == "__main__": - tf.test.main() diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/csiszar_divergence_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/csiszar_divergence_test.py index 8c6a614beb194180d8b075526a5395aa65d354de..2e94b7206de4f7c40c89f083f3bfa2a22bb7b917 100644 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/csiszar_divergence_test.py +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/csiszar_divergence_test.py @@ -759,7 +759,7 @@ class CsiszarVIMCOTest(test.TestCase): def _csiszar_vimco_helper_grad(self, logu, delta): """Finite difference approximation of `grad(csiszar_vimco_helper, logu)`.""" - # This code actually estimates the sum of the Jacobiab because thats what + # This code actually estimates the sum of the Jacobiab because that's what # TF's `gradients` does. np_log_avg_u1, np_log_sooavg_u1 = self._csiszar_vimco_helper( logu[..., None] + np.diag([delta]*len(logu))) diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/entropy_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/entropy_test.py deleted file mode 100644 index 0bd12b84d12a9c3219f6b24830b1b82db9716043..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/entropy_test.py +++ /dev/null @@ -1,352 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for Monte Carlo Ops.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib import layers as layers_lib -from tensorflow.contrib.bayesflow.python.ops import entropy_impl as entropy -from tensorflow.contrib.distributions.python.ops import mvn_diag as mvn_diag_lib -from tensorflow.contrib.distributions.python.ops import mvn_tril as mvn_tril_lib -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn_ops -from tensorflow.python.ops import variables -from tensorflow.python.ops.distributions import kullback_leibler as kullback_leibler_lib -from tensorflow.python.ops.distributions import normal as normal_lib -from tensorflow.python.ops.distributions import util as distribution_util -from tensorflow.python.platform import test - -layers = layers_lib - - -class NormalNoEntropy(normal_lib.Normal): # pylint: disable=no-init - """Normal distribution without a `.entropy` method.""" - - def entropy(self): - return NotImplementedError('Entropy removed by gremlins') - - -def get_train_op(scalar_loss, optimizer='SGD', learning_rate=1.0, decay=0.0): - global_step = variables.Variable(0) - - def decay_fn(rate, t): - return rate * (1 + math_ops.to_float(t))**(-decay) - - train_op = layers.optimize_loss( - scalar_loss, - global_step, - learning_rate, - optimizer, - learning_rate_decay_fn=decay_fn) - return train_op - - -def _assert_monotonic_decreasing(array, atol=1e-5): - array = np.asarray(array) - _assert_monotonic_increasing(-array, atol=atol) - - -def _assert_monotonic_increasing(array, atol=1e-5): - array = np.asarray(array) - diff = np.diff(array.ravel()) - np.testing.assert_array_less(-1 * atol, diff) - - -class ElboRatioTest(test.TestCase): - """Show sampling converges to true KL values.""" - - def setUp(self): - self._rng = np.random.RandomState(0) - - def test_convergence_to_kl_using_sample_form_on_3dim_normal(self): - # Test that the sample mean KL is the same as analytic when we use samples - # to estimate every part of the KL divergence ratio. - vector_shape = (2, 3) - n_samples = 5000 - - with self.test_session(): - q = mvn_diag_lib.MultivariateNormalDiag( - loc=self._rng.rand(*vector_shape), - scale_diag=self._rng.rand(*vector_shape)) - p = mvn_diag_lib.MultivariateNormalDiag( - loc=self._rng.rand(*vector_shape), - scale_diag=self._rng.rand(*vector_shape)) - - # In this case, the log_ratio is the KL. - sample_kl = -1 * entropy.elbo_ratio( - log_p=p.log_prob, - q=q, - n=n_samples, - form=entropy.ELBOForms.sample, - seed=42) - actual_kl = kullback_leibler_lib.kl_divergence(q, p) - - # Relative tolerance (rtol) chosen 2 times as large as minimim needed to - # pass. - self.assertEqual((2,), sample_kl.get_shape()) - self.assertAllClose(actual_kl.eval(), sample_kl.eval(), rtol=0.05) - - def test_convergence_to_kl_using_analytic_entropy_form_on_3dim_normal(self): - # Test that the sample mean KL is the same as analytic when we use an - # analytic entropy combined with sampled cross-entropy. - n_samples = 5000 - - vector_shape = (2, 3) - with self.test_session(): - q = mvn_diag_lib.MultivariateNormalDiag( - loc=self._rng.rand(*vector_shape), - scale_diag=self._rng.rand(*vector_shape)) - p = mvn_diag_lib.MultivariateNormalDiag( - loc=self._rng.rand(*vector_shape), - scale_diag=self._rng.rand(*vector_shape)) - - # In this case, the log_ratio is the KL. - sample_kl = -1 * entropy.elbo_ratio( - log_p=p.log_prob, - q=q, - n=n_samples, - form=entropy.ELBOForms.analytic_entropy, - seed=42) - actual_kl = kullback_leibler_lib.kl_divergence(q, p) - - # Relative tolerance (rtol) chosen 2 times as large as minimim needed to - # pass. - self.assertEqual((2,), sample_kl.get_shape()) - self.assertAllClose(actual_kl.eval(), sample_kl.eval(), rtol=0.1) - - def test_sample_kl_zero_when_p_and_q_are_the_same_distribution(self): - n_samples = 50 - - vector_shape = (2, 3) - with self.test_session(): - q = mvn_diag_lib.MultivariateNormalDiag( - loc=self._rng.rand(*vector_shape), - scale_diag=self._rng.rand(*vector_shape)) - - # In this case, the log_ratio is the KL. - sample_kl = -1 * entropy.elbo_ratio( - log_p=q.log_prob, - q=q, - n=n_samples, - form=entropy.ELBOForms.sample, - seed=42) - - self.assertEqual((2,), sample_kl.get_shape()) - self.assertAllClose(np.zeros(2), sample_kl.eval()) - - -class EntropyShannonTest(test.TestCase): - - def test_normal_entropy_default_form_uses_exact_entropy(self): - with self.test_session(): - dist = normal_lib.Normal(loc=1.11, scale=2.22) - mc_entropy = entropy.entropy_shannon(dist, n=11) - exact_entropy = dist.entropy() - self.assertEqual(exact_entropy.get_shape(), mc_entropy.get_shape()) - self.assertAllClose(exact_entropy.eval(), mc_entropy.eval()) - - def test_normal_entropy_analytic_form_uses_exact_entropy(self): - with self.test_session(): - dist = normal_lib.Normal(loc=1.11, scale=2.22) - mc_entropy = entropy.entropy_shannon( - dist, form=entropy.ELBOForms.analytic_entropy) - exact_entropy = dist.entropy() - self.assertEqual(exact_entropy.get_shape(), mc_entropy.get_shape()) - self.assertAllClose(exact_entropy.eval(), mc_entropy.eval()) - - def test_normal_entropy_sample_form_gets_approximate_answer(self): - # Tested by showing we get a good answer that is not exact. - with self.test_session(): - dist = normal_lib.Normal(loc=1.11, scale=2.22) - mc_entropy = entropy.entropy_shannon( - dist, n=1000, form=entropy.ELBOForms.sample, seed=0) - exact_entropy = dist.entropy() - - self.assertEqual(exact_entropy.get_shape(), mc_entropy.get_shape()) - - # Relative tolerance (rtol) chosen 2 times as large as minimim needed to - # pass. - self.assertAllClose(exact_entropy.eval(), mc_entropy.eval(), rtol=0.01) - - # Make sure there is some error, proving we used samples - self.assertLess(0.0001, math_ops.abs(exact_entropy - mc_entropy).eval()) - - def test_default_entropy_falls_back_on_sample_if_analytic_not_available(self): - # Tested by showing we get a good answer that is not exact. - with self.test_session(): - # NormalNoEntropy is like a Normal, but does not have .entropy method, so - # we are forced to fall back on sample entropy. - dist_no_entropy = NormalNoEntropy(loc=1.11, scale=2.22) - dist_yes_entropy = normal_lib.Normal(loc=1.11, scale=2.22) - - mc_entropy = entropy.entropy_shannon( - dist_no_entropy, n=1000, form=entropy.ELBOForms.sample, seed=0) - exact_entropy = dist_yes_entropy.entropy() - - self.assertEqual(exact_entropy.get_shape(), mc_entropy.get_shape()) - - # Relative tolerance (rtol) chosen 2 times as large as minimim needed to - # pass. - self.assertAllClose(exact_entropy.eval(), mc_entropy.eval(), rtol=0.01) - - # Make sure there is some error, proving we used samples - self.assertLess(0.0001, math_ops.abs(exact_entropy - mc_entropy).eval()) - - -class RenyiRatioTest(test.TestCase): - """Show renyi_ratio is minimized when the distributions match.""" - - def setUp(self): - self._rng = np.random.RandomState(0) - - def test_fitting_two_dimensional_normal_n_equals_1000(self): - # Minmizing Renyi divergence should allow us to make one normal match - # another one exactly. - n = 1000 - mu_true = np.array([1.0, -1.0], dtype=np.float64) - chol_true = np.array([[2.0, 0.0], [0.5, 1.0]], dtype=np.float64) - with self.test_session() as sess: - target = mvn_tril_lib.MultivariateNormalTriL(mu_true, chol_true) - - # Set up q distribution by defining mean/covariance as Variables - mu = variables.Variable( - np.zeros(mu_true.shape), dtype=mu_true.dtype, name='mu') - mat = variables.Variable( - np.zeros(chol_true.shape), dtype=chol_true.dtype, name='mat') - chol = distribution_util.matrix_diag_transform( - mat, transform=nn_ops.softplus) - q = mvn_tril_lib.MultivariateNormalTriL(mu, chol) - for alpha in [0.25, 0.75]: - - negative_renyi_divergence = entropy.renyi_ratio( - log_p=target.log_prob, q=q, n=n, alpha=alpha, seed=0) - train_op = get_train_op( - math_ops.reduce_mean(-negative_renyi_divergence), - optimizer='SGD', - learning_rate=0.5, - decay=0.1) - - variables.global_variables_initializer().run() - renyis = [] - for step in range(1000): - sess.run(train_op) - if step in [1, 5, 100]: - renyis.append(negative_renyi_divergence.eval()) - - # This optimization should maximize the renyi divergence. - _assert_monotonic_increasing(renyis, atol=0) - - # Relative tolerance (rtol) chosen 2 times as large as minimim needed to - # pass. - self.assertAllClose(target.loc.eval(), q.loc.eval(), rtol=0.06) - self.assertAllClose(target.scale.to_dense().eval(), - q.scale.to_dense().eval(), - rtol=0.1) - - def test_divergence_between_identical_distributions_is_zero(self): - n = 1000 - vector_shape = (2, 3) - with self.test_session(): - q = mvn_diag_lib.MultivariateNormalDiag( - loc=self._rng.rand(*vector_shape), - scale_diag=self._rng.rand(*vector_shape)) - for alpha in [0.25, 0.75]: - - negative_renyi_divergence = entropy.renyi_ratio( - log_p=q.log_prob, q=q, n=n, alpha=alpha, seed=0) - - self.assertEqual((2,), negative_renyi_divergence.get_shape()) - self.assertAllClose(np.zeros(2), negative_renyi_divergence.eval()) - - -class RenyiAlphaTest(test.TestCase): - - def test_with_three_alphas(self): - with self.test_session(): - for dtype in (dtypes.float32, dtypes.float64): - alpha_min = constant_op.constant(0.0, dtype=dtype) - alpha_max = 0.5 - decay_time = 3 - - alpha_0 = entropy.renyi_alpha( - 0, decay_time, alpha_min=alpha_min, alpha_max=alpha_max) - alpha_1 = entropy.renyi_alpha( - 1, decay_time, alpha_min=alpha_min, alpha_max=alpha_max) - alpha_2 = entropy.renyi_alpha( - 2, decay_time, alpha_min=alpha_min, alpha_max=alpha_max) - alpha_3 = entropy.renyi_alpha( - 3, decay_time, alpha_min=alpha_min, alpha_max=alpha_max) - - # Alpha should start at alpha_max. - self.assertAllClose(alpha_max, alpha_0.eval(), atol=1e-5) - # Alpha should finish at alpha_min. - self.assertAllClose(alpha_min.eval(), alpha_3.eval(), atol=1e-5) - # In between, alpha should be monotonically decreasing. - _assert_monotonic_decreasing( - [alpha_0.eval(), alpha_1.eval(), alpha_2.eval(), alpha_3.eval()]) - - def test_non_scalar_input_raises(self): - with self.test_session(): - # Good values here - step = 0 - alpha_min = 0.0 - alpha_max = 0.5 - decay_time = 3 - - # Use one bad value inside each check. - # The "bad" value is always the non-scalar one. - with self.assertRaisesRegexp(ValueError, 'must be scalar'): - entropy.renyi_alpha( - [step], decay_time, alpha_min=alpha_min, alpha_max=alpha_max).eval() - - with self.assertRaisesRegexp(ValueError, 'must be scalar'): - entropy.renyi_alpha( - step, [decay_time], alpha_min=alpha_min, alpha_max=alpha_max).eval() - - with self.assertRaisesRegexp(ValueError, 'must be scalar'): - entropy.renyi_alpha( - step, decay_time, alpha_min=[alpha_min], alpha_max=alpha_max).eval() - - with self.assertRaisesRegexp(ValueError, 'must be scalar'): - entropy.renyi_alpha( - step, decay_time, alpha_min=alpha_min, alpha_max=[alpha_max]).eval() - - def test_input_with_wrong_sign_raises(self): - with self.test_session(): - # Good values here - step = 0 - alpha_min = 0.0 - alpha_max = 0.5 - decay_time = 3 - - # Use one bad value inside each check. - # The "bad" value is always the non-scalar one. - with self.assertRaisesOpError('decay_time must be positive'): - entropy.renyi_alpha( - step, 0.0, alpha_min=alpha_min, alpha_max=alpha_max).eval() - - with self.assertRaisesOpError('step must be non-negative'): - entropy.renyi_alpha( - -1, decay_time, alpha_min=alpha_min, alpha_max=alpha_max).eval() - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/halton_sequence_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/halton_sequence_test.py new file mode 100644 index 0000000000000000000000000000000000000000..0a85862abfd744a86b9a38e10dbb5b985d0a0e94 --- /dev/null +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/halton_sequence_test.py @@ -0,0 +1,131 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for halton_sequence.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.bayesflow.python.ops import halton_sequence as halton +from tensorflow.contrib.bayesflow.python.ops import monte_carlo_impl as monte_carlo_lib +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import normal as normal_lib +from tensorflow.python.platform import test + + +mc = monte_carlo_lib + + +class HaltonSequenceTest(test.TestCase): + + def test_known_values_small_bases(self): + with self.test_session(): + # The first five elements of the Halton sequence with base 2 and 3 + expected = np.array(((1. / 2, 1. / 3), + (1. / 4, 2. / 3), + (3. / 4, 1. / 9), + (1. / 8, 4. / 9), + (5. / 8, 7. / 9)), dtype=np.float32) + sample = halton.sample(2, num_samples=5) + self.assertAllClose(expected, sample.eval(), rtol=1e-6) + + def test_sample_indices(self): + with self.test_session(): + dim = 5 + indices = math_ops.range(10, dtype=dtypes.int32) + sample_direct = halton.sample(dim, num_samples=10) + sample_from_indices = halton.sample(dim, sample_indices=indices) + self.assertAllClose(sample_direct.eval(), sample_from_indices.eval(), + rtol=1e-6) + + def test_dtypes_works_correctly(self): + with self.test_session(): + dim = 3 + sample_float32 = halton.sample(dim, num_samples=10, dtype=dtypes.float32) + sample_float64 = halton.sample(dim, num_samples=10, dtype=dtypes.float64) + self.assertEqual(sample_float32.eval().dtype, np.float32) + self.assertEqual(sample_float64.eval().dtype, np.float64) + + def test_normal_integral_mean_and_var_correctly_estimated(self): + n = int(1000) + # This test is almost identical to the similarly named test in + # monte_carlo_test.py. The only difference is that we use the Halton + # samples instead of the random samples to evaluate the expectations. + # MC with pseudo random numbers converges at the rate of 1/ Sqrt(N) + # (N=number of samples). For QMC in low dimensions, the expected convergence + # rate is ~ 1/N. Hence we should only need 1e3 samples as compared to the + # 1e6 samples used in the pseudo-random monte carlo. + with self.test_session(): + mu_p = array_ops.constant([-1.0, 1.0], dtype=dtypes.float64) + mu_q = array_ops.constant([0.0, 0.0], dtype=dtypes.float64) + sigma_p = array_ops.constant([0.5, 0.5], dtype=dtypes.float64) + sigma_q = array_ops.constant([1.0, 1.0], dtype=dtypes.float64) + p = normal_lib.Normal(loc=mu_p, scale=sigma_p) + q = normal_lib.Normal(loc=mu_q, scale=sigma_q) + + cdf_sample = halton.sample(2, num_samples=n, dtype=dtypes.float64) + q_sample = q.quantile(cdf_sample) + + # Compute E_p[X]. + e_x = mc.expectation_importance_sampler( + f=lambda x: x, log_p=p.log_prob, sampling_dist_q=q, z=q_sample, + seed=42) + + # Compute E_p[X^2]. + e_x2 = mc.expectation_importance_sampler( + f=math_ops.square, log_p=p.log_prob, sampling_dist_q=q, z=q_sample, + seed=42) + + stddev = math_ops.sqrt(e_x2 - math_ops.square(e_x)) + # Keep the tolerance levels the same as in monte_carlo_test.py. + self.assertEqual(p.batch_shape, e_x.get_shape()) + self.assertAllClose(p.mean().eval(), e_x.eval(), rtol=0.01) + self.assertAllClose(p.stddev().eval(), stddev.eval(), rtol=0.02) + + def test_docstring_example(self): + # Produce the first 1000 members of the Halton sequence in 3 dimensions. + num_samples = 1000 + dim = 3 + with self.test_session(): + sample = halton.sample(dim, num_samples=num_samples) + + # Evaluate the integral of x_1 * x_2^2 * x_3^3 over the three dimensional + # hypercube. + powers = math_ops.range(1.0, limit=dim + 1) + integral = math_ops.reduce_mean( + math_ops.reduce_prod(sample ** powers, axis=-1)) + true_value = 1.0 / math_ops.reduce_prod(powers + 1.0) + + # Produces a relative absolute error of 1.7%. + self.assertAllClose(integral.eval(), true_value.eval(), rtol=0.02) + + # Now skip the first 1000 samples and recompute the integral with the next + # thousand samples. The sample_indices argument can be used to do this. + + sample_indices = math_ops.range(start=1000, limit=1000 + num_samples, + dtype=dtypes.int32) + sample_leaped = halton.sample(dim, sample_indices=sample_indices) + + integral_leaped = math_ops.reduce_mean( + math_ops.reduce_prod(sample_leaped ** powers, axis=-1)) + self.assertAllClose(integral_leaped.eval(), true_value.eval(), rtol=0.001) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py index b1f108e5f01e4945ee83d8262f1d99877f0fe9f0..cbc66b6dc13db62c25952de6b6c13b2fdfe27f12 100644 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for Hamiltonian Monte Carlo. -""" +"""Tests for Hamiltonian Monte Carlo.""" from __future__ import absolute_import from __future__ import division @@ -27,6 +26,7 @@ from tensorflow.contrib.bayesflow.python.ops import hmc from tensorflow.python.framework import random_seed from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.platform import test @@ -46,6 +46,9 @@ class HMCTest(test.TestCase): random_seed.set_random_seed(10003) np.random.seed(10003) + def assertAllFinite(self, x): + self.assertAllEqual(np.ones_like(x).astype(bool), np.isfinite(x)) + def _log_gamma_log_prob(self, x, event_dims=()): """Computes log-pdf of a log-gamma random variable. @@ -345,5 +348,97 @@ class HMCTest(test.TestCase): def testAIS12(self): self._ais_gets_correct_log_normalizer_wrapper([1, 2]) + def testNanRejection(self): + """Tests that an update that yields NaN potentials gets rejected. + + We run HMC with a target distribution that returns NaN + log-likelihoods if any element of x < 0, and unit-scale + exponential log-likelihoods otherwise. The exponential potential + pushes x towards 0, ensuring that any reasonably large update will + push us over the edge into NaN territory. + """ + def _unbounded_exponential_log_prob(x): + """An exponential distribution with log-likelihood NaN for x < 0.""" + per_element_potentials = array_ops.where(x < 0, + np.nan * array_ops.ones_like(x), + -x) + return math_ops.reduce_sum(per_element_potentials) + + with self.test_session() as sess: + initial_x = math_ops.linspace(0.01, 5, 10) + updated_x, acceptance_probs, _, _ = hmc.kernel( + 2., 5, initial_x, _unbounded_exponential_log_prob, [0]) + initial_x_val, updated_x_val, acceptance_probs_val = sess.run( + [initial_x, updated_x, acceptance_probs]) + + logging.vlog(1, 'initial_x = {}'.format(initial_x_val)) + logging.vlog(1, 'updated_x = {}'.format(updated_x_val)) + logging.vlog(1, 'acceptance_probs = {}'.format(acceptance_probs_val)) + + self.assertAllEqual(initial_x_val, updated_x_val) + self.assertEqual(acceptance_probs_val, 0.) + + def testNanFromGradsDontPropagate(self): + """Test that update with NaN gradients does not cause NaN in results.""" + def _nan_log_prob_with_nan_gradient(x): + return np.nan * math_ops.reduce_sum(x) + + with self.test_session() as sess: + initial_x = math_ops.linspace(0.01, 5, 10) + updated_x, acceptance_probs, new_log_prob, new_grad = hmc.kernel( + 2., 5, initial_x, _nan_log_prob_with_nan_gradient, [0]) + initial_x_val, updated_x_val, acceptance_probs_val = sess.run( + [initial_x, updated_x, acceptance_probs]) + + logging.vlog(1, 'initial_x = {}'.format(initial_x_val)) + logging.vlog(1, 'updated_x = {}'.format(updated_x_val)) + logging.vlog(1, 'acceptance_probs = {}'.format(acceptance_probs_val)) + + self.assertAllEqual(initial_x_val, updated_x_val) + self.assertEqual(acceptance_probs_val, 0.) + + self.assertAllFinite( + gradients_impl.gradients(updated_x, initial_x)[0].eval()) + self.assertTrue( + gradients_impl.gradients(new_grad, initial_x)[0] is None) + + # Gradients of the acceptance probs and new log prob are not finite. + _ = new_log_prob # Prevent unused arg error. + # self.assertAllFinite( + # gradients_impl.gradients(acceptance_probs, initial_x)[0].eval()) + # self.assertAllFinite( + # gradients_impl.gradients(new_log_prob, initial_x)[0].eval()) + + def testChainWorksIn64Bit(self): + def log_prob(x): + return - math_ops.reduce_sum(x * x, axis=-1) + states, acceptance_probs = hmc.chain( + n_iterations=10, + step_size=np.float64(0.01), + n_leapfrog_steps=10, + initial_x=np.zeros(5).astype(np.float64), + target_log_prob_fn=log_prob, + event_dims=[-1]) + with self.test_session() as sess: + states_, acceptance_probs_ = sess.run([states, acceptance_probs]) + self.assertEqual(np.float64, states_.dtype) + self.assertEqual(np.float64, acceptance_probs_.dtype) + + def testChainWorksIn16Bit(self): + def log_prob(x): + return - math_ops.reduce_sum(x * x, axis=-1) + states, acceptance_probs = hmc.chain( + n_iterations=10, + step_size=np.float16(0.01), + n_leapfrog_steps=10, + initial_x=np.zeros(5).astype(np.float16), + target_log_prob_fn=log_prob, + event_dims=[-1]) + with self.test_session() as sess: + states_, acceptance_probs_ = sess.run([states, acceptance_probs]) + self.assertEqual(np.float16, states_.dtype) + self.assertEqual(np.float16, acceptance_probs_.dtype) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/layers_dense_variational_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/layers_dense_variational_test.py new file mode 100644 index 0000000000000000000000000000000000000000..50358fd1c2b7635ffe2d08c5af3219bb0a11498b --- /dev/null +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/layers_dense_variational_test.py @@ -0,0 +1,304 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for dense Bayesian layers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.bayesflow.python.ops import layers_dense_variational_impl as prob_layers_lib +from tensorflow.python.framework import ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops.distributions import normal as normal_lib +from tensorflow.python.platform import test + + +class Counter(object): + """Helper class to manage incrementing a counting `int`.""" + + def __init__(self): + self._value = -1 + + @property + def value(self): + return self._value + + def __call__(self): + self._value += 1 + return self._value + + +class MockDistribution(normal_lib.Normal): + """Monitors DenseVariational calls to the underlying distribution.""" + + def __init__(self, result_sample, result_log_prob, loc=None, scale=None): + self.result_sample = result_sample + self.result_log_prob = result_log_prob + self.result_loc = loc + self.result_scale = scale + self.called_log_prob = Counter() + self.called_sample = Counter() + self.called_loc = Counter() + self.called_scale = Counter() + + def log_prob(self, *args, **kwargs): + self.called_log_prob() + return self.result_log_prob + + def sample(self, *args, **kwargs): + self.called_sample() + return self.result_sample + + @property + def loc(self): + self.called_loc() + return self.result_loc + + @property + def scale(self): + self.called_scale() + return self.result_scale + + +class MockKLDivergence(object): + """Monitors DenseVariational calls to the divergence implementation.""" + + def __init__(self, result): + self.result = result + self.args = [] + self.called = Counter() + + def __call__(self, *args, **kwargs): + self.called() + self.args.append(args) + return self.result + + +class DenseVariationalLocalReparametrization(test.TestCase): + + def testKLPenaltyKernel(self): + with self.test_session(): + dense_vi = prob_layers_lib.DenseVariational(units=2) + inputs = random_ops.random_uniform([2, 3], seed=1) + + # No keys. + loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) + self.assertEqual(len(loss_keys), 0) + self.assertListEqual(dense_vi.losses, loss_keys) + + _ = dense_vi(inputs) + + # Yes keys. + loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) + self.assertEqual(len(loss_keys), 1) + self.assertListEqual(dense_vi.losses, loss_keys) + + def testKLPenaltyBoth(self): + def _make_normal(dtype, *args): # pylint: disable=unused-argument + return normal_lib.Normal( + loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)) + with self.test_session(): + dense_vi = prob_layers_lib.DenseVariational( + units=2, + bias_posterior_fn=prob_layers_lib.default_mean_field_normal_fn(), + bias_prior_fn=_make_normal) + inputs = random_ops.random_uniform([2, 3], seed=1) + + # No keys. + loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) + self.assertEqual(len(loss_keys), 0) + self.assertListEqual(dense_vi.losses, loss_keys) + + _ = dense_vi(inputs) + + # Yes keys. + loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) + self.assertEqual(len(loss_keys), 2) + self.assertListEqual(dense_vi.losses, loss_keys) + + def testVariationalNonLocal(self): + batch_size, in_size, out_size = 2, 3, 4 + with self.test_session() as sess: + seed = Counter() + inputs = random_ops.random_uniform([batch_size, in_size], seed=seed()) + + kernel_size = [in_size, out_size] + kernel_posterior = MockDistribution( + result_log_prob=random_ops.random_uniform(kernel_size, seed=seed()), + result_sample=random_ops.random_uniform(kernel_size, seed=seed())) + kernel_prior = MockDistribution( + result_log_prob=random_ops.random_uniform(kernel_size, seed=seed()), + result_sample=random_ops.random_uniform(kernel_size, seed=seed())) + kernel_divergence = MockKLDivergence( + result=random_ops.random_uniform(kernel_size, seed=seed())) + + bias_size = [out_size] + bias_posterior = MockDistribution( + result_log_prob=random_ops.random_uniform(bias_size, seed=seed()), + result_sample=random_ops.random_uniform(bias_size, seed=seed())) + bias_prior = MockDistribution( + result_log_prob=random_ops.random_uniform(bias_size, seed=seed()), + result_sample=random_ops.random_uniform(bias_size, seed=seed())) + bias_divergence = MockKLDivergence( + result=random_ops.random_uniform(bias_size, seed=seed())) + + expected_outputs = ( + math_ops.matmul(inputs, kernel_posterior.result_sample) + + bias_posterior.result_sample) + + dense_vi = prob_layers_lib.DenseVariational( + units=2, + kernel_use_local_reparameterization=False, + kernel_posterior_fn=lambda *args: kernel_posterior, + kernel_posterior_tensor_fn=lambda d: d.sample(seed=42), + kernel_prior_fn=lambda *args: kernel_prior, + kernel_divergence_fn=kernel_divergence, + bias_posterior_fn=lambda *args: bias_posterior, + bias_posterior_tensor_fn=lambda d: d.sample(seed=43), + bias_prior_fn=lambda *args: bias_prior, + bias_divergence_fn=bias_divergence) + + outputs = dense_vi(inputs) + + kl_penalty = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) + + [ + expected_outputs_, actual_outputs_, + expected_kernel_, actual_kernel_, + expected_kernel_divergence_, actual_kernel_divergence_, + expected_bias_, actual_bias_, + expected_bias_divergence_, actual_bias_divergence_, + ] = sess.run([ + expected_outputs, outputs, + kernel_posterior.result_sample, dense_vi.kernel.posterior_tensor, + kernel_divergence.result, kl_penalty[0], + bias_posterior.result_sample, dense_vi.bias.posterior_tensor, + bias_divergence.result, kl_penalty[1], + ]) + + self.assertAllClose( + expected_kernel_, actual_kernel_, + rtol=1e-6, atol=0.) + self.assertAllClose( + expected_bias_, actual_bias_, + rtol=1e-6, atol=0.) + self.assertAllClose( + expected_outputs_, actual_outputs_, + rtol=1e-6, atol=0.) + self.assertAllClose( + expected_kernel_divergence_, actual_kernel_divergence_, + rtol=1e-6, atol=0.) + self.assertAllClose( + expected_bias_divergence_, actual_bias_divergence_, + rtol=1e-6, atol=0.) + + self.assertAllEqual( + [[kernel_posterior, kernel_prior, kernel_posterior.result_sample]], + kernel_divergence.args) + + self.assertAllEqual( + [[bias_posterior, bias_prior, bias_posterior.result_sample]], + bias_divergence.args) + + def testVariationalLocal(self): + batch_size, in_size, out_size = 2, 3, 4 + with self.test_session() as sess: + seed = Counter() + inputs = random_ops.random_uniform([batch_size, in_size], seed=seed()) + + kernel_size = [in_size, out_size] + kernel_posterior = MockDistribution( + loc=random_ops.random_uniform(kernel_size, seed=seed()), + scale=random_ops.random_uniform(kernel_size, seed=seed()), + result_log_prob=random_ops.random_uniform(kernel_size, seed=seed()), + result_sample=random_ops.random_uniform(kernel_size, seed=seed())) + kernel_prior = MockDistribution( + result_log_prob=random_ops.random_uniform(kernel_size, seed=seed()), + result_sample=random_ops.random_uniform(kernel_size, seed=seed())) + kernel_divergence = MockKLDivergence( + result=random_ops.random_uniform(kernel_size, seed=seed())) + + bias_size = [out_size] + bias_posterior = MockDistribution( + result_log_prob=random_ops.random_uniform(bias_size, seed=seed()), + result_sample=random_ops.random_uniform(bias_size, seed=seed())) + bias_prior = MockDistribution( + result_log_prob=random_ops.random_uniform(bias_size, seed=seed()), + result_sample=random_ops.random_uniform(bias_size, seed=seed())) + bias_divergence = MockKLDivergence( + result=random_ops.random_uniform(bias_size, seed=seed())) + + expected_kernel_posterior_affine = normal_lib.Normal( + loc=math_ops.matmul(inputs, kernel_posterior.result_loc), + scale=math_ops.matmul( + inputs**2., kernel_posterior.result_scale**2)**0.5) + expected_kernel_posterior_affine_tensor = ( + expected_kernel_posterior_affine.sample(seed=42)) + expected_outputs = (expected_kernel_posterior_affine_tensor + + bias_posterior.result_sample) + + dense_vi = prob_layers_lib.DenseVariational( + units=2, + kernel_use_local_reparameterization=True, + kernel_posterior_fn=lambda *args: kernel_posterior, + kernel_posterior_tensor_fn=lambda d: d.sample(seed=42), + kernel_prior_fn=lambda *args: kernel_prior, + kernel_divergence_fn=kernel_divergence, + bias_posterior_fn=lambda *args: bias_posterior, + bias_posterior_tensor_fn=lambda d: d.sample(seed=43), + bias_prior_fn=lambda *args: bias_prior, + bias_divergence_fn=bias_divergence) + + outputs = dense_vi(inputs) + + kl_penalty = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) + + [ + expected_outputs_, actual_outputs_, + expected_kernel_divergence_, actual_kernel_divergence_, + expected_bias_, actual_bias_, + expected_bias_divergence_, actual_bias_divergence_, + ] = sess.run([ + expected_outputs, outputs, + kernel_divergence.result, kl_penalty[0], + bias_posterior.result_sample, dense_vi.bias.posterior_tensor, + bias_divergence.result, kl_penalty[1], + ]) + + self.assertAllClose( + expected_bias_, actual_bias_, + rtol=1e-6, atol=0.) + self.assertAllClose( + expected_outputs_, actual_outputs_, + rtol=1e-6, atol=0.) + self.assertAllClose( + expected_kernel_divergence_, actual_kernel_divergence_, + rtol=1e-6, atol=0.) + self.assertAllClose( + expected_bias_divergence_, actual_bias_divergence_, + rtol=1e-6, atol=0.) + + self.assertAllEqual( + [[kernel_posterior, kernel_prior, None]], + kernel_divergence.args) + + self.assertAllEqual( + [[bias_posterior, bias_prior, bias_posterior.result_sample]], + bias_divergence.args) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/sgld_optimizer_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/sgld_optimizer_test.py new file mode 100644 index 0000000000000000000000000000000000000000..66793383fdd5c71f136900197a91be6966e2f8c7 --- /dev/null +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/sgld_optimizer_test.py @@ -0,0 +1,209 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Functional test for GradientDescent.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import math +from tensorflow.contrib.bayesflow.python.ops.optimizers import SGLDOptimizer +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +class SGLDOptimizerTest(test.TestCase): + + def testBasic(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.test_session(): + var0 = variables.Variable([1.1, 2.1], dtype=dtype) + var1 = variables.Variable([3.0, 4.0], dtype=dtype) + grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) + grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) + decay_rate = 0.53 + sgd_op = SGLDOptimizer( + 3.0, preconditioner_decay_rate=decay_rate).apply_gradients( + zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + # Fetch params to validate initial values + self.assertAllCloseAccordingToType([1.1, 2.1], var0.eval()) + self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval()) + # Run 1 step of sgd + sgd_op.run() + # Validate updated params + grads_scaled = (0.5 * 0.1 / math.sqrt(decay_rate + + (1 - decay_rate) * 0.1**2 + 1e-8)) + self.assertAllCloseAccordingToType( + [1.1 - 3.0 * grads_scaled, 2.1 - 3.0 * grads_scaled], var0.eval()) + grads_scaled = (0.5 * 0.01 / math.sqrt( + decay_rate + (1 - decay_rate) * 0.01**2 + 1e-8)) + self.assertAllCloseAccordingToType( + [3.0 - 3.0 * grads_scaled, 4.0 - 3.0 * grads_scaled], var1.eval()) + + def testBasicMultiInstance(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.test_session(): + var0 = variables.Variable([1.1, 2.1], dtype=dtype) + var1 = variables.Variable([3.0, 4.0], dtype=dtype) + grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) + grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) + vara = variables.Variable([1.1, 2.1], dtype=dtype) + varb = variables.Variable([3.0, 4.0], dtype=dtype) + gradsa = constant_op.constant([0.1, 0.1], dtype=dtype) + gradsb = constant_op.constant([0.01, 0.01], dtype=dtype) + decay_rate = 0.5 + sgd_optimizer = SGLDOptimizer(3.0, preconditioner_decay_rate=decay_rate) + sgd_op = sgd_optimizer.apply_gradients( + zip([grads0, grads1], [var0, var1])) + sgd_optimizer2 = SGLDOptimizer( + 3.0, preconditioner_decay_rate=decay_rate) + sgd_op2 = sgd_optimizer2.apply_gradients( + zip([gradsa, gradsb], [vara, varb])) + variables.global_variables_initializer().run() + # Fetch params to validate initial values + self.assertAllCloseAccordingToType([1.1, 2.1], var0.eval()) + self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval()) + self.assertAllCloseAccordingToType([1.1, 2.1], vara.eval()) + self.assertAllCloseAccordingToType([3.0, 4.0], varb.eval()) + + # Run 1 step of sgd + sgd_op.run() + sgd_op2.run() + # Validate updated params + grads_scaled = (0.5 * 0.1 / math.sqrt(decay_rate + + (1 - decay_rate) * 0.1**2 + 1e-8)) + self.assertAllCloseAccordingToType( + [1.1 - 3.0 * grads_scaled, 2.1 - 3.0 * grads_scaled], var0.eval()) + self.assertAllCloseAccordingToType( + [1.1 - 3.0 * grads_scaled, 2.1 - 3.0 * grads_scaled], vara.eval()) + + grads_scaled = (0.5 * 0.01 / math.sqrt( + decay_rate + (1 - decay_rate) * 0.01**2 + 1e-8)) + self.assertAllCloseAccordingToType( + [3.0 - 3.0 * grads_scaled, 4.0 - 3.0 * grads_scaled], var1.eval()) + self.assertAllCloseAccordingToType( + [3.0 - 3.0 * grads_scaled, 4.0 - 3.0 * grads_scaled], varb.eval()) + self.assertNotEqual(sgd_optimizer.variable_scope, + sgd_optimizer2.variable_scope) + self.assertNotEqual(sgd_optimizer.variable_scope.name, + sgd_optimizer2.variable_scope.name) + + def testTensorLearningRate(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.test_session(): + var0 = variables.Variable([1.1, 2.1], dtype=dtype) + var1 = variables.Variable([3.0, 4.0], dtype=dtype) + grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) + grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) + lrate = constant_op.constant(3.0) + decay_rate = 0.5 + sgd_op = SGLDOptimizer( + lrate, preconditioner_decay_rate=constant_op.constant( + decay_rate)).apply_gradients( + zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + # Fetch params to validate initial values + self.assertAllCloseAccordingToType([1.1, 2.1], var0.eval()) + self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval()) + # Run 1 step of sgd + sgd_op.run() + # Validate updated params + grads_scaled = (0.5 * 0.1 / math.sqrt(decay_rate + + (1 - decay_rate) * 0.1**2 + 1e-8)) + self.assertAllCloseAccordingToType( + [1.1 - 3.0 * grads_scaled, 2.1 - 3.0 * grads_scaled], var0.eval()) + grads_scaled = (0.5 * 0.01 / math.sqrt( + decay_rate + (1 - decay_rate) * 0.01**2 + 1e-8)) + self.assertAllCloseAccordingToType( + [3.0 - 3.0 * grads_scaled, 4.0 - 3.0 * grads_scaled], var1.eval()) + + def testGradWrtRef(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.test_session(): + opt = SGLDOptimizer(3.0) + values = [1.0, 3.0] + vars_ = [variables.Variable([v], dtype=dtype) for v in values] + grads_and_vars = opt.compute_gradients(vars_[0] + vars_[1], vars_) + variables.global_variables_initializer().run() + for grad, _ in grads_and_vars: + self.assertAllCloseAccordingToType([1.0], grad.eval()) + + def testWithGlobalStep(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.test_session(): + global_step = variables.Variable(0, trainable=False) + var0 = variables.Variable([1.1, 2.1], dtype=dtype) + var1 = variables.Variable([3.0, 4.0], dtype=dtype) + grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) + grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) + decay_rate = 0.1 + sgd_op = SGLDOptimizer( + 3.0, preconditioner_decay_rate=decay_rate).apply_gradients( + zip([grads0, grads1], [var0, var1]), global_step=global_step) + variables.global_variables_initializer().run() + # Fetch params to validate initial values + self.assertAllCloseAccordingToType([1.1, 2.1], var0.eval()) + self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval()) + # Run 1 step of sgd + sgd_op.run() + + # Validate updated params and global_step + grads_scaled = (0.5 * 0.1 / math.sqrt(decay_rate + + (1 - decay_rate) * 0.1**2 + 1e-8)) + self.assertAllCloseAccordingToType( + [1.1 - 3.0 * grads_scaled, 2.1 - 3.0 * grads_scaled], var0.eval()) + grads_scaled = (0.5 * 0.01 / math.sqrt( + decay_rate + (1 - decay_rate) * 0.01**2 + 1e-8)) + self.assertAllCloseAccordingToType( + [3.0 - 3.0 * grads_scaled, 4.0 - 3.0 * grads_scaled], var1.eval()) + self.assertAllCloseAccordingToType(1, global_step.eval()) + + def testSparseBasic(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.test_session(): + var0 = variables.Variable([[1.1], [2.1]], dtype=dtype) + var1 = variables.Variable([[3.0], [4.0]], dtype=dtype) + grads0 = ops.IndexedSlices( + constant_op.constant([0.1], shape=[1, 1], dtype=dtype), + constant_op.constant([0]), constant_op.constant([2, 1])) + grads1 = ops.IndexedSlices( + constant_op.constant([0.01], shape=[1, 1], dtype=dtype), + constant_op.constant([1]), constant_op.constant([2, 1])) + decay_rate = 0.9 + sgd_op = SGLDOptimizer( + 3.0, preconditioner_decay_rate=decay_rate).apply_gradients( + zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + # Fetch params to validate initial values + self.assertAllCloseAccordingToType([[1.1], [2.1]], var0.eval()) + self.assertAllCloseAccordingToType([[3.0], [4.0]], var1.eval()) + # Run 1 step of sgd + sgd_op.run() + # Validate updated params + grads_scaled = (0.5 * 0.1 / math.sqrt(decay_rate + + (1 - decay_rate) * 0.1**2 + 1e-8)) + self.assertAllCloseAccordingToType([[1.1 - 3.0 * grads_scaled], [2.1]], + var0.eval()) + grads_scaled = (0.5 * 0.01 / math.sqrt( + decay_rate + (1 - decay_rate) * 0.01**2 + 1e-8)) + self.assertAllCloseAccordingToType( + [[3.0 - 3.0 * 0], [4.0 - 3.0 * grads_scaled]], var1.eval()) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_gradient_estimators_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_gradient_estimators_test.py deleted file mode 100644 index 9b1f482b34967082d6ac44494123879fb8fb0ee3..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_gradient_estimators_test.py +++ /dev/null @@ -1,206 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for stochastic graphs.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np -from tensorflow.contrib import distributions -from tensorflow.contrib.bayesflow.python.ops import stochastic_gradient_estimators -from tensorflow.contrib.bayesflow.python.ops import stochastic_tensor -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.ops import gradient_checker -from tensorflow.python.ops import gradients_impl -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import variables -from tensorflow.python.platform import test - -st = stochastic_tensor -sge = stochastic_gradient_estimators -dists = distributions - - -def _vimco(loss): - """Python implementation of VIMCO.""" - n = loss.shape[0] - log_loss = np.log(loss) - geometric_mean = [] - for j in range(n): - geometric_mean.append( - np.exp(np.mean([log_loss[i, :] for i in range(n) if i != j], 0))) - geometric_mean = np.array(geometric_mean) - - learning_signal = [] - for j in range(n): - learning_signal.append(np.sum([loss[i, :] for i in range(n) if i != j], 0)) - learning_signal = np.array(learning_signal) - - local_learning_signal = np.log(1 / n * (learning_signal + geometric_mean)) - - # log_mean - local_learning_signal - log_mean = np.log(np.mean(loss, 0)) - advantage = log_mean - local_learning_signal - - return advantage - - -class StochasticGradientEstimatorsTest(test.TestCase): - - def setUp(self): - self._p = constant_op.constant(0.999999) - self._final_loss = constant_op.constant(3.2) - - def _testScoreFunction(self, loss_fn, expected): - x = st.StochasticTensor(dists.Bernoulli(probs=self._p), loss_fn=loss_fn) - sf = x.loss(self._final_loss) - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - self.assertAllClose(*sess.run([expected, sf])) - - def testScoreFunction(self): - expected = math_ops.log(self._p) * self._final_loss - self._testScoreFunction(sge.score_function, expected) - - def testScoreFunctionWithConstantBaseline(self): - b = constant_op.constant(9.8) - expected = math_ops.log(self._p) * (self._final_loss - b) - self._testScoreFunction( - sge.get_score_function_with_constant_baseline(b), expected) - - def testScoreFunctionWithBaselineFn(self): - b = constant_op.constant(9.8) - - def baseline_fn(stoch_tensor, loss): - self.assertTrue(isinstance(stoch_tensor, st.StochasticTensor)) - self.assertTrue(isinstance(loss, ops.Tensor)) - return b - - expected = math_ops.log(self._p) * (self._final_loss - b) - self._testScoreFunction( - sge.get_score_function_with_baseline(baseline_fn), expected) - - def testScoreFunctionWithMeanBaseline(self): - ema_decay = 0.8 - num_steps = 6 - x = st.StochasticTensor( - dists.Bernoulli(probs=self._p), - loss_fn=sge.get_score_function_with_baseline( - sge.get_mean_baseline(ema_decay))) - sf = x.loss(self._final_loss) - - # Expected EMA value - ema = 0. - for _ in range(num_steps): - ema -= (1. - ema_decay) * (ema - self._final_loss) - - # Baseline is EMA with bias correction - bias_correction = 1. - ema_decay**num_steps - baseline = ema / bias_correction - expected = math_ops.log(self._p) * (self._final_loss - baseline) - - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - for _ in range(num_steps - 1): - sess.run(sf) # run to update EMA - self.assertAllClose(*sess.run([expected, sf])) - - def testScoreFunctionWithAdvantageFn(self): - b = constant_op.constant(9.8) - - def advantage_fn(stoch_tensor, loss): - self.assertTrue(isinstance(stoch_tensor, st.StochasticTensor)) - self.assertTrue(isinstance(loss, ops.Tensor)) - return loss - b - - expected = math_ops.log(self._p) * (self._final_loss - b) - self._testScoreFunction( - sge.get_score_function_with_advantage(advantage_fn), expected) - - def testVIMCOAdvantageFn(self): - # simple_loss: (3, 2) with 3 samples, batch size 2 - simple_loss = np.array( - [[1.0, 1.5], - [1e-6, 1e4], - [2.0, 3.0]]) - # random_loss: (100, 50, 64) with 100 samples, batch shape (50, 64) - random_loss = 100 * np.random.rand(100, 50, 64) - - advantage_fn = sge.get_vimco_advantage_fn(have_log_loss=False) - - with self.test_session() as sess: - for loss in [simple_loss, random_loss]: - expected = _vimco(loss) - loss_t = constant_op.constant(loss, dtype=dtypes.float32) - advantage_t = advantage_fn(None, loss_t) # ST is not used - advantage = sess.run(advantage_t) - self.assertEqual(expected.shape, advantage_t.get_shape()) - self.assertAllClose(expected, advantage, atol=5e-5) - - def testVIMCOAdvantageGradients(self): - loss = np.log( - [[1.0, 1.5], - [1e-6, 1e4], - [2.0, 3.0]]) - advantage_fn = sge.get_vimco_advantage_fn(have_log_loss=True) - - with self.test_session(): - loss_t = constant_op.constant(loss, dtype=dtypes.float64) - advantage_t = advantage_fn(None, loss_t) # ST is not used - gradient_error = gradient_checker.compute_gradient_error( - loss_t, - loss_t.get_shape().as_list(), - advantage_t, - advantage_t.get_shape().as_list(), - x_init_value=loss) - self.assertLess(gradient_error, 1e-3) - - def testVIMCOAdvantageWithSmallProbabilities(self): - theta_value = np.random.rand(10, 100000) - # Test with float16 dtype to ensure stability even in this extreme case. - theta = constant_op.constant(theta_value, dtype=dtypes.float16) - advantage_fn = sge.get_vimco_advantage_fn(have_log_loss=True) - - with self.test_session() as sess: - log_loss = -math_ops.reduce_sum(theta, [1]) - advantage_t = advantage_fn(None, log_loss) - grad_t = gradients_impl.gradients(advantage_t, theta)[0] - advantage, grad = sess.run((advantage_t, grad_t)) - self.assertTrue(np.all(np.isfinite(advantage))) - self.assertTrue(np.all(np.isfinite(grad))) - - def testScoreFunctionWithMeanBaselineHasUniqueVarScope(self): - ema_decay = 0.8 - x = st.StochasticTensor( - dists.Bernoulli(probs=self._p), - loss_fn=sge.get_score_function_with_baseline( - sge.get_mean_baseline(ema_decay))) - y = st.StochasticTensor( - dists.Bernoulli(probs=self._p), - loss_fn=sge.get_score_function_with_baseline( - sge.get_mean_baseline(ema_decay))) - sf_x = x.loss(self._final_loss) - sf_y = y.loss(self._final_loss) - with self.test_session() as sess: - # Smoke test - sess.run(variables.global_variables_initializer()) - sess.run([sf_x, sf_y]) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_graph_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_graph_test.py deleted file mode 100644 index 44e27db03b18d0e6a789db676bea684c10dcfca7..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_graph_test.py +++ /dev/null @@ -1,246 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for stochastic graphs.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib import distributions as distributions_lib -from tensorflow.contrib.bayesflow.python.ops import stochastic_graph_impl -from tensorflow.contrib.bayesflow.python.ops import stochastic_tensor -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import gradients_impl -from tensorflow.python.ops import math_ops -from tensorflow.python.platform import test - -st = stochastic_tensor -sg = stochastic_graph_impl -distributions = distributions_lib - - -class NormalNotParam(distributions.Normal): - - @property - def reparameterization_type(self): - return distributions.NOT_REPARAMETERIZED - - -class TestSurrogateLosses(test.TestCase): - - def testPathwiseDerivativeDoesNotAddSurrogateLosses(self): - with self.test_session(): - mu = [0.0, 0.1, 0.2] - sigma = constant_op.constant([1.1, 1.2, 1.3]) - with st.value_type(st.SampleValue()): - prior = st.StochasticTensor(distributions.Normal(loc=mu, scale=sigma)) - likelihood = st.StochasticTensor( - distributions.Normal( - loc=prior, scale=sigma)) - self.assertEqual( - prior.distribution.reparameterization_type, - distributions.FULLY_REPARAMETERIZED) - self.assertEqual( - likelihood.distribution.reparameterization_type, - distributions.FULLY_REPARAMETERIZED) - - loss = math_ops.square(array_ops.identity(likelihood) - [0.0, 0.1, 0.2]) - sum_loss = math_ops.reduce_sum(loss) - - surrogate_loss = sg.surrogate_loss([loss]) - with self.assertRaisesRegexp(ValueError, "dimensionality 1 or greater"): - _ = sg.surrogate_loss([sum_loss]) - surrogate_from_both = sg.surrogate_loss( - [loss, sum_loss * array_ops.ones_like(loss)]) - - # Pathwise derivative terms do not require add'l surrogate loss terms. - with self.test_session() as sess: - self.assertAllClose(*sess.run([loss, surrogate_loss])) - self.assertAllClose(*sess.run([(loss + sum_loss), surrogate_from_both])) - - def _testSurrogateLoss(self, session, losses, expected_addl_terms, xs): - surrogate_loss = sg.surrogate_loss(losses) - expected_surrogate_loss = math_ops.add_n(losses + expected_addl_terms) - self.assertAllClose(*session.run([surrogate_loss, expected_surrogate_loss])) - - # Test backprop - expected_grads = gradients_impl.gradients(ys=expected_surrogate_loss, xs=xs) - surrogate_grads = gradients_impl.gradients(ys=surrogate_loss, xs=xs) - self.assertEqual(len(expected_grads), len(surrogate_grads)) - grad_values = session.run(expected_grads + surrogate_grads) - n_grad = len(expected_grads) - self.assertAllClose(grad_values[:n_grad], grad_values[n_grad:]) - - def testSurrogateLoss(self): - with self.test_session() as sess: - mu = constant_op.constant([0.0, 0.1, 0.2]) - sigma = constant_op.constant([1.1, 1.2, 1.3]) - with st.value_type(st.SampleValue()): - prior = st.StochasticTensor(NormalNotParam(loc=mu, scale=sigma)) - likelihood = st.StochasticTensor(NormalNotParam(loc=prior, scale=sigma)) - prior_2 = st.StochasticTensor(NormalNotParam(loc=mu, scale=sigma)) - - loss = math_ops.square(array_ops.identity(likelihood) - mu) - part_loss = math_ops.square(array_ops.identity(prior) - mu) - sum_loss = math_ops.reduce_sum(loss) - loss_nodeps = math_ops.square(array_ops.identity(prior_2) - mu) - - # For ground truth, use the stop-gradient versions of the losses - loss_nograd = array_ops.stop_gradient(loss) - loss_nodeps_nograd = array_ops.stop_gradient(loss_nodeps) - sum_loss_nograd = array_ops.stop_gradient(sum_loss) - - # These score functions should ignore prior_2 - self._testSurrogateLoss( - session=sess, - losses=[loss], - expected_addl_terms=[ - likelihood.distribution.log_prob( - likelihood.value()) * loss_nograd, - prior.distribution.log_prob(prior.value()) * loss_nograd - ], - xs=[mu, sigma]) - - self._testSurrogateLoss( - session=sess, - losses=[loss, part_loss], - expected_addl_terms=[ - likelihood.distribution.log_prob( - likelihood.value()) * loss_nograd, - (prior.distribution.log_prob(prior.value()) * - array_ops.stop_gradient(part_loss + loss)) - ], - xs=[mu, sigma]) - - self._testSurrogateLoss( - session=sess, - losses=[sum_loss * array_ops.ones_like(loss)], - expected_addl_terms=[( - likelihood.distribution.log_prob(likelihood.value()) * - sum_loss_nograd), prior.distribution.log_prob(prior.value()) * - sum_loss_nograd], - xs=[mu, sigma]) - - self._testSurrogateLoss( - session=sess, - losses=[loss, sum_loss * array_ops.ones_like(loss)], - expected_addl_terms=[( - likelihood.distribution.log_prob(likelihood.value()) * - array_ops.stop_gradient(loss + sum_loss)), - (prior.distribution.log_prob(prior.value()) * - array_ops.stop_gradient(loss + sum_loss))], - xs=[mu, sigma]) - - # These score functions should ignore prior and likelihood - self._testSurrogateLoss( - session=sess, - losses=[loss_nodeps], - expected_addl_terms=[(prior_2.distribution.log_prob(prior_2.value()) * - loss_nodeps_nograd)], - xs=[mu, sigma]) - - # These score functions should include all terms selectively - self._testSurrogateLoss( - session=sess, - losses=[loss, loss_nodeps], - # We can't guarantee ordering of output losses in this case. - expected_addl_terms=[( - likelihood.distribution.log_prob(likelihood.value()) * - loss_nograd), prior.distribution.log_prob(prior.value()) * - loss_nograd, - (prior_2.distribution.log_prob(prior_2.value()) * - loss_nodeps_nograd)], - xs=[mu, sigma]) - - def testNoSurrogateLoss(self): - with self.test_session(): - mu = constant_op.constant([0.0, 0.1, 0.2]) - sigma = constant_op.constant([1.1, 1.2, 1.3]) - with st.value_type(st.SampleValue()): - dt = st.StochasticTensor( - NormalNotParam( - loc=mu, scale=sigma), loss_fn=None) - self.assertEqual(None, dt.loss(constant_op.constant([2.0]))) - - def testExplicitStochasticTensors(self): - with self.test_session() as sess: - mu = constant_op.constant([0.0, 0.1, 0.2]) - sigma = constant_op.constant([1.1, 1.2, 1.3]) - with st.value_type(st.SampleValue()): - dt1 = st.StochasticTensor(NormalNotParam(loc=mu, scale=sigma)) - dt2 = st.StochasticTensor(NormalNotParam(loc=mu, scale=sigma)) - loss = math_ops.square(array_ops.identity(dt1)) + 10. + dt2 - - sl_all = sg.surrogate_loss([loss]) - sl_dt1 = sg.surrogate_loss([loss], stochastic_tensors=[dt1]) - sl_dt2 = sg.surrogate_loss([loss], stochastic_tensors=[dt2]) - - dt1_term = dt1.distribution.log_prob(dt1) * loss - dt2_term = dt2.distribution.log_prob(dt2) * loss - - self.assertAllClose(*sess.run( - [sl_all, sum([loss, dt1_term, dt2_term])])) - self.assertAllClose(*sess.run([sl_dt1, sum([loss, dt1_term])])) - self.assertAllClose(*sess.run([sl_dt2, sum([loss, dt2_term])])) - - -class StochasticDependenciesMapTest(test.TestCase): - - def testBuildsMapOfUpstreamNodes(self): - dt1 = st.StochasticTensor(distributions.Normal(loc=0., scale=1.)) - dt2 = st.StochasticTensor(distributions.Normal(loc=0., scale=1.)) - out1 = dt1.value() + 1. - out2 = dt2.value() + 2. - x = out1 + out2 - y = out2 * 3. - dep_map = sg._stochastic_dependencies_map([x, y]) - self.assertEqual(dep_map[dt1], set([x])) - self.assertEqual(dep_map[dt2], set([x, y])) - - def testHandlesStackedStochasticNodes(self): - dt1 = st.StochasticTensor(distributions.Normal(loc=0., scale=1.)) - out1 = dt1.value() + 1. - dt2 = st.StochasticTensor(distributions.Normal(loc=out1, scale=1.)) - x = dt2.value() + 2. - dt3 = st.StochasticTensor(distributions.Normal(loc=0., scale=1.)) - y = dt3.value() * 3. - dep_map = sg._stochastic_dependencies_map([x, y]) - self.assertEqual(dep_map[dt1], set([x])) - self.assertEqual(dep_map[dt2], set([x])) - self.assertEqual(dep_map[dt3], set([y])) - - def testTraversesControlInputs(self): - dt1 = st.StochasticTensor(distributions.Normal(loc=0., scale=1.)) - logits = dt1.value() * 3. - dt2 = st.StochasticTensor(distributions.Bernoulli(logits=logits)) - dt3 = st.StochasticTensor(distributions.Normal(loc=0., scale=1.)) - x = dt3.value() - y = array_ops.ones((2, 2)) * 4. - z = array_ops.ones((2, 2)) * 3. - out = control_flow_ops.cond( - math_ops.cast(dt2, dtypes.bool), lambda: math_ops.add(x, y), - lambda: math_ops.square(z)) - out += 5. - dep_map = sg._stochastic_dependencies_map([out]) - self.assertEqual(dep_map[dt1], set([out])) - self.assertEqual(dep_map[dt2], set([out])) - self.assertEqual(dep_map[dt3], set([out])) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_tensor_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_tensor_test.py deleted file mode 100644 index 6d0cff4678972719cb5c565bc409041e298beadb..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_tensor_test.py +++ /dev/null @@ -1,239 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for stochastic graphs.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib.bayesflow.python.ops import stochastic_gradient_estimators -from tensorflow.contrib.bayesflow.python.ops import stochastic_tensor_impl -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops.distributions import normal -from tensorflow.python.platform import test - -sge = stochastic_gradient_estimators -st = stochastic_tensor_impl - - -class StochasticTensorTest(test.TestCase): - - def testConstructionAndValue(self): - with self.test_session() as sess: - mu = [0.0, 0.1, 0.2] - sigma = constant_op.constant([1.1, 1.2, 1.3]) - sigma2 = constant_op.constant([0.1, 0.2, 0.3]) - - prior_default = st.StochasticTensor( - normal.Normal(loc=mu, scale=sigma)) - self.assertTrue(isinstance(prior_default.value_type, st.SampleValue)) - prior_0 = st.StochasticTensor( - normal.Normal(loc=mu, scale=sigma), - dist_value_type=st.SampleValue()) - self.assertTrue(isinstance(prior_0.value_type, st.SampleValue)) - - with st.value_type(st.SampleValue()): - prior = st.StochasticTensor(normal.Normal(loc=mu, scale=sigma)) - self.assertTrue(isinstance(prior.value_type, st.SampleValue)) - likelihood = st.StochasticTensor( - normal.Normal(loc=prior, scale=sigma2)) - self.assertTrue(isinstance(likelihood.value_type, st.SampleValue)) - - coll = ops.get_collection(st.STOCHASTIC_TENSOR_COLLECTION) - self.assertEqual(coll, [prior_default, prior_0, prior, likelihood]) - - # Also works: tf.convert_to_tensor(prior) - prior_default = array_ops.identity(prior_default) - prior_0 = array_ops.identity(prior_0) - prior = array_ops.identity(prior) - likelihood = array_ops.identity(likelihood) - - # Mostly a smoke test for now... - prior_0_val, prior_val, prior_default_val, _ = sess.run( - [prior_0, prior, prior_default, likelihood]) - - self.assertEqual(prior_0_val.shape, prior_val.shape) - self.assertEqual(prior_default_val.shape, prior_val.shape) - # These are different random samples from the same distribution, - # so the values should differ. - self.assertGreater(np.abs(prior_0_val - prior_val).sum(), 1e-6) - self.assertGreater(np.abs(prior_default_val - prior_val).sum(), 1e-6) - - def testMeanValue(self): - with self.test_session() as sess: - mu = [0.0, -1.0, 1.0] - sigma = constant_op.constant([1.1, 1.2, 1.3]) - - with st.value_type(st.MeanValue()): - prior = st.StochasticTensor(normal.Normal(loc=mu, scale=sigma)) - self.assertTrue(isinstance(prior.value_type, st.MeanValue)) - - prior_mean = prior.mean() - prior_value = prior.value() - - prior_mean_val, prior_value_val = sess.run([prior_mean, prior_value]) - self.assertAllEqual(prior_mean_val, mu) - self.assertAllEqual(prior_mean_val, prior_value_val) - - def testSampleValueScalar(self): - with self.test_session() as sess: - mu = [[0.0, -1.0, 1.0], [0.0, -1.0, 1.0]] - sigma = constant_op.constant([[1.1, 1.2, 1.3], [1.1, 1.2, 1.3]]) - - with st.value_type(st.SampleValue()): - prior_single = st.StochasticTensor( - normal.Normal(loc=mu, scale=sigma)) - - prior_single_value = prior_single.value() - self.assertEqual(prior_single_value.get_shape(), (2, 3)) - - prior_single_value_val = sess.run([prior_single_value])[0] - self.assertEqual(prior_single_value_val.shape, (2, 3)) - - with st.value_type(st.SampleValue(1)): - prior_single = st.StochasticTensor( - normal.Normal(loc=mu, scale=sigma)) - self.assertTrue(isinstance(prior_single.value_type, st.SampleValue)) - - prior_single_value = prior_single.value() - self.assertEqual(prior_single_value.get_shape(), (1, 2, 3)) - - prior_single_value_val = sess.run([prior_single_value])[0] - self.assertEqual(prior_single_value_val.shape, (1, 2, 3)) - - with st.value_type(st.SampleValue(2)): - prior_double = st.StochasticTensor( - normal.Normal(loc=mu, scale=sigma)) - - prior_double_value = prior_double.value() - self.assertEqual(prior_double_value.get_shape(), (2, 2, 3)) - - prior_double_value_val = sess.run([prior_double_value])[0] - self.assertEqual(prior_double_value_val.shape, (2, 2, 3)) - - def testDistributionEntropy(self): - with self.test_session() as sess: - mu = [0.0, -1.0, 1.0] - sigma = constant_op.constant([1.1, 1.2, 1.3]) - with st.value_type(st.MeanValue()): - prior = st.StochasticTensor(normal.Normal(loc=mu, scale=sigma)) - entropy = prior.entropy() - deep_entropy = prior.distribution.entropy() - expected_deep_entropy = normal.Normal( - loc=mu, scale=sigma).entropy() - entropies = sess.run([entropy, deep_entropy, expected_deep_entropy]) - self.assertAllEqual(entropies[2], entropies[0]) - self.assertAllEqual(entropies[1], entropies[0]) - - def testSurrogateLoss(self): - with self.test_session(): - mu = [[3.0, -4.0, 5.0], [6.0, -7.0, 8.0]] - sigma = constant_op.constant(1.0) - - # With default - with st.value_type(st.MeanValue(stop_gradient=True)): - dt = st.StochasticTensor(normal.Normal(loc=mu, scale=sigma)) - loss = dt.loss([constant_op.constant(2.0)]) - self.assertTrue(loss is not None) - self.assertAllClose( - dt.distribution.log_prob(mu).eval() * 2.0, loss.eval()) - - # With passed-in loss_fn. - dt = st.StochasticTensor( - normal.Normal(loc=mu, scale=sigma), - dist_value_type=st.MeanValue(stop_gradient=True), - loss_fn=sge.get_score_function_with_constant_baseline( - baseline=constant_op.constant(8.0))) - loss = dt.loss([constant_op.constant(2.0)]) - self.assertTrue(loss is not None) - self.assertAllClose((dt.distribution.log_prob(mu) * (2.0 - 8.0)).eval(), - loss.eval()) - - -class ValueTypeTest(test.TestCase): - - def testValueType(self): - type_mean = st.MeanValue() - type_reshape = st.SampleValue() - type_full = st.SampleValue() - with st.value_type(type_mean): - self.assertEqual(st.get_current_value_type(), type_mean) - with st.value_type(type_reshape): - self.assertEqual(st.get_current_value_type(), type_reshape) - with st.value_type(type_full): - self.assertEqual(st.get_current_value_type(), type_full) - self.assertEqual(st.get_current_value_type(), type_mean) - with self.assertRaisesRegexp(ValueError, "No value type currently set"): - st.get_current_value_type() - - -class ObservedStochasticTensorTest(test.TestCase): - - def testConstructionAndValue(self): - with self.test_session() as sess: - mu = [0.0, 0.1, 0.2] - sigma = constant_op.constant([1.1, 1.2, 1.3]) - obs = array_ops.zeros((2, 3)) - z = st.ObservedStochasticTensor( - normal.Normal(loc=mu, scale=sigma), value=obs) - [obs_val, z_val] = sess.run([obs, z.value()]) - self.assertAllEqual(obs_val, z_val) - - coll = ops.get_collection(st.STOCHASTIC_TENSOR_COLLECTION) - self.assertEqual(coll, [z]) - - def testConstructionWithUnknownShapes(self): - mu = array_ops.placeholder(dtypes.float32) - sigma = array_ops.placeholder(dtypes.float32) - obs = array_ops.placeholder(dtypes.float32) - z = st.ObservedStochasticTensor( - normal.Normal(loc=mu, scale=sigma), value=obs) - - mu2 = array_ops.placeholder(dtypes.float32, shape=[None]) - sigma2 = array_ops.placeholder(dtypes.float32, shape=[None]) - obs2 = array_ops.placeholder(dtypes.float32, shape=[None, None]) - z2 = st.ObservedStochasticTensor( - normal.Normal(loc=mu2, scale=sigma2), value=obs2) - - coll = ops.get_collection(st.STOCHASTIC_TENSOR_COLLECTION) - self.assertEqual(coll, [z, z2]) - - def testConstructionErrors(self): - mu = [0., 0.] - sigma = [1., 1.] - self.assertRaises( - ValueError, - st.ObservedStochasticTensor, - normal.Normal(loc=mu, scale=sigma), - value=array_ops.zeros((3,))) - self.assertRaises( - ValueError, - st.ObservedStochasticTensor, - normal.Normal(loc=mu, scale=sigma), - value=array_ops.zeros((3, 1))) - self.assertRaises( - ValueError, - st.ObservedStochasticTensor, - normal.Normal(loc=mu, scale=sigma), - value=array_ops.zeros((1, 2), dtype=dtypes.int32)) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_variables_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_variables_test.py deleted file mode 100644 index 9ee59a03ca76c6095e34b869d9b175e2c9223cd7..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_variables_test.py +++ /dev/null @@ -1,168 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for stochastic graphs.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np -from tensorflow.contrib import distributions -from tensorflow.contrib.bayesflow.python.ops import stochastic_tensor -from tensorflow.contrib.bayesflow.python.ops import stochastic_variables -from tensorflow.contrib.bayesflow.python.ops import variational_inference_impl -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import random_ops -from tensorflow.python.ops import variable_scope -from tensorflow.python.ops import variables -from tensorflow.python.platform import test - -sv = stochastic_variables -st = stochastic_tensor -vi = variational_inference_impl -dist = distributions - - -class StochasticVariablesTest(test.TestCase): - - def testStochasticVariables(self): - shape = (10, 20) - with variable_scope.variable_scope( - "stochastic_variables", - custom_getter=sv.make_stochastic_variable_getter( - dist_cls=dist.NormalWithSoftplusScale)): - v = variable_scope.get_variable("sv", shape) - - self.assertTrue(isinstance(v, st.StochasticTensor)) - self.assertTrue(isinstance(v.distribution, dist.NormalWithSoftplusScale)) - - self.assertEqual( - {"stochastic_variables/sv_loc", "stochastic_variables/sv_scale"}, - set([v.op.name for v in variables.global_variables()])) - self.assertEqual( - set(variables.trainable_variables()), set(variables.global_variables())) - - v = ops.convert_to_tensor(v) - self.assertEqual(list(shape), v.get_shape().as_list()) - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - self.assertEqual(shape, sess.run(v).shape) - - def testStochasticVariablesWithConstantInitializer(self): - shape = (10, 20) - with variable_scope.variable_scope( - "stochastic_variables", - custom_getter=sv.make_stochastic_variable_getter( - dist_cls=dist.NormalWithSoftplusScale, - dist_kwargs={"validate_args": True}, - param_initializers={ - "loc": np.ones(shape) * 4., - "scale": np.ones(shape) * 2. - })): - v = variable_scope.get_variable("sv") - - for var in variables.global_variables(): - if "loc" in var.name: - mu_var = var - if "scale" in var.name: - sigma_var = var - - v = ops.convert_to_tensor(v) - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - self.assertAllEqual(np.ones(shape) * 4., sess.run(mu_var)) - self.assertAllEqual(np.ones(shape) * 2., sess.run(sigma_var)) - self.assertEqual(shape, sess.run(v).shape) - - def testStochasticVariablesWithCallableInitializer(self): - shape = (10, 20) - - def sigma_init(shape, dtype, partition_info): - _ = partition_info - return array_ops.ones(shape, dtype=dtype) * 2. - - with variable_scope.variable_scope( - "stochastic_variables", - custom_getter=sv.make_stochastic_variable_getter( - dist_cls=dist.NormalWithSoftplusScale, - dist_kwargs={"validate_args": True}, - param_initializers={ - "loc": np.ones( - shape, dtype=np.float32) * 4., - "scale": sigma_init - })): - v = variable_scope.get_variable("sv", shape) - - for var in variables.global_variables(): - if "loc" in var.name: - mu_var = var - if "scale" in var.name: - sigma_var = var - - v = ops.convert_to_tensor(v) - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - self.assertAllEqual(np.ones(shape) * 4., sess.run(mu_var)) - self.assertAllEqual(np.ones(shape) * 2., sess.run(sigma_var)) - self.assertEqual(shape, sess.run(v).shape) - - def testStochasticVariablesWithPrior(self): - shape = (10, 20) - prior = dist.Normal(0., 1.) - with variable_scope.variable_scope( - "stochastic_variables", - custom_getter=sv.make_stochastic_variable_getter( - dist_cls=dist.NormalWithSoftplusScale, prior=prior)): - w = variable_scope.get_variable("weights", shape) - - x = random_ops.random_uniform((8, 10)) - y = math_ops.matmul(x, w) - - prior_map = vi._find_variational_and_priors(y, None) - self.assertEqual(prior_map[w], prior) - elbo = vi.elbo(y, keep_batch_dim=False) - - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - sess.run(elbo) - - def testStochasticVariablesWithCallablePriorInitializer(self): - - def prior_init(shape, dtype): - return dist.Normal( - array_ops.zeros(shape, dtype), array_ops.ones(shape, dtype)) - - with variable_scope.variable_scope( - "stochastic_variables", - custom_getter=sv.make_stochastic_variable_getter( - dist_cls=dist.NormalWithSoftplusScale, prior=prior_init)): - w = variable_scope.get_variable("weights", (10, 20)) - - x = random_ops.random_uniform((8, 10)) - y = math_ops.matmul(x, w) - - prior_map = vi._find_variational_and_priors(y, None) - self.assertTrue(isinstance(prior_map[w], dist.Normal)) - elbo = vi.elbo(y, keep_batch_dim=False) - - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - sess.run(elbo) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/variational_inference_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/variational_inference_test.py deleted file mode 100644 index fff6b74b2efed27abd7b25cbe0e8e8b3904767e1..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/variational_inference_test.py +++ /dev/null @@ -1,146 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for variational inference.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib import distributions as distributions_lib -from tensorflow.contrib import layers -from tensorflow.contrib.bayesflow.python.ops import stochastic_tensor -from tensorflow.contrib.bayesflow.python.ops import variational_inference_impl -from tensorflow.python.framework import constant_op -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import variables -from tensorflow.python.ops.distributions import kullback_leibler -from tensorflow.python.ops.distributions import normal -from tensorflow.python.platform import test - -st = stochastic_tensor -vi = variational_inference_impl -distributions = distributions_lib - - -class NormalNoEntropy(distributions.Normal): - - def entropy(self): - raise NotImplementedError("entropy not implemented") - - -# For mini-VAE -def inference_net(x, latent_size): - return layers.linear(x, latent_size) - - -def generative_net(z, data_size): - return layers.linear(z, data_size) - - -def mini_vae(): - x = [[-6., 3., 6.], [-8., 4., 8.]] - prior = distributions.Normal(loc=0., scale=1.) - variational = st.StochasticTensor( - distributions.Normal( - loc=inference_net(x, 1), scale=1.)) - vi.register_prior(variational, prior) - px = distributions.Normal(loc=generative_net(variational, 3), scale=1.) - log_likelihood = math_ops.reduce_sum(px.log_prob(x), 1) - log_likelihood = array_ops.expand_dims(log_likelihood, -1) - return x, prior, variational, px, log_likelihood - - -class VariationalInferenceTest(test.TestCase): - - def testDefaultVariationalAndPrior(self): - _, prior, variational, _, log_likelihood = mini_vae() - elbo = vi.elbo(log_likelihood) - expected_elbo = log_likelihood - kullback_leibler.kl_divergence( - variational.distribution, prior) - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - self.assertAllEqual(*sess.run([expected_elbo, elbo])) - - def testExplicitVariationalAndPrior(self): - with self.test_session() as sess: - _, _, variational, _, log_likelihood = mini_vae() - prior = normal.Normal(loc=3., scale=2.) - elbo = vi.elbo( - log_likelihood, variational_with_prior={variational: prior}) - expected_elbo = log_likelihood - kullback_leibler.kl_divergence( - variational.distribution, prior) - sess.run(variables.global_variables_initializer()) - self.assertAllEqual(*sess.run([expected_elbo, elbo])) - - def testExplicitForms(self): - _, prior, variational, _, log_likelihood = mini_vae() - - elbos = [] - forms = vi.ELBOForms - for form in [ - forms.default, forms.analytic_kl, forms.sample, forms.analytic_entropy - ]: - elbo = vi.elbo( - log_likelihood=log_likelihood, - variational_with_prior={variational: prior}, - form=form) - elbos.append(elbo) - - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - log_likelihood_shape = array_ops.shape(log_likelihood).eval() - for elbo in elbos: - elbo.eval() - elbo_shape = array_ops.shape(elbo).eval() - self.assertAllEqual(log_likelihood_shape, elbo_shape) - self.assertEqual(elbo.dtype, log_likelihood.dtype) - - def testDefaultsSampleKLWithoutAnalyticKLOrEntropy(self): - x = constant_op.constant([[-6., 3., 6.]]) - - prior = distributions.Bernoulli(0.5) - variational = st.StochasticTensor( - NormalNoEntropy( - loc=inference_net(x, 1), scale=1.)) - vi.register_prior(variational, prior) - px = distributions.Normal(loc=generative_net(variational, 3), scale=1.) - log_likelihood = math_ops.reduce_sum(px.log_prob(x), 1) - - # No analytic KL available between prior and variational distributions. - with self.assertRaisesRegexp(NotImplementedError, "No KL"): - distributions.kl_divergence(variational.distribution, prior) - - elbo = vi.elbo( - variational_with_prior={variational: prior}, - log_likelihood=log_likelihood) - expected_elbo = log_likelihood + prior.log_prob( - variational) - variational.distribution.log_prob(variational) - - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - self.assertAllEqual(*sess.run([expected_elbo, elbo])) - - def testElboWithLogJoint(self): - with self.test_session() as sess: - _, prior, variational, _, log_likelihood = mini_vae() - log_joint = log_likelihood + prior.log_prob(variational) - elbo = vi.elbo_with_log_joint(log_joint) - sess.run(variables.global_variables_initializer()) - elbo.eval() - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/bayesflow/python/ops/entropy_impl.py b/tensorflow/contrib/bayesflow/python/ops/entropy_impl.py deleted file mode 100644 index 4a7679fb436b91c9ae70daf85552099e5b710cbc..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/ops/entropy_impl.py +++ /dev/null @@ -1,386 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Support for Entropy Ops. See ${python/contrib.bayesflow.entropy}. - -@@elbo_ratio -@@entropy_shannon -@@renyi_ratio -@@renyi_alpha -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import math - -from tensorflow.contrib.bayesflow.python.ops import monte_carlo_impl as monte_carlo -from tensorflow.contrib.bayesflow.python.ops import variational_inference -from tensorflow.contrib.bayesflow.python.ops.monte_carlo_impl import _get_samples as get_samples -from tensorflow.python.framework import ops -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.platform import tf_logging as logging - - -# Make utility functions from monte_carlo available. -# pylint: disable=protected-access -_get_samples = get_samples -_logspace_mean = monte_carlo._logspace_mean -_sample_mean = monte_carlo._sample_mean - -# pylint: enable=protected-access - -__all__ = [ - 'elbo_ratio', - 'entropy_shannon', - 'renyi_ratio', - 'renyi_alpha', -] - -ELBOForms = variational_inference.ELBOForms # pylint: disable=invalid-name - - -def elbo_ratio(log_p, - q, - z=None, - n=None, - seed=None, - form=None, - name='elbo_ratio'): - r"""Estimate of the ratio appearing in the `ELBO` and `KL` divergence. - - With `p(z) := exp{log_p(z)}`, this `Op` returns an approximation of - - ``` - E_q[ Log[p(Z) / q(Z)] ] - ``` - - The term `E_q[ Log[p(Z)] ]` is always computed as a sample mean. - The term `E_q[ Log[q(z)] ]` can be computed with samples, or an exact formula - if `q.entropy()` is defined. This is controlled with the kwarg `form`. - - This log-ratio appears in different contexts: - - #### `KL[q || p]` - - If `log_p(z) = Log[p(z)]` for distribution `p`, this `Op` approximates - the negative Kullback-Leibler divergence. - - ``` - elbo_ratio(log_p, q, n=100) = -1 * KL[q || p], - KL[q || p] = E[ Log[q(Z)] - Log[p(Z)] ] - ``` - - Note that if `p` is a `Distribution`, then - `distributions.kl_divergence(q, p)` may be defined and available as an - exact result. - - #### ELBO - - If `log_p(z) = Log[p(z, x)]` is the log joint of a distribution `p`, this is - the Evidence Lower BOund (ELBO): - - ``` - ELBO ~= E[ Log[p(Z, x)] - Log[q(Z)] ] - = Log[p(x)] - KL[q || p] - <= Log[p(x)] - ``` - - User supplies either `Tensor` of samples `z`, or number of samples to draw `n` - - Args: - log_p: Callable mapping samples from `q` to `Tensors` with - shape broadcastable to `q.batch_shape`. - For example, `log_p` works "just like" `q.log_prob`. - q: `tf.contrib.distributions.Distribution`. - z: `Tensor` of samples from `q`, produced by `q.sample(n)` for some `n`. - n: Integer `Tensor`. Number of samples to generate if `z` is not provided. - seed: Python integer to seed the random number generator. - form: Either `ELBOForms.analytic_entropy` (use formula for entropy of `q`) - or `ELBOForms.sample` (sample estimate of entropy), or `ELBOForms.default` - (attempt analytic entropy, fallback on sample). - Default value is `ELBOForms.default`. - name: A name to give this `Op`. - - Returns: - Scalar `Tensor` holding sample mean KL divergence. `shape` is the batch - shape of `q`, and `dtype` is the same as `q`. - - Raises: - ValueError: If `form` is not handled by this function. - """ - form = ELBOForms.default if form is None else form - - with ops.name_scope(name, values=[n, z]): - z = _get_samples(q, z, n, seed) - - entropy = entropy_shannon(q, z=z, form=form) - - # If log_p(z) = Log[p(z)], cross entropy = -E_q[log(p(Z))] - negative_cross_entropy = _sample_mean(log_p(z)) - - return entropy + negative_cross_entropy - - -def entropy_shannon(p, - z=None, - n=None, - seed=None, - form=None, - name='entropy_shannon'): - r"""Monte Carlo or deterministic computation of Shannon's entropy. - - Depending on the kwarg `form`, this `Op` returns either the analytic entropy - of the distribution `p`, or the sampled entropy: - - ``` - -n^{-1} sum_{i=1}^n p.log_prob(z_i), where z_i ~ p, - \approx - E_p[ Log[p(Z)] ] - = Entropy[p] - ``` - - User supplies either `Tensor` of samples `z`, or number of samples to draw `n` - - Args: - p: `tf.contrib.distributions.Distribution` - z: `Tensor` of samples from `p`, produced by `p.sample(n)` for some `n`. - n: Integer `Tensor`. Number of samples to generate if `z` is not provided. - seed: Python integer to seed the random number generator. - form: Either `ELBOForms.analytic_entropy` (use formula for entropy of `q`) - or `ELBOForms.sample` (sample estimate of entropy), or `ELBOForms.default` - (attempt analytic entropy, fallback on sample). - Default value is `ELBOForms.default`. - name: A name to give this `Op`. - - Returns: - A `Tensor` with same `dtype` as `p`, and shape equal to `p.batch_shape`. - - Raises: - ValueError: If `form` not handled by this function. - ValueError: If `form` is `ELBOForms.analytic_entropy` and `n` was provided. - """ - form = ELBOForms.default if form is None else form - - if n is not None and form == ELBOForms.analytic_entropy: - raise ValueError('If form == ELBOForms.analytic_entropy, n must be None.') - - with ops.name_scope(name, values=[n, z]): - # Entropy: -E_p[log(p(Z))]. - entropy = None - - # Try analytic path - if form in [ELBOForms.default, ELBOForms.analytic_entropy]: - try: - entropy = p.entropy() - logging.info('Using analytic entropy(p:%s)', p) - except NotImplementedError as e: - if form == ELBOForms.analytic_entropy: - raise e - elif form != ELBOForms.sample: - raise ValueError('ELBOForm not handled by this function: %s' % form) - - # Sample path - if entropy is None: - logging.info('Using sampled entropy(p:%s)', p) - if z is None: - z = p.sample(n, seed=seed) - entropy = -monte_carlo.expectation(p.log_prob, z) - - return entropy - - -def renyi_ratio(log_p, q, alpha, z=None, n=None, seed=None, name='renyi_ratio'): - r"""Monte Carlo estimate of the ratio appearing in Renyi divergence. - - This can be used to compute the Renyi (alpha) divergence, or a log evidence - approximation based on Renyi divergence. - - #### Definition - - With `z_i` iid samples from `q`, and `exp{log_p(z)} = p(z)`, this `Op` returns - the (biased for finite `n`) estimate: - - ``` - (1 - alpha)^{-1} Log[ n^{-1} sum_{i=1}^n ( p(z_i) / q(z_i) )^{1 - alpha}, - \approx (1 - alpha)^{-1} Log[ E_q[ (p(Z) / q(Z))^{1 - alpha} ] ] - ``` - - This ratio appears in different contexts: - - #### Renyi divergence - - If `log_p(z) = Log[p(z)]` is the log prob of a distribution, and - `alpha > 0`, `alpha != 1`, this `Op` approximates `-1` times Renyi divergence: - - ``` - # Choose reasonably high n to limit bias, see below. - renyi_ratio(log_p, q, alpha, n=100) - \approx -1 * D_alpha[q || p], where - D_alpha[q || p] := (1 - alpha)^{-1} Log E_q[(p(Z) / q(Z))^{1 - alpha}] - ``` - - The Renyi (or "alpha") divergence is non-negative and equal to zero iff - `q = p`. Various limits of `alpha` lead to different special case results: - - ``` - alpha D_alpha[q || p] - ----- --------------- - --> 0 Log[ int_{q > 0} p(z) dz ] - = 0.5, -2 Log[1 - Hel^2[q || p]], (\propto squared Hellinger distance) - --> 1 KL[q || p] - = 2 Log[ 1 + chi^2[q || p] ], (\propto squared Chi-2 divergence) - --> infty Log[ max_z{q(z) / p(z)} ], (min description length principle). - ``` - - See "Renyi Divergence Variational Inference", by Li and Turner. - - #### Log evidence approximation - - If `log_p(z) = Log[p(z, x)]` is the log of the joint distribution `p`, this is - an alternative to the ELBO common in variational inference. - - ``` - L_alpha(q, p) = Log[p(x)] - D_alpha[q || p] - ``` - - If `q` and `p` have the same support, and `0 < a <= b < 1`, one can show - `ELBO <= D_b <= D_a <= Log[p(x)]`. Thus, this `Op` allows a smooth - interpolation between the ELBO and the true evidence. - - #### Stability notes - - Note that when `1 - alpha` is not small, the ratio `(p(z) / q(z))^{1 - alpha}` - is subject to underflow/overflow issues. For that reason, it is evaluated in - log-space after centering. Nonetheless, infinite/NaN results may occur. For - that reason, one may wish to shrink `alpha` gradually. See the `Op` - `renyi_alpha`. Using `float64` will also help. - - - #### Bias for finite sample size - - Due to nonlinearity of the logarithm, for random variables `{X_1,...,X_n}`, - `E[ Log[sum_{i=1}^n X_i] ] != Log[ E[sum_{i=1}^n X_i] ]`. As a result, this - estimate is biased for finite `n`. For `alpha < 1`, it is non-decreasing - with `n` (in expectation). For example, if `n = 1`, this estimator yields the - same result as `elbo_ratio`, and as `n` increases the expected value - of the estimator increases. - - #### Call signature - - User supplies either `Tensor` of samples `z`, or number of samples to draw `n` - - Args: - log_p: Callable mapping samples from `q` to `Tensors` with - shape broadcastable to `q.batch_shape`. - For example, `log_p` works "just like" `q.log_prob`. - q: `tf.contrib.distributions.Distribution`. - `float64` `dtype` recommended. - `log_p` and `q` should be supported on the same set. - alpha: `Tensor` with shape `q.batch_shape` and values not equal to 1. - z: `Tensor` of samples from `q`, produced by `q.sample` for some `n`. - n: Integer `Tensor`. The number of samples to use if `z` is not provided. - Note that this can be highly biased for small `n`, see docstring. - seed: Python integer to seed the random number generator. - name: A name to give this `Op`. - - Returns: - renyi_result: The scaled log of sample mean. `Tensor` with `shape` equal - to batch shape of `q`, and `dtype` = `q.dtype`. - """ - with ops.name_scope(name, values=[alpha, n, z]): - z = _get_samples(q, z, n, seed) - - # Evaluate sample mean in logspace. Note that _logspace_mean will compute - # (among other things) the mean of q.log_prob(z), which could also be - # obtained with q.entropy(). However, DON'T use analytic entropy, because - # that increases variance, and could result in NaN/Inf values of a sensitive - # term. - - # log_values - # = (1 - alpha) * ( Log p - Log q ) - log_values = (1. - alpha) * (log_p(z) - q.log_prob(z)) - - # log_mean_values - # = Log[ E[ values ] ] - # = Log[ E[ (p / q)^{1-alpha} ] ] - log_mean_values = _logspace_mean(log_values) - - return log_mean_values / (1. - alpha) - - -def renyi_alpha(step, - decay_time, - alpha_min, - alpha_max=0.99999, - name='renyi_alpha'): - r"""Exponentially decaying `Tensor` appropriate for Renyi ratios. - - When minimizing the Renyi divergence for `0 <= alpha < 1` (or maximizing the - Renyi equivalent of elbo) in high dimensions, it is not uncommon to experience - `NaN` and `inf` values when `alpha` is far from `1`. - - For that reason, it is often desirable to start the optimization with `alpha` - very close to 1, and reduce it to a final `alpha_min` according to some - schedule. The user may even want to optimize using `elbo_ratio` for - some fixed time before switching to Renyi based methods. - - This `Op` returns an `alpha` decaying exponentially with step: - - ``` - s(step) = (exp{step / decay_time} - 1) / (e - 1) - t(s) = max(0, min(s, 1)), (smooth growth from 0 to 1) - alpha(t) = (1 - t) alpha_min + t alpha_max - ``` - - Args: - step: Non-negative scalar `Tensor`. Typically the global step or an - offset version thereof. - decay_time: Positive scalar `Tensor`. - alpha_min: `float` or `double` `Tensor`. - The minimal, final value of `alpha`, achieved when `step >= decay_time` - alpha_max: `Tensor` of same `dtype` as `alpha_min`. - The maximal, beginning value of `alpha`, achieved when `step == 0` - name: A name to give this `Op`. - - Returns: - alpha: A `Tensor` of same `dtype` as `alpha_min`. - """ - with ops.name_scope(name, values=[step, decay_time, alpha_min, alpha_max]): - alpha_min = ops.convert_to_tensor(alpha_min, name='alpha_min') - dtype = alpha_min.dtype - - alpha_max = ops.convert_to_tensor(alpha_max, dtype=dtype, name='alpha_max') - decay_time = math_ops.cast(decay_time, dtype) - step = math_ops.cast(step, dtype) - - check_scalars = [ - check_ops.assert_rank(step, 0, message='step must be scalar'), - check_ops.assert_rank( - decay_time, 0, message='decay_time must be scalar'), - check_ops.assert_rank(alpha_min, 0, message='alpha_min must be scalar'), - check_ops.assert_rank(alpha_max, 0, message='alpha_max must be scalar'), - ] - check_sign = [ - check_ops.assert_non_negative( - step, message='step must be non-negative'), - check_ops.assert_positive( - decay_time, message='decay_time must be positive'), - ] - - with ops.control_dependencies(check_scalars + check_sign): - theta = (math_ops.exp(step / decay_time) - 1.) / (math.e - 1.) - theta = math_ops.minimum(math_ops.maximum(theta, 0.), 1.) - return alpha_max * (1. - theta) + alpha_min * theta diff --git a/tensorflow/contrib/bayesflow/python/ops/entropy.py b/tensorflow/contrib/bayesflow/python/ops/halton_sequence.py similarity index 82% rename from tensorflow/contrib/bayesflow/python/ops/entropy.py rename to tensorflow/contrib/bayesflow/python/ops/halton_sequence.py index a22e1c1d4e098439760267fca1374f986e45be8f..49d747d538f5a4aa3134d28ba00a651cb509fa41 100644 --- a/tensorflow/contrib/bayesflow/python/ops/entropy.py +++ b/tensorflow/contrib/bayesflow/python/ops/halton_sequence.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Support for Entropy Ops. See ${python/contrib.bayesflow.entropy}.""" +"""Support for low discrepancy Halton sequences. + +""" from __future__ import absolute_import from __future__ import division @@ -20,12 +22,12 @@ from __future__ import print_function # go/tf-wildcard-import # pylint: disable=wildcard-import -from tensorflow.contrib.bayesflow.python.ops.entropy_impl import * +from tensorflow.contrib.bayesflow.python.ops.halton_sequence_impl import * # pylint: enable=wildcard-import from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ - 'ELBOForms', 'elbo_ratio', 'entropy_shannon', 'renyi_ratio', 'renyi_alpha' + 'sample', ] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/bayesflow/python/ops/halton_sequence_impl.py b/tensorflow/contrib/bayesflow/python/ops/halton_sequence_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..8cabf18903b5f15002470acdfb8fdd3ec31a7413 --- /dev/null +++ b/tensorflow/contrib/bayesflow/python/ops/halton_sequence_impl.py @@ -0,0 +1,264 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Quasi Monte Carlo support: Halton sequence. + +@@sample +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops + + +__all__ = [ + 'sample', +] + + +# The maximum dimension we support. This is limited by the number of primes +# in the _PRIMES array. +_MAX_DIMENSION = 1000 + + +def sample(dim, num_samples=None, sample_indices=None, dtype=None, name=None): + r"""Returns a sample from the `m` dimensional Halton sequence. + + Warning: The sequence elements take values only between 0 and 1. Care must be + taken to appropriately transform the domain of a function if it differs from + the unit cube before evaluating integrals using Halton samples. It is also + important to remember that quasi-random numbers are not a replacement for + pseudo-random numbers in every context. Quasi random numbers are completely + deterministic and typically have significant negative autocorrelation (unless + randomized). + + Computes the members of the low discrepancy Halton sequence in dimension + `dim`. The d-dimensional sequence takes values in the unit hypercube in d + dimensions. Currently, only dimensions up to 1000 are supported. The prime + base for the `k`-th axes is the k-th prime starting from 2. For example, + if dim = 3, then the bases will be [2, 3, 5] respectively and the first + element of the sequence will be: [0.5, 0.333, 0.2]. For a more complete + description of the Halton sequences see: + https://en.wikipedia.org/wiki/Halton_sequence. For low discrepancy sequences + and their applications see: + https://en.wikipedia.org/wiki/Low-discrepancy_sequence. + + The user must supply either `num_samples` or `sample_indices` but not both. + The former is the number of samples to produce starting from the first + element. If `sample_indices` is given instead, the specified elements of + the sequence are generated. For example, sample_indices=tf.range(10) is + equivalent to specifying n=10. + + Example Use: + + ```python + bf = tf.contrib.bayesflow + + # Produce the first 1000 members of the Halton sequence in 3 dimensions. + num_samples = 1000 + dim = 3 + sample = bf.halton_sequence.sample(dim, num_samples=num_samples) + + # Evaluate the integral of x_1 * x_2^2 * x_3^3 over the three dimensional + # hypercube. + powers = tf.range(1.0, limit=dim + 1) + integral = tf.reduce_mean(tf.reduce_prod(sample ** powers, axis=-1)) + true_value = 1.0 / tf.reduce_prod(powers + 1.0) + with tf.Session() as session: + values = session.run((integral, true_value)) + + # Produces a relative absolute error of 1.7%. + print ("Estimated: %f, True Value: %f" % values) + + # Now skip the first 1000 samples and recompute the integral with the next + # thousand samples. The sample_indices argument can be used to do this. + + + sample_indices = tf.range(start=1000, limit=1000 + num_samples, + dtype=tf.int32) + sample_leaped = halton.sample(dim, sample_indices=sample_indices) + + integral_leaped = tf.reduce_mean(tf.reduce_prod(sample_leaped ** powers, + axis=-1)) + with tf.Session() as session: + values = session.run((integral_leaped, true_value)) + # Now produces a relative absolute error of 0.05%. + print ("Leaped Estimated: %f, True Value: %f" % values) + ``` + + Args: + dim: Positive Python `int` representing each sample's `event_size.` Must + not be greater than 1000. + num_samples: (Optional) positive Python `int`. The number of samples to + generate. Either this parameter or sample_indices must be specified but + not both. If this parameter is None, then the behaviour is determined by + the `sample_indices`. + sample_indices: (Optional) `Tensor` of dtype int32 and rank 1. The elements + of the sequence to compute specified by their position in the sequence. + The entries index into the Halton sequence starting with 0 and hence, + must be whole numbers. For example, sample_indices=[0, 5, 6] will produce + the first, sixth and seventh elements of the sequence. If this parameter + is None, then the `num_samples` parameter must be specified which gives + the number of desired samples starting from the first sample. + dtype: (Optional) The dtype of the sample. One of `float32` or `float64`. + Default is `float32`. + name: (Optional) Python `str` describing ops managed by this function. If + not supplied the name of this function is used. + + Returns: + halton_elements: Elements of the Halton sequence. `Tensor` of supplied dtype + and `shape` `[num_samples, dim]` if `num_samples` was specified or shape + `[s, dim]` where s is the size of `sample_indices` if `sample_indices` + were specified. + + Raises: + ValueError: if both `sample_indices` and `num_samples` were specified or + if dimension `dim` is less than 1 or greater than 1000. + """ + if dim < 1 or dim > _MAX_DIMENSION: + raise ValueError( + 'Dimension must be between 1 and {}. Supplied {}'.format(_MAX_DIMENSION, + dim)) + if (num_samples is None) == (sample_indices is None): + raise ValueError('Either `num_samples` or `sample_indices` must be' + ' specified but not both.') + + dtype = dtype or dtypes.float32 + if not dtype.is_floating: + raise ValueError('dtype must be of `float`-type') + + with ops.name_scope(name, 'sample', values=[sample_indices]): + # Here and in the following, the shape layout is as follows: + # [sample dimension, event dimension, coefficient dimension]. + # The coefficient dimension is an intermediate axes which will hold the + # weights of the starting integer when expressed in the (prime) base for + # an event dimension. + indices = _get_indices(num_samples, sample_indices, dtype) + radixes = array_ops.constant(_PRIMES[0:dim], dtype=dtype, shape=[dim, 1]) + + max_sizes_by_axes = _base_expansion_size(math_ops.reduce_max(indices), + radixes) + + max_size = math_ops.reduce_max(max_sizes_by_axes) + + # The powers of the radixes that we will need. Note that there is a bit + # of an excess here. Suppose we need the place value coefficients of 7 + # in base 2 and 3. For 2, we will have 3 digits but we only need 2 digits + # for base 3. However, we can only create rectangular tensors so we + # store both expansions in a [2, 3] tensor. This leads to the problem that + # we might end up attempting to raise large numbers to large powers. For + # example, base 2 expansion of 1024 has 10 digits. If we were in 10 + # dimensions, then the 10th prime (29) we will end up computing 29^10 even + # though we don't need it. We avoid this by setting the exponents for each + # axes to 0 beyond the maximum value needed for that dimension. + exponents_by_axes = array_ops.tile([math_ops.range(max_size)], [dim, 1]) + weight_mask = exponents_by_axes > max_sizes_by_axes + capped_exponents = array_ops.where( + weight_mask, array_ops.zeros_like(exponents_by_axes), exponents_by_axes) + weights = radixes ** capped_exponents + coeffs = math_ops.floor_div(indices, weights) + coeffs *= 1 - math_ops.cast(weight_mask, dtype) + coeffs = (coeffs % radixes) / radixes + return math_ops.reduce_sum(coeffs / weights, axis=-1) + + +def _get_indices(n, sample_indices, dtype, name=None): + """Generates starting points for the Halton sequence procedure. + + The k'th element of the sequence is generated starting from a positive integer + which must be distinct for each `k`. It is conventional to choose the starting + point as `k` itself (or `k+1` if k is zero based). This function generates + the starting integers for the required elements and reshapes the result for + later use. + + Args: + n: Positive `int`. The number of samples to generate. If this + parameter is supplied, then `sample_indices` should be None. + sample_indices: `Tensor` of dtype int32 and rank 1. The entries + index into the Halton sequence starting with 0 and hence, must be whole + numbers. For example, sample_indices=[0, 5, 6] will produce the first, + sixth and seventh elements of the sequence. If this parameter is not None + then `n` must be None. + dtype: The dtype of the sample. One of `float32` or `float64`. + Default is `float32`. + name: Python `str` name which describes ops created by this function. + + Returns: + indices: `Tensor` of dtype `dtype` and shape = `[n, 1, 1]`. + """ + with ops.name_scope(name, 'get_indices', [n, sample_indices]): + if sample_indices is None: + sample_indices = math_ops.range(n, dtype=dtype) + else: + sample_indices = math_ops.cast(sample_indices, dtype) + + # Shift the indices so they are 1 based. + indices = sample_indices + 1 + + # Reshape to make space for the event dimension and the place value + # coefficients. + return array_ops.reshape(indices, [-1, 1, 1]) + + +def _base_expansion_size(num, bases): + """Computes the number of terms in the place value expansion. + + Let num = a0 + a1 b + a2 b^2 + ... ak b^k be the place value expansion of + `num` in base b (ak <> 0). This function computes and returns `k` for each + base `b` specified in `bases`. + + This can be inferred from the base `b` logarithm of `num` as follows: + $$k = Floor(log_b (num)) + 1 = Floor( log(num) / log(b)) + 1$$ + + Args: + num: Scalar `Tensor` of dtype either `float32` or `float64`. The number to + compute the base expansion size of. + bases: `Tensor` of the same dtype as num. The bases to compute the size + against. + + Returns: + Tensor of same dtype and shape as `bases` containing the size of num when + written in that base. + """ + return math_ops.floor(math_ops.log(num) / math_ops.log(bases)) + 1 + + +def _primes_less_than(n): + # Based on + # https://stackoverflow.com/questions/2068372/fastest-way-to-list-all-primes-below-n-in-python/3035188#3035188 + """Returns sorted array of primes such that `2 <= prime < n`.""" + small_primes = np.array((2, 3, 5)) + if n <= 6: + return small_primes[small_primes < n] + sieve = np.ones(n // 3 + (n % 6 == 2), dtype=np.bool) + sieve[0] = False + m = int(n ** 0.5) // 3 + 1 + for i in range(m): + if not sieve[i]: + continue + k = 3 * i + 1 | 1 + sieve[k ** 2 // 3::2 * k] = False + sieve[(k ** 2 + 4 * k - 2 * k * (i & 1)) // 3::2 * k] = False + return np.r_[2, 3, 3 * np.nonzero(sieve)[0] + 1 | 1] + +_PRIMES = _primes_less_than(7919+1) + +assert len(_PRIMES) == _MAX_DIMENSION diff --git a/tensorflow/contrib/bayesflow/python/ops/hmc_impl.py b/tensorflow/contrib/bayesflow/python/ops/hmc_impl.py index 333dce929530adceb30dcb63653a5bd009c059e0..5685a942e98800a39ec718adc67bcfd43aeafd52 100644 --- a/tensorflow/contrib/bayesflow/python/ops/hmc_impl.py +++ b/tensorflow/contrib/bayesflow/python/ops/hmc_impl.py @@ -27,6 +27,7 @@ from __future__ import print_function import numpy as np +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -174,9 +175,11 @@ def chain(n_iterations, step_size, n_leapfrog_steps, initial_x, potential_and_grad = _make_potential_and_grad(target_log_prob_fn) potential, grad = potential_and_grad(initial_x) - return functional_ops.scan(body, array_ops.zeros(n_iterations), - (initial_x, array_ops.zeros(non_event_shape), - -potential, -grad))[:2] + return functional_ops.scan( + body, array_ops.zeros(n_iterations, dtype=initial_x.dtype), + (initial_x, + array_ops.zeros(non_event_shape, dtype=initial_x.dtype), + -potential, -grad))[:2] def ais_chain(n_iterations, step_size, n_leapfrog_steps, initial_x, @@ -298,8 +301,9 @@ def ais_chain(n_iterations, step_size, n_leapfrog_steps, initial_x, return updated_x, acceptance_probs, w x, acceptance_probs, w = functional_ops.scan( - _body, beta_series, (initial_x, array_ops.zeros(non_event_shape), - array_ops.zeros(non_event_shape))) + _body, beta_series, + (initial_x, array_ops.zeros(non_event_shape, dtype=initial_x.dtype), + array_ops.zeros(non_event_shape, dtype=initial_x.dtype))) return w[-1], x[-1], acceptance_probs[-1] @@ -446,9 +450,10 @@ def kernel(step_size, n_leapfrog_steps, x, target_log_prob_fn, event_dims=(), """ with ops.name_scope(name, 'hmc_kernel', [step_size, n_leapfrog_steps, x]): potential_and_grad = _make_potential_and_grad(target_log_prob_fn) + x = ops.convert_to_tensor(x, name='x') x_shape = array_ops.shape(x) - m = random_ops.random_normal(x_shape) + m = random_ops.random_normal(x_shape, dtype=x.dtype) kinetic_0 = 0.5 * math_ops.reduce_sum(math_ops.square(m), event_dims) @@ -468,26 +473,33 @@ def kernel(step_size, n_leapfrog_steps, x, target_log_prob_fn, event_dims=(), kinetic_1 = 0.5 * math_ops.reduce_sum(math_ops.square(new_m), event_dims) - # TODO(mhoffman): It seems like there may be an opportunity for nans here. - # I'm delaying addressing this because we're going to refactor this part - # to use the more general Metropolis abstraction anyway. - acceptance_probs = math_ops.exp(math_ops.minimum(0., log_potential_0 - - log_potential_1 + - kinetic_0 - kinetic_1)) - accepted = math_ops.cast( - random_ops.random_uniform(array_ops.shape(acceptance_probs)) < - acceptance_probs, np.float32) - new_log_prob = (-log_potential_0 * (1. - accepted) - - log_potential_1 * accepted) + energy_change = log_potential_1 - log_potential_0 + kinetic_1 - kinetic_0 + # Treat NaN as infinite energy (and therefore guaranteed rejection). + energy_change = array_ops.where( + math_ops.is_nan(energy_change), + array_ops.fill(array_ops.shape(energy_change), + energy_change.dtype.as_numpy_dtype(np.inf)), + energy_change) + acceptance_probs = math_ops.exp(math_ops.minimum(-energy_change, 0.)) + accepted = ( + random_ops.random_uniform( + array_ops.shape(acceptance_probs), dtype=x.dtype) + < acceptance_probs) + new_log_prob = -array_ops.where(accepted, log_potential_1, log_potential_0) # TODO(b/65738010): This should work, but it doesn't for now. # reduced_shape = math_ops.reduced_shape(x_shape, event_dims) reduced_shape = array_ops.shape(math_ops.reduce_sum(x, event_dims, keep_dims=True)) accepted = array_ops.reshape(accepted, reduced_shape) - new_x = x * (1. - accepted) + new_x * accepted - new_grad = -grad_0 * (1. - accepted) - grad_1 * accepted - + accepted = math_ops.logical_or( + accepted, math_ops.cast(array_ops.zeros_like(x), dtypes.bool)) + new_x = array_ops.where(accepted, new_x, x) + new_grad = -array_ops.where(accepted, grad_1, grad_0) + + # TODO(langmore) Gradients of acceptance_probs and new_log_prob with respect + # to initial_x will propagate NaNs (see testNanFromGradsDontPropagate). This + # should be fixed. return new_x, acceptance_probs, new_log_prob, new_grad @@ -525,6 +537,7 @@ def leapfrog_integrator(step_size, n_steps, initial_position, initial_momentum, Has shape matching `initial_position`. Example: Simple quadratic potential. + ```python def potential_and_grad(position): return tf.reduce_sum(0.5 * tf.square(position)), position @@ -600,6 +613,7 @@ def leapfrog_step(step_size, position, momentum, potential_and_grad, grad, Has shape matching `position`. Example: Simple quadratic potential. + ```python def potential_and_grad(position): # Simple quadratic potential diff --git a/tensorflow/contrib/bayesflow/python/ops/variational_inference.py b/tensorflow/contrib/bayesflow/python/ops/layers.py similarity index 74% rename from tensorflow/contrib/bayesflow/python/ops/variational_inference.py rename to tensorflow/contrib/bayesflow/python/ops/layers.py index 6316361da2accf39dfe2e77902eec06813ca7036..dcead38af826a12e776160bdb251ba021e6b953c 100644 --- a/tensorflow/contrib/bayesflow/python/ops/variational_inference.py +++ b/tensorflow/contrib/bayesflow/python/ops/layers.py @@ -1,4 +1,4 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Variational inference. +"""Probabilistic neural layers. -See the ${@python/contrib.bayesflow.variational_inference} guide. +See ${python/contrib.bayesflow.layers}. """ from __future__ import absolute_import @@ -23,12 +23,15 @@ from __future__ import print_function # go/tf-wildcard-import # pylint: disable=wildcard-import -from tensorflow.contrib.bayesflow.python.ops.variational_inference_impl import * +from tensorflow.contrib.bayesflow.python.ops.layers_dense_variational_impl import * # pylint: enable=wildcard-import from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ - "elbo", "elbo_with_log_joint", "ELBOForms", "register_prior" + 'DenseVariational', + 'dense_variational', + 'default_loc_scale_fn', + 'default_mean_field_normal_fn', ] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/bayesflow/python/ops/layers_dense_variational_impl.py b/tensorflow/contrib/bayesflow/python/ops/layers_dense_variational_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..b05ce0ffc1dd55ffb029b339a846a9aa5c877620 --- /dev/null +++ b/tensorflow/contrib/bayesflow/python/ops/layers_dense_variational_impl.py @@ -0,0 +1,797 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Dense Bayesian layer using KL-divergence based variational inference. + +@@DenseVariational +@@dense_variational + +@@default_loc_scale_fn +@@default_mean_field_normal_fn +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distributions.python.ops import deterministic as deterministic_lib +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.layers import base as layers_lib +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import standard_ops +from tensorflow.python.ops.distributions import kullback_leibler as kl_lib +from tensorflow.python.ops.distributions import normal as normal_lib + + +__all__ = [ + "DenseVariational", + "dense_variational", + "default_loc_scale_fn", + "default_mean_field_normal_fn", +] + + +def default_loc_scale_fn( + is_singular=False, + loc_initializer=init_ops.random_normal_initializer(stddev=0.1), + untransformed_scale_initializer=init_ops.random_normal_initializer( + mean=-3., stddev=0.1), + loc_regularizer=None, + untransformed_scale_regularizer=None, + loc_constraint=None, + untransformed_scale_constraint=None): + """Makes closure which creates `loc`, `scale` params from `tf.get_variable`. + + This function produces a closure which produces `loc`, `scale` using + `tf.get_variable`. The closure accepts the following arguments: + + dtype: Type of parameter's event. + shape: Python `list`-like representing the parameter's event shape. + name: Python `str` name prepended to any created (or existing) + `tf.Variable`s. + trainable: Python `bool` indicating all created `tf.Variable`s should be + added to the graph collection `GraphKeys.TRAINABLE_VARIABLES`. + add_variable_fn: `tf.get_variable`-like `callable` used to create (or + access existing) `tf.Variable`s. + + Args: + is_singular: Python `bool` indicating if `scale is None`. Default: `False`. + loc_initializer: Initializer function for the `loc` parameters. + The default is `tf.random_normal_initializer(mean=0., stddev=0.1)`. + untransformed_scale_initializer: Initializer function for the `scale` + parameters. Default value: `tf.random_normal_initializer(mean=-3., + stddev=0.1)`. This implies the softplus transformed result has mean + approximately `0.05` and std. deviation approximately `0.005`. + loc_regularizer: Regularizer function for the `loc` parameters. + The default (`None`) is to use the `tf.get_variable` default. + untransformed_scale_regularizer: Regularizer function for the `scale` + parameters. The default (`None`) is to use the `tf.get_variable` default. + loc_constraint: An optional projection function to be applied to the + loc after being updated by an `Optimizer`. The function must take as input + the unprojected variable and must return the projected variable (which + must have the same shape). Constraints are not safe to use when doing + asynchronous distributed training. + The default (`None`) is to use the `tf.get_variable` default. + untransformed_scale_constraint: An optional projection function to be + applied to the `scale` parameters after being updated by an `Optimizer` + (e.g. used to implement norm constraints or value constraints). The + function must take as input the unprojected variable and must return the + projected variable (which must have the same shape). Constraints are not + safe to use when doing asynchronous distributed training. The default + (`None`) is to use the `tf.get_variable` default. + + Returns: + default_loc_scale_fn: Python `callable` which instantiates `loc`, `scale` + parameters from args: `dtype, shape, name, trainable, add_variable_fn`. + """ + def _fn(dtype, shape, name, trainable, add_variable_fn): + """Creates `loc`, `scale` parameters.""" + loc = add_variable_fn( + name=name + "_loc", + shape=shape, + initializer=loc_initializer, + regularizer=loc_regularizer, + constraint=loc_constraint, + dtype=dtype, + trainable=trainable) + if is_singular: + return loc, None + untransformed_scale = add_variable_fn( + name=name + "_untransformed_scale", + shape=shape, + initializer=untransformed_scale_initializer, + regularizer=untransformed_scale_regularizer, + constraint=untransformed_scale_constraint, + dtype=dtype, + trainable=trainable) + scale = (np.finfo(dtype.as_numpy_dtype).eps + + nn_ops.softplus(untransformed_scale)) + return loc, scale + return _fn + + +def default_mean_field_normal_fn( + is_singular=False, + loc_initializer=None, + untransformed_scale_initializer=None, + loc_regularizer=None, + untransformed_scale_regularizer=None, + loc_constraint=None, + untransformed_scale_constraint=None): + """Creates a function to build Normal distributions with trainable params. + + This function produces a closure which produces `tf.distributions.Normal` + parameterized by a loc` and `scale` each created using `tf.get_variable`. The + produced closure accepts the following arguments: + + name: Python `str` name prepended to any created (or existing) + `tf.Variable`s. + shape: Python `list`-like representing the parameter's event shape. + dtype: Type of parameter's event. + trainable: Python `bool` indicating all created `tf.Variable`s should be + added to the graph collection `GraphKeys.TRAINABLE_VARIABLES`. + add_variable_fn: `tf.get_variable`-like `callable` used to create (or + access existing) `tf.Variable`s. + + Args: + is_singular: Python `bool` if `True`, forces the special case limit of + `scale->0`, i.e., a `Deterministic` distribution. + loc_initializer: Initializer function for the `loc` parameters. + If `None` (default), values are initialized using the default + initializer used by `tf.get_variable`. + untransformed_scale_initializer: Initializer function for the `scale` + parameters. If `None` (default), values are initialized using the default + initializer used by `tf.get_variable`. + loc_regularizer: Regularizer function for the `loc` parameters. + untransformed_scale_regularizer: Regularizer function for the `scale` + parameters. + loc_constraint: An optional projection function to be applied to the + loc after being updated by an `Optimizer`. The function must take as input + the unprojected variable and must return the projected variable (which + must have the same shape). Constraints are not safe to use when doing + asynchronous distributed training. + untransformed_scale_constraint: An optional projection function to be + applied to the `scale` parameters after being updated by an `Optimizer` + (e.g. used to implement norm constraints or value constraints). The + function must take as input the unprojected variable and must return the + projected variable (which must have the same shape). Constraints are not + safe to use when doing asynchronous distributed training. + + Returns: + make_normal_fn: Python `callable` which creates a `tf.distributions.Normal` + using from args: `dtype, shape, name, trainable, add_variable_fn`. + """ + loc_scale_fn_ = default_loc_scale_fn( + is_singular, + loc_initializer, + untransformed_scale_initializer, + loc_regularizer, + untransformed_scale_regularizer, + loc_constraint, + untransformed_scale_constraint) + def _fn(dtype, shape, name, trainable, add_variable_fn): + """Creates a batch of `Deterministic` or `Normal` distributions.""" + loc, scale = loc_scale_fn_(dtype, shape, name, trainable, add_variable_fn) + if scale is None: + return deterministic_lib.Deterministic(loc=loc) + return normal_lib.Normal(loc=loc, scale=scale) + return _fn + + +class DenseVariational(layers_lib.Layer): + """Densely-connected variational class. + + This layer implements the Bayesian variational inference analogue to: + `outputs = activation(matmul(inputs, kernel) + bias)` + by assuming the `kernel` and/or the `bias` are random variables. + + The layer implements a stochastic dense calculation by making a Monte Carlo + approximation of a [variational Bayesian method based on KL divergence]( + https://en.wikipedia.org/wiki/Variational_Bayesian_methods), i.e., + + ```none + -log p(y|x) = -log int_{R**d} p(y|x,w) p(w) dw + = -log int_{R**d} p(y,w|x) q(w|x) / q(w|x) dw + <= E_q(W|x)[-log p(y,W|x) + log q(W|x)] # Jensen's + = E_q(W|x)[-log p(y|x,W)] + KL[q(W|x), p(W)] + ~= m**-1 sum{ -log(y|x,w[j]) : w[j] ~ q(W|x), j=1..m } + + KL[q(W|x), p(W)] + ``` + + where `W` denotes the (independent) `kernel` and `bias` random variables, `w` + is a random variate or outcome of `W`, `y` is the label, `x` is the evidence`, + and `~=` denotes an approximation which becomes exact as `m->inf`. The above + bound is sometimes referred to as the negative Evidence Lower BOund or + negative [ELBO](https://arxiv.org/abs/1601.00670). In context of a DNN, this + layer is appropriate to use when the final loss is a negative log-likelihood. + + The Monte-Carlo sum portion is used for the feed-forward calculation of the + DNN. The KL divergence portion can be added to the final loss via: + `loss += sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))`. + + The arguments permit separate specification of the surrogate posterior + (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` + random variables (which together comprise `W`). + + Args: + units: Integer or Long, dimensionality of the output space. + activation: Activation function (`callable`). Set it to None to maintain a + linear activation. + activity_regularizer: Regularizer function for the output. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + kernel_use_local_reparameterization: Python `bool` indicating whether + `kernel` calculation should employ the Local Reparameterization Trick. + When `True`, `kernel_posterior_fn` must create an instance of + `tf.distributions.Normal`. + kernel_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `kernel` parameter. Default value: + `default_mean_field_normal_fn()`. + kernel_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + kernel_prior_fn: Python `callable` which creates `tf.distributions` + instance. See `default_mean_field_normal_fn` docstring for required + parameter signature. + Default value: `tf.distributions.Normal(loc=0., scale=1.)`. + kernel_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + bias_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `bias` parameter. Default value: + `default_mean_field_normal_fn(is_singular=True)` (which creates an + instance of `tf.distributions.Deterministic`). + bias_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + bias_prior_fn: Python `callable` which creates `tf.distributions` instance. + See `default_mean_field_normal_fn` docstring for required parameter + signature. Default value: `None` (no prior, no variational inference) + bias_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + name: Python `str`, the name of the layer. Layers with the same name will + share `tf.Variable`s, but to avoid mistakes we require `reuse=True` in + such cases. + reuse: Python `bool`, whether to reuse the `tf.Variable`s of a previous + layer by the same name. + + Properties: + units: Python integer, dimensionality of the output space. + activation: Activation function (`callable`). + activity_regularizer: Regularizer function for the output. + kernel_use_local_reparameterization: Python `bool` indicating whether + `kernel` calculation should employ the Local Reparameterization Trick. + kernel: `VariationalKernelParamater` instance containing all `kernel` + related properties and `callable`s. + bias: `VariationalParameter` instance containing all `kernel` + related properties and `callable`s. + """ + + def __init__( + self, + units, + activation=None, + activity_regularizer=None, + trainable=True, + kernel_use_local_reparameterization=True, + kernel_posterior_fn=default_mean_field_normal_fn(), + kernel_posterior_tensor_fn=lambda d: d.sample(), + kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda + loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), + kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + bias_posterior_fn=default_mean_field_normal_fn(is_singular=True), + bias_posterior_tensor_fn=lambda d: d.sample(), + bias_prior_fn=None, + bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + name=None, + **kwargs): + super(DenseVariational, self).__init__( + trainable=trainable, + name=name, + activity_regularizer=activity_regularizer, + **kwargs) + self._units = units + self._activation = activation + self._input_spec = layers_lib.InputSpec(min_ndim=2) + self._kernel_use_local_reparameterization = ( + kernel_use_local_reparameterization) + self._kernel = VariationalKernelParameter( + kernel_posterior_fn, + kernel_posterior_tensor_fn, + kernel_prior_fn, + kernel_divergence_fn) + self._bias = VariationalParameter( + bias_posterior_fn, + bias_posterior_tensor_fn, + bias_prior_fn, + bias_divergence_fn) + + @property + def units(self): + return self._units + + @property + def activation(self): + return self._activation + + @property + def input_spec(self): + return self._input_spec + + @input_spec.setter + def input_spec(self, value): + self._input_spec = value + + @property + def kernel_use_local_reparameterization(self): + return self._kernel_use_local_reparameterization + + @property + def kernel(self): + return self._kernel + + @property + def bias(self): + return self._bias + + def build(self, input_shape): + input_shape = tensor_shape.TensorShape(input_shape) + in_size = input_shape.with_rank_at_least(2)[-1].value + if in_size is None: + raise ValueError("The last dimension of the inputs to `Dense` " + "should be defined. Found `None`.") + self._input_spec = layers_lib.InputSpec(min_ndim=2, axes={-1: in_size}) + dtype = dtypes.as_dtype(self.dtype) + + # Must have a posterior kernel. + self.kernel.posterior = self.kernel.posterior_fn( + dtype, [in_size, self.units], "kernel_posterior", + self.trainable, self.add_variable) + + if self.kernel.prior_fn is None: + self.kernel_prior = None + else: + self.kernel.prior = self.kernel.prior_fn( + dtype, [in_size, self.units], "kernel_prior", + self.trainable, self.add_variable) + self._built_kernel_divergence = False + + if self.bias.posterior_fn is None: + self.bias.posterior = None + else: + self.bias.posterior = self.bias.posterior_fn( + dtype, [self.units], "bias_posterior", + self.trainable, self.add_variable) + + if self.bias.prior_fn is None: + self.bias.prior = None + else: + self.bias.prior = self.bias.prior_fn( + dtype, [self.units], "bias_prior", + self.trainable, self.add_variable) + self._built_bias_divergence = False + + self.built = True + + def call(self, inputs): + inputs = ops.convert_to_tensor(inputs, dtype=self.dtype) + + outputs = self._apply_variational_kernel(inputs) + outputs = self._apply_variational_bias(outputs) + if self.activation is not None: + outputs = self.activation(outputs) # pylint: disable=not-callable + if not self._built_kernel_divergence: + self._apply_divergence(self.kernel, name="divergence_kernel") + self._built_kernel_divergence = True + if not self._built_bias_divergence: + self._apply_divergence(self.bias, name="divergence_bias") + self._built_bias_divergence = True + return outputs + + def _apply_variational_kernel(self, inputs): + if not self.kernel_use_local_reparameterization: + self.kernel.posterior_tensor = self.kernel.posterior_tensor_fn( + self.kernel.posterior) + self.kernel.posterior_affine = None + self.kernel.posterior_affine_tensor = None + return self._matmul(inputs, self.kernel.posterior_tensor) + if not isinstance(self.kernel.posterior, normal_lib.Normal): + raise TypeError("`kernel_use_local_reparameterization=True` requires " + "`kernel_posterior_fn` produce an instance of " + "`tf.distributions.Normal` (saw: \"{}\").".format( + type(self.kernel.posterior).__name__)) + self.kernel.posterior_affine = normal_lib.Normal( + loc=self._matmul(inputs, self.kernel.posterior.loc), + scale=standard_ops.sqrt(self._matmul( + standard_ops.square(inputs), + standard_ops.square(self.kernel.posterior.scale)))) + self.kernel.posterior_affine_tensor = ( + self.kernel.posterior_tensor_fn(self.kernel.posterior_affine)) + self.kernel.posterior_tensor = None + return self.kernel.posterior_affine_tensor + + def _apply_variational_bias(self, inputs): + if self.bias.posterior is None: + self.bias.posterior_tensor = None + return inputs + self.bias.posterior_tensor = self.bias.posterior_tensor_fn( + self.bias.posterior) + return nn.bias_add(inputs, self.bias.posterior_tensor) + + def _apply_divergence(self, param, name): + if (param.divergence_fn is None or + param.posterior is None or + param.prior is None): + param.divergence = None + return + param.divergence = standard_ops.identity( + param.divergence_fn( + param.posterior, param.prior, param.posterior_tensor), + name=name) + self.add_loss(param.divergence) + + def _matmul(self, inputs, kernel): + if inputs.shape.ndims <= 2: + return standard_ops.matmul(inputs, kernel) + # To handle broadcasting, we must use `tensordot`. + return standard_ops.tensordot(inputs, kernel, axes=[[-1], [0]]) + + def _compute_output_shape(self, input_shape): + input_shape = tensor_shape.TensorShape(input_shape).with_rank_at_least(2) + if input_shape[-1].value is None: + raise ValueError( + "The innermost dimension of input_shape must be defined, " + "but saw: {}".format(input_shape)) + return input_shape[:-1].concatenate(self.units) + + +def dense_variational( + inputs, + units, + activation=None, + activity_regularizer=None, + trainable=True, + kernel_use_local_reparameterization=True, + kernel_posterior_fn=default_mean_field_normal_fn(), + kernel_posterior_tensor_fn=lambda d: d.sample(), + kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda + loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), + kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + bias_posterior_fn=default_mean_field_normal_fn(is_singular=True), + bias_posterior_tensor_fn=lambda d: d.sample(), + bias_prior_fn=None, + bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + name=None, + reuse=None): + """Densely-connected variational layer. + + This layer implements the Bayesian variational inference analogue to: + `outputs = activation(matmul(inputs, kernel) + bias)` + by assuming the `kernel` and/or the `bias` are random variables. + + The layer implements a stochastic dense calculation by making a Monte Carlo + approximation of a [variational Bayesian method based on KL divergence]( + https://en.wikipedia.org/wiki/Variational_Bayesian_methods), i.e., + + ```none + -log p(y|x) = -log int_{R**d} p(y|x,w) p(w) dw + = -log int_{R**d} p(y,w|x) q(w|x) / q(w|x) dw + <= E_q(W|x)[-log p(y,W|x) + log q(W|x)] # Jensen's + = E_q(W|x)[-log p(y|x,W)] + KL[q(W|x), p(W)] + ~= m**-1 sum{ -log(y|x,w[j]) : w[j] ~ q(W|x), j=1..m } + + KL[q(W|x), p(W)] + ``` + + where `W` denotes the (independent) `kernel` and `bias` random variables, `w` + is a random variate or outcome of `W`, `y` is the label, `x` is the evidence`, + and `~=` denotes an approximation which becomes exact as `m->inf`. The above + bound is sometimes referred to as the negative Evidence Lower BOund or + negative [ELBO](https://arxiv.org/abs/1601.00670). In context of a DNN, this + layer is appropriate to use when the final loss is a negative log-likelihood. + + The Monte-Carlo sum portion is used for the feed-forward calculation of the + DNN. The KL divergence portion can be added to the final loss via: + `loss += sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))`. + + The arguments permit separate specification of the surrogate posterior + (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` + random variables (which together comprise `W`). + + Args: + inputs: Tensor input. + units: Integer or Long, dimensionality of the output space. + activation: Activation function (`callable`). Set it to None to maintain a + linear activation. + activity_regularizer: Regularizer function for the output. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + kernel_use_local_reparameterization: Python `bool` indicating whether + `kernel` calculation should employ the Local Reparameterization Trick. + When `True`, `kernel_posterior_fn` must create an instance of + `tf.distributions.Normal`. + kernel_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `kernel` parameter. Default value: + `default_mean_field_normal_fn()`. + kernel_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + kernel_prior_fn: Python `callable` which creates `tf.distributions` + instance. See `default_mean_field_normal_fn` docstring for required + parameter signature. + Default value: `tf.distributions.Normal(loc=0., scale=1.)`. + kernel_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + bias_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `bias` parameter. Default value: + `default_mean_field_normal_fn(is_singular=True)` (which creates an + instance of `tf.distributions.Deterministic`). + bias_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + bias_prior_fn: Python `callable` which creates `tf.distributions` instance. + See `default_mean_field_normal_fn` docstring for required parameter + signature. Default value: `None` (no prior, no variational inference) + bias_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + name: Python `str`, the name of the layer. Layers with the same name will + share `tf.Variable`s, but to avoid mistakes we require `reuse=True` in + such cases. + reuse: Python `bool`, whether to reuse the `tf.Variable`s of a previous + layer by the same name. + + Returns: + output: `Tensor` representing a the affine transformed input under a random + draw from the surrogate posterior distribution. + """ + layer = DenseVariational( + units, + activation=activation, + activity_regularizer=activity_regularizer, + trainable=trainable, + kernel_use_local_reparameterization=( + kernel_use_local_reparameterization), + kernel_posterior_fn=kernel_posterior_fn, + kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, + kernel_prior_fn=kernel_prior_fn, + kernel_divergence_fn=kernel_divergence_fn, + bias_posterior_fn=bias_posterior_fn, + bias_posterior_tensor_fn=bias_posterior_tensor_fn, + bias_prior_fn=bias_prior_fn, + bias_divergence_fn=bias_divergence_fn, + name=name, + dtype=inputs.dtype.base_dtype, + _scope=name, + _reuse=reuse) + return layer.apply(inputs) + + +class NotSet(object): + """Helper to track whether a `VariationalParameter` value has been set.""" + pass + + +class VariationalParameter(object): + """Struct-like container of variational parameter properties. + + A `VariationalParameter` is intitialized with Python `callable`s which set the + value of correspondingly named members. Corresponding values have "set once" + semantics, i.e., once set to any value they are immutable. + """ + + def __init__( + self, + posterior_fn, + posterior_tensor_fn, + prior_fn, + divergence_fn): + """Creates the `VariationalParameter` struct-like object. + + Args: + posterior_fn: Python `callable` which creates a + `tf.distribution.Distribution` like object representing the posterior + distribution. See `VariationalParameter.posterior_fn` for `callable`'s + required parameters. + posterior_tensor_fn: Python `callable` which computes a `Tensor` + which represents the `posterior`. + prior_fn: Python `callable` which creates a + `tf.distribution.Distribution` like object representing the prior + distribution. See `VariationalParameter.prior_fn` for `callable`'s + required parameters. + divergence_fn: Python `callable` which computes the KL divergence from + `posterior` to `prior`. See `VariationalParameter.divergence_fn` for + required `callable`'s parameters. + """ + self._posterior_fn = posterior_fn + self._posterior = NotSet() + self._posterior_tensor_fn = posterior_tensor_fn + self._posterior_tensor = NotSet() + self._prior_fn = prior_fn + self._prior = NotSet() + self._divergence_fn = divergence_fn + self._divergence = NotSet() + self._init_helper() + + @property + def posterior_fn(self): + """`callable` which creates `tf.distributions.Distribution`-like posterior. + + The `callable` must accept the following parameters: + name: Python `str` name prepended to any created (or existing) + `tf.Variable`s. + shape: Python `list`-like representing the parameter's event shape. + dtype: Type of parameter's event. + trainable: Python `bool` indicating all created `tf.Variable`s should be + added to the graph collection `GraphKeys.TRAINABLE_VARIABLES`. + add_variable_fn: `tf.get_variable`-like `callable` used to create (or + access existing) `tf.Variable`s. + + Returns: + posterior_fn: The Python `callable` specified in `__init__`. + """ + return self._posterior_fn + + @property + def posterior(self): + """`tf.distributions.Distribution`-like instance representing posterior.""" + return self._posterior + + @posterior.setter + def posterior(self, value): + """One-time setter of the `posterior` distribution.""" + if not isinstance(self._posterior, NotSet): + raise ValueError("Cannot override already set attribute.") + self._posterior = value + + @property + def posterior_tensor_fn(self): + """Creates `Tensor` representing the `posterior` distribution. + + The `callable` must accept the following parameters: + posterior: `tf.distributions.Distribution`-like instance. + + Returns: + posterior_tensor_fn: The Python `callable` specified in + `__init__`. + """ + return self._posterior_tensor_fn + + @property + def posterior_tensor(self): + """`Tensor` representing the `posterior` distribution.""" + return self._posterior_tensor + + @posterior_tensor.setter + def posterior_tensor(self, value): + """One-time setter of the `posterior_tensor`.""" + if not isinstance(self._posterior_tensor, NotSet): + raise ValueError("Cannot override already set attribute.") + self._posterior_tensor = value + + @property + def prior_fn(self): + """`callable` which creates `tf.distributions.Distribution`-like prior. + + The `callable` must accept the following parameters: + name: Python `str` name prepended to any created (or existing) + `tf.Variable`s. + shape: Python `list`-like representing the parameter's event shape. + dtype: Type of parameter's event. + trainable: Python `bool` indicating all created `tf.Variable`s should be + added to the graph collection `GraphKeys.TRAINABLE_VARIABLES`. + add_variable_fn: `tf.get_variable`-like `callable` used to create (or + access existing) `tf.Variable`s. + + Returns: + prior_fn: The Python `callable` specified in `__init__`. + """ + return self._prior_fn + + @property + def prior(self): + """`tf.distributions.Distribution`-like instance representing posterior.""" + return self._prior + + @prior.setter + def prior(self, value): + """One-time setter of the `prior` distribution.""" + if not isinstance(self._prior, NotSet): + raise ValueError("Cannot override already set attribute.") + self._prior = value + + @property + def divergence_fn(self): + """`callable` which computes KL-divergence `Tensor` from posterior to prior. + + The `callable` must accept the following parameters: + posterior: `tf.distributions.Distribution`-like instance. + prior: `tf.distributions.Distribution`-like instance. + posterior_tensor: `Tensor` representing value of posterior. + + Returns: + divergence_fn: The Python `callable` specified in `__init__`. + """ + return self._divergence_fn + + @property + def divergence(self): + """`Tensor` representing KL-divergence from posterior to prior.""" + return self._divergence + + @divergence.setter + def divergence(self, value): + """One-time setter of the `divergence`.""" + if not isinstance(self._divergence, NotSet): + raise ValueError("Cannot override already set attribute.") + self._divergence = value + + def _init_helper(self): + pass + + +class VariationalKernelParameter(VariationalParameter): + """Struct-like container of variational kernel properties. + + A `VariationalKernelParameter` is intitialized with Python `callable`s which + set the value of correspondingly named members. Corresponding values have "set + once" semantics, i.e., once set to any value they are immutable. + """ + + @property + def posterior_affine(self): + """`tf.distributions.Distribution` affine transformed posterior.""" + return self._posterior_affine + + @posterior_affine.setter + def posterior_affine(self, value): + """One-time setter of `posterior_affine`.""" + if not isinstance(self._posterior_affine, NotSet): + raise ValueError("Cannot override already set attribute.") + self._posterior_affine = value + + @property + def posterior_affine_tensor(self): + """`Tensor` representing the `posterior_affine` distribution.""" + return self._posterior_affine_tensor + + @posterior_affine_tensor.setter + def posterior_affine_tensor(self, value): + """One-time setter of the `posterior_affine_tensor`.""" + if not isinstance(self._posterior_affine_tensor, NotSet): + raise ValueError("Cannot override already set attribute.") + self._posterior_affine_tensor = value + + def _init_helper(self): + self._posterior_affine = NotSet() + self._posterior_affine_tensor = NotSet() diff --git a/tensorflow/contrib/bayesflow/python/ops/stochastic_graph.py b/tensorflow/contrib/bayesflow/python/ops/optimizers.py similarity index 77% rename from tensorflow/contrib/bayesflow/python/ops/stochastic_graph.py rename to tensorflow/contrib/bayesflow/python/ops/optimizers.py index b8e38b6f9bf86aef42627cf127a93ce2edd42451..ee32e6b5c3d9efaeaf73436638c5eea55f2cfc70 100644 --- a/tensorflow/contrib/bayesflow/python/ops/stochastic_graph.py +++ b/tensorflow/contrib/bayesflow/python/ops/optimizers.py @@ -1,4 +1,4 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,11 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Support for Stochastic Computation Graphs. +"""Probabilistic optimizer modules. -See the @{$python/contrib.bayesflow.stochastic_graph} guide. - -@@surrogate_loss +See ${python/contrib.bayesflow.optimizers}. """ from __future__ import absolute_import @@ -25,13 +23,12 @@ from __future__ import print_function # go/tf-wildcard-import # pylint: disable=wildcard-import -from tensorflow.contrib.bayesflow.python.ops.stochastic_graph_impl import * +from tensorflow.contrib.bayesflow.python.ops.sgld_optimizer import * # pylint: enable=wildcard-import from tensorflow.python.util.all_util import remove_undocumented - _allowed_symbols = [ - "surrogate_loss" + 'SGLDOptimizer', ] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/bayesflow/python/ops/sgld_optimizer.py b/tensorflow/contrib/bayesflow/python/ops/sgld_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..5d36ea7a2b51aa45cdc253992a2a58634c068987 --- /dev/null +++ b/tensorflow/contrib/bayesflow/python/ops/sgld_optimizer.py @@ -0,0 +1,216 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""An optimizer module for stochastic gradient Langevin dynamics.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import variable_scope as varscope_ops +from tensorflow.python.training import optimizer +from tensorflow.python.training import training_ops + + +class SGLDOptimizer(optimizer.Optimizer): + """An optimizer module for stochastic gradient Langevin dynamics. + + This implements the preconditioned Stochastic Gradient Langevin Dynamics + optimizer [1]. The optimization variable is regarded as a sample from the + posterior under Stochastic Gradient Langevin Dynamics with noise rescaled in + each dimension according to RMSProp [2]. + + Note: If a prior is included in the loss, it should be scaled by + `1/num_pseudo_batches`, where num_pseudo_batches is the number of minibatches + in the data. I.e., it should be divided by the `num_pseudo_batches` term + described below. + + [1]: "Preconditioned Stochastic Gradient Langevin Dynamics for Deep Neural + Networks." Chunyuan Li, Changyou Chen, David Carlson, Lawrence Carin. + ArXiv:1512.07666, 2015. https://arxiv.org/abs/1512.07666 + [2]: http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf + + Args: + learning_rate: Scalar `float`-like `Tensor`. The base learning rate for the + optimizer. Must be tuned to the specific function being minimized. + preconditioner_decay_rate: Scalar `float`-like `Tensor`. The exponential + decay rate of the rescaling of the preconditioner (RMSprop). (This is + "alpha" in [1]). Should be smaller than but nearly `1` to approximate + sampling from the posterior. (Default: `0.95`) + num_pseudo_batches: Scalar `int`-like `Tensor`. The effective number of + minibatches in the data set. Trades off noise and prior with the SGD + likelihood term. Note: Assumes the loss is taken as the mean over a + minibatch. Otherwise if the sum was taken, divide this number by the + batch size. (Default: `1`) + burnin: Scalar `int`-like `Tensor`. The number of iterations to collect + gradient statistics to update the preconditioner before starting to draw + noisy samples. (Default: `25`) + diagonal_bias: Scalar `float`-like `Tensor`. Term added to the diagonal of + the preconditioner to prevent the preconditioner from degenerating. + (Default: `1e-8`) + name: Python `str` describing ops managed by this function. + (Default: `"SGLDOptimizer"`) + variable_scope: Variable scope used for calls to `tf.get_variable`. + If `None`, a new variable scope is created using name + `ops.get_default_graph().unique_name(name or default_name)`. + + Raises: + InvalidArgumentError: If preconditioner_decay_rate is a `Tensor` not in + `(0,1]`. + """ + + def __init__(self, + learning_rate, + preconditioner_decay_rate=0.95, + num_pseudo_batches=1, + burnin=25, + diagonal_bias=1e-8, + name=None, + variable_scope=None): + default_name = 'SGLDOptimizer' + with ops.name_scope(name, default_name, [ + learning_rate, preconditioner_decay_rate, num_pseudo_batches, burnin, + diagonal_bias + ]): + if variable_scope is None: + var_scope_name = ops.get_default_graph().unique_name( + name or default_name) + with varscope_ops.variable_scope(var_scope_name) as scope: + self._variable_scope = scope + else: + self._variable_scope = variable_scope + + self._preconditioner_decay_rate = ops.convert_to_tensor( + preconditioner_decay_rate, name='preconditioner_decay_rate') + self._num_pseudo_batches = ops.convert_to_tensor( + num_pseudo_batches, name='num_pseudo_batches') + self._burnin = ops.convert_to_tensor(burnin, name='burnin') + self._diagonal_bias = ops.convert_to_tensor( + diagonal_bias, name='diagonal_bias') + self._learning_rate = ops.convert_to_tensor( + learning_rate, name='learning_rate') + + with varscope_ops.variable_scope(self._variable_scope): + self._counter = varscope_ops.get_variable( + 'counter', initializer=0, trainable=False) + + self._preconditioner_decay_rate = control_flow_ops.with_dependencies([ + check_ops.assert_non_negative( + self._preconditioner_decay_rate, + message='`preconditioner_decay_rate` must be non-negative'), + check_ops.assert_less_equal( + self._preconditioner_decay_rate, + 1., + message='`preconditioner_decay_rate` must be at most 1.'), + ], self._preconditioner_decay_rate) + + self._num_pseudo_batches = control_flow_ops.with_dependencies([ + check_ops.assert_greater( + self._num_pseudo_batches, + 0, + message='`num_pseudo_batches` must be greater than zero') + ], self._num_pseudo_batches) + + self._burnin = control_flow_ops.with_dependencies([ + check_ops.assert_non_negative( + self._burnin, message='`burnin` must be non-negative'), + check_ops.assert_integer( + self._burnin, message='`burnin` must be an integer') + ], self._burnin) + + self._diagonal_bias = control_flow_ops.with_dependencies([ + check_ops.assert_non_negative( + self._diagonal_bias, + message='`diagonal_bias` must be non-negative') + ], self._diagonal_bias) + + super(SGLDOptimizer, self).__init__(use_locking=False, + name=name or default_name) + + def _create_slots(self, var_list): + for v in var_list: + init_rms = init_ops.ones_initializer(dtype=v.dtype) + self._get_or_make_slot_with_initializer(v, init_rms, v.get_shape(), + v.dtype, 'rms', self._name) + + def _prepare(self): + # We need to put the conversion and check here because a user will likely + # want to decay the learning rate dynamically. + self._learning_rate_tensor = control_flow_ops.with_dependencies([ + check_ops.assert_non_negative( + self._learning_rate, message='`learning_rate` must be non-negative') + ], ops.convert_to_tensor(self._learning_rate, name='learning_rate_tensor')) + self._decay_tensor = ops.convert_to_tensor( + self._preconditioner_decay_rate, name='preconditioner_decay_rate') + + super(SGLDOptimizer, self)._prepare() + + def _apply_dense(self, grad, var): + rms = self.get_slot(var, 'rms') + + with ops.control_dependencies([ + self._update_momentum(rms, grad, math_ops.cast(self._decay_tensor, + var.dtype.base_dtype))]): + new_grad = self._apply_noisy_update(rms, grad) + + return training_ops.apply_gradient_descent( + var, + math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype), + new_grad, + use_locking=self._use_locking).op + + def _apply_sparse(self, grad, var): + rms = self.get_slot(var, 'rms') + + with ops.control_dependencies([ + self._update_momentum(rms, grad, math_ops.cast(self._decay_tensor, + var.dtype.base_dtype))]): + new_grad = self._apply_noisy_update(rms, grad) + + return training_ops.apply_gradient_descent( + var, + math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype), + new_grad, + use_locking=self._use_locking).op + + @property + def variable_scope(self): + """Variable scope of all calls to `tf.get_variable`.""" + return self._variable_scope + + def _apply_noisy_update(self, mom, grad): + # Compute and apply the gradient update following + # preconditioned Langevin dynamics + stddev = array_ops.where( + array_ops.squeeze(self._counter > self._burnin), + math_ops.cast(math_ops.rsqrt(self._learning_rate), grad.dtype), + array_ops.zeros([], grad.dtype)) + + preconditioner = math_ops.rsqrt( + mom + math_ops.cast(self._diagonal_bias, grad.dtype)) + return ( + 0.5 * preconditioner * grad * math_ops.cast(self._num_pseudo_batches, + grad.dtype) + + random_ops.random_normal(array_ops.shape(grad), 1.0, dtype=grad.dtype) * + stddev * math_ops.sqrt(preconditioner)) + + def _update_momentum(self, mom, grad, decay): + # Keep an exponentially weighted moving average of squared gradients. + # Not thread safe + return mom.assign_add((1.0 - decay) * (math_ops.square(grad) - mom)) diff --git a/tensorflow/contrib/bayesflow/python/ops/stochastic_gradient_estimators.py b/tensorflow/contrib/bayesflow/python/ops/stochastic_gradient_estimators.py deleted file mode 100644 index 695310837e0f6a58842f45c28608f12fbe162c6e..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/ops/stochastic_gradient_estimators.py +++ /dev/null @@ -1,317 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Stochastic gradient estimators. - -These functions are meant to be used in conjunction with `StochasticTensor` -(`loss_fn` parameter) and `surrogate_loss`. - -See Gradient Estimation Using Stochastic Computation Graphs -(http://arxiv.org/abs/1506.05254) by Schulman et al., eq. 1 and section 4, for -mathematical details. - -## Score function estimator - -The score function is an unbiased estimator of the gradient of `E_p(x)[f(x)]`, -where `f(x)` can be considered to be a "loss" term. It is computed as -`E_p(x)[f(x) grad(log p(x))]`. A constant `b`, referred to here as the -"baseline", can be subtracted from `f(x)` without affecting the expectation. The -term `(f(x) - b)` is referred to here as the "advantage". - -Note that the methods defined in this module actually compute the integrand of -the score function, such that when taking the gradient, the true score function -is computed. - -@@score_function -@@get_score_function_with_baseline -@@get_score_function_with_constant_baseline -@@get_score_function_with_advantage - -## Baseline functions - -Baselines reduce the variance of Monte Carlo estimate of an expectation. The -baseline for a stochastic node can be a function of all non-influenced nodes -(see section 4 of Schulman et al., linked above). Baselines are also known as -"control variates." - -In the context of a MC estimate of `E_p(x)[f(x) - b]`, baseline functions have -the signature `(st, fx) => Tensor`, where `st` is a `StochasticTensor` backed by -the distribution `p(x)` and `fx` is the influenced loss. - -@@get_mean_baseline - -""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import variable_scope as vs -from tensorflow.python.training import training -from tensorflow.python.util.all_util import make_all - - -def score_function(stochastic_tensor, value, loss, baseline=None, - name="ScoreFunction"): - """Score function estimator. - - Computes the integrand of the score function with a baseline: - `p.log_prob(value) * (loss - baseline)`. - - It will add a `stop_gradient` to the advantage `(loss - baseline)`. - - Args: - stochastic_tensor: `StochasticTensor` p(x). - value: `Tensor` x. Samples from p(x). - loss: `Tensor`. - baseline: `Tensor` broadcastable to `loss`. - name: name to prepend ops with. - - Returns: - `Tensor` `p.log_prob(x) * (loss - b)`. Taking the gradient yields the score - function estimator. - """ - with ops.name_scope(name, values=[value, loss, baseline]): - value = ops.convert_to_tensor(value) - loss = ops.convert_to_tensor(loss) - if baseline is not None: - baseline = ops.convert_to_tensor(baseline) - advantage = loss - baseline - else: - advantage = loss - - advantage = array_ops.stop_gradient(advantage) - return stochastic_tensor.distribution.log_prob(value) * advantage - - -def get_score_function_with_advantage(advantage_fn=None, - name="ScoreFunctionWithAdvantage"): - """Score function estimator with advantage function. - - Args: - advantage_fn: callable that takes the `StochasticTensor` and the - downstream `loss` and returns a `Tensor` advantage - (e.g. `loss - baseline`). - name: name to prepend ops with. - - Returns: - Callable score function estimator that takes the `StochasticTensor`, the - sampled `value`, and the downstream `loss`, and uses the provided advantage. - """ - - def score_function_with_advantage(stochastic_tensor, value, loss): - with ops.name_scope(name, values=[value, loss]): - advantage = advantage_fn(stochastic_tensor, loss) - advantage = array_ops.stop_gradient(advantage) - return stochastic_tensor.distribution.log_prob(value) * advantage - - return score_function_with_advantage - - -def get_score_function_with_constant_baseline(baseline, name="ScoreFunction"): - """Score function estimator with constant baseline. - - Args: - baseline: `Tensor` to be subtracted from loss. - name: name to prepend ops with. - - Returns: - Callable score function estimator that takes the `StochasticTensor`, the - sampled `value`, and the downstream `loss`, and subtracts the provided - `baseline` from the `loss`. - """ - - def score_function_with_constant_baseline(stochastic_tensor, value, loss): - return score_function(stochastic_tensor, value, loss, baseline, name) - - return score_function_with_constant_baseline - - -def get_score_function_with_baseline(baseline_fn=None, name="ScoreFunction"): - """Score function estimator with baseline function. - - Args: - baseline_fn: callable that takes the `StochasticTensor` and the downstream - `loss` and returns a `Tensor` baseline to be subtracted from the `loss`. - If None, defaults to `get_mean_baseline`, which is an EMA of the loss. - name: name to prepend ops with. - - Returns: - Callable score function estimator that takes the `StochasticTensor`, the - sampled `value`, and the downstream `loss`, and subtracts the provided - `baseline` from the `loss`. - """ - if baseline_fn is None: - baseline_fn = get_mean_baseline() - - def score_function_with_baseline(stochastic_tensor, value, loss): - with ops.name_scope(name): - b = baseline_fn(stochastic_tensor, loss) - return score_function(stochastic_tensor, value, loss, b) - - return score_function_with_baseline - - -def get_mean_baseline(ema_decay=0.99, name=None): - """ExponentialMovingAverage baseline. - - Args: - ema_decay: decay rate for the ExponentialMovingAverage. - name: name for variable scope of the ExponentialMovingAverage. - - Returns: - Callable baseline function that takes the `StochasticTensor` (unused) and - the downstream `loss`, and returns an EMA of the loss. - """ - - def mean_baseline(_, loss): - with vs.variable_scope(name, default_name="MeanBaseline"): - reduced_loss = math_ops.reduce_mean(loss) - - ema = training.ExponentialMovingAverage(decay=ema_decay, zero_debias=True) - update_op = ema.apply([reduced_loss]) - - with ops.control_dependencies([update_op]): - # Using `identity` causes an op to be added in this context, which - # triggers the update. Removing the `identity` means nothing is updated. - baseline = array_ops.identity(ema.average(reduced_loss)) - - return baseline - - return mean_baseline - - -def get_vimco_advantage_fn(have_log_loss=False): - """VIMCO (Variational Inference for Monte Carlo Objectives) baseline. - - Implements VIMCO baseline from the article of the same name: - - https://arxiv.org/pdf/1602.06725v2.pdf - - Given a `loss` tensor (containing non-negative probabilities or ratios), - calculates the advantage VIMCO advantage via Eq. 9 of the above paper. - - The tensor `loss` should be shaped `[n, ...]`, with rank at least 1. Here, - the first axis is considered the single sampling dimension and `n` must - be at least 2. Specifically, the `StochasticTensor` is assumed to have - used the `SampleValue(n)` value type with `n > 1`. - - Args: - have_log_loss: Python `Boolean`. If `True`, the loss is assumed to be the - log loss. If `False` (the default), it is assumed to be a nonnegative - probability or probability ratio. - - Returns: - Callable baseline function that takes the `StochasticTensor` (unused) and - the downstream `loss`, and returns the VIMCO baseline for the loss. - """ - def vimco_advantage_fn(_, loss, name=None): - """Internal VIMCO function. - - Args: - _: ignored `StochasticTensor`. - loss: The loss `Tensor`. - name: Python string, the name scope to use. - - Returns: - The advantage `Tensor`. - """ - with ops.name_scope(name, "VIMCOAdvantage", values=[loss]): - loss = ops.convert_to_tensor(loss) - loss_shape = loss.get_shape() - loss_num_elements = loss_shape[0].value - n = math_ops.cast( - loss_num_elements or array_ops.shape(loss)[0], dtype=loss.dtype) - - if have_log_loss: - log_loss = loss - else: - log_loss = math_ops.log(loss) - - # Calculate L_hat, Eq. (4) -- stably - log_mean = math_ops.reduce_logsumexp(log_loss, [0]) - math_ops.log(n) - - # expand_dims: Expand shape [a, b, c] to [a, 1, b, c] - log_loss_expanded = array_ops.expand_dims(log_loss, [1]) - - # divide: log_loss_sub with shape [a, a, b, c], where - # - # log_loss_sub[i] = log_loss - log_loss[i] - # - # = [ log_loss[j] - log_loss[i] for rows j = 0 ... i - 1 ] - # [ zeros ] - # [ log_loss[j] - log_loss[i] for rows j = i + 1 ... a - 1 ] - # - log_loss_sub = log_loss - log_loss_expanded - - # reduce_sum: Sums each row across all the sub[i]'s; result is: - # reduce_sum[j] = (n - 1) * log_loss[j] - (sum_{i != j} loss[i]) - # divide by (n - 1) to get: - # geometric_reduction[j] = - # log_loss[j] - (sum_{i != j} log_loss[i]) / (n - 1) - geometric_reduction = math_ops.reduce_sum(log_loss_sub, [0]) / (n - 1) - - # subtract this from the original log_loss to get the baseline: - # geometric_mean[j] = exp((sum_{i != j} log_loss[i]) / (n - 1)) - log_geometric_mean = log_loss - geometric_reduction - - ## Equation (9) - - # Calculate sum_{i != j} loss[i] -- via exp(reduce_logsumexp(.)) - # reduce_logsumexp: log-sum-exp each row across all the - # -sub[i]'s, result is: - # - # exp(reduce_logsumexp[j]) = - # 1 + sum_{i != j} exp(log_loss[i] - log_loss[j]) - log_local_learning_reduction = math_ops.reduce_logsumexp( - -log_loss_sub, [0]) - - # convert local_learning_reduction to the sum-exp of the log-sum-exp - # (local_learning_reduction[j] - 1) * exp(log_loss[j]) - # = sum_{i != j} exp(log_loss[i]) - local_learning_log_sum = ( - _logexpm1(log_local_learning_reduction) + log_loss) - - # Add (logaddexp) the local learning signals (Eq. 9) - local_learning_signal = ( - math_ops.reduce_logsumexp( - array_ops.stack((local_learning_log_sum, log_geometric_mean)), - [0]) - - math_ops.log(n)) - - advantage = log_mean - local_learning_signal - - return advantage - - return vimco_advantage_fn - - -def _logexpm1(x): - """Stably calculate log(exp(x)-1).""" - with ops.name_scope("logsumexp1"): - eps = np.finfo(x.dtype.as_numpy_dtype).eps - # Choose a small offset that makes gradient calculations stable for - # float16, float32, and float64. - safe_log = lambda y: math_ops.log(y + eps / 1e8) # For gradient stability - return array_ops.where( - math_ops.abs(x) < eps, - safe_log(x) + x/2 + x*x/24, # small x approximation to log(expm1(x)) - safe_log(math_ops.exp(x) - 1)) - - -__all__ = make_all(__name__) diff --git a/tensorflow/contrib/bayesflow/python/ops/stochastic_graph_impl.py b/tensorflow/contrib/bayesflow/python/ops/stochastic_graph_impl.py deleted file mode 100644 index b2338bca8c94e0c7c44182f3f6bba7d7e79595e1..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/ops/stochastic_graph_impl.py +++ /dev/null @@ -1,175 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Classes and helper functions for Stochastic Computation Graphs. - -## Stochastic Computation Graph Helper Functions - -@@surrogate_loss -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import collections - -from tensorflow.contrib.bayesflow.python.ops import stochastic_tensor_impl -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.platform import tf_logging as logging - - -def _upstream_stochastic_nodes(tensors): - """Map tensors to the stochastic tensors upstream of them. - - Args: - tensors: a list of Tensors. - - Returns: - A dict that maps the tensors passed in to the `StochasticTensor` objects - upstream of them. - """ - reverse_map = _stochastic_dependencies_map(tensors) - upstream = collections.defaultdict(set) - for st, ts in reverse_map.items(): - for t in ts: - upstream[t].add(st) - return upstream - - -def _stochastic_dependencies_map(fixed_losses, stochastic_tensors=None): - """Map stochastic tensors to the fixed losses that depend on them. - - Args: - fixed_losses: a list of `Tensor`s. - stochastic_tensors: a list of `StochasticTensor`s to map to fixed losses. - If `None`, all `StochasticTensor`s in the graph will be used. - - Returns: - A dict `dependencies` that maps `StochasticTensor` objects to subsets of - `fixed_losses`. - - If `loss in dependencies[st]`, for some `loss` in `fixed_losses` then there - is a direct path from `st.value()` to `loss` in the graph. - """ - stoch_value_collection = stochastic_tensors or ops.get_collection( - stochastic_tensor_impl.STOCHASTIC_TENSOR_COLLECTION) - - if not stoch_value_collection: - return {} - - stoch_value_map = dict( - (node.value(), node) for node in stoch_value_collection) - - # Step backwards through the graph to see which surrogate losses correspond - # to which fixed_losses. - # - # TODO(ebrevdo): Ensure that fixed_losses and stochastic values are in the - # same frame. - stoch_dependencies_map = collections.defaultdict(set) - for loss in fixed_losses: - boundary = set([loss]) - while boundary: - edge = boundary.pop() - edge_stoch_node = stoch_value_map.get(edge, None) - if edge_stoch_node: - stoch_dependencies_map[edge_stoch_node].add(loss) - boundary.update(edge.op.inputs) - - return stoch_dependencies_map - - -def surrogate_loss(sample_losses, - stochastic_tensors=None, - name="SurrogateLoss"): - """Surrogate loss for stochastic graphs. - - This function will call `loss_fn` on each `StochasticTensor` - upstream of `sample_losses`, passing the losses that it influenced. - - Note that currently `surrogate_loss` does not work with `StochasticTensor`s - instantiated in `while_loop`s or other control structures. - - Args: - sample_losses: a list or tuple of final losses. Each loss should be per - example in the batch (and possibly per sample); that is, it should have - dimensionality of 1 or greater. All losses should have the same shape. - stochastic_tensors: a list of `StochasticTensor`s to add loss terms for. - If None, defaults to all `StochasticTensor`s in the graph upstream of - the `Tensor`s in `sample_losses`. - name: the name with which to prepend created ops. - - Returns: - `Tensor` loss, which is the sum of `sample_losses` and the - `loss_fn`s returned by the `StochasticTensor`s. - - Raises: - TypeError: if `sample_losses` is not a list or tuple, or if its elements - are not `Tensor`s. - ValueError: if any loss in `sample_losses` does not have dimensionality 1 - or greater. - """ - with ops.name_scope(name, values=sample_losses): - if not isinstance(sample_losses, (list, tuple)): - raise TypeError("sample_losses must be a list or tuple") - for loss in sample_losses: - if not isinstance(loss, ops.Tensor): - raise TypeError("loss is not a Tensor: %s" % loss) - ndims = loss.get_shape().ndims - if not (ndims is not None and ndims >= 1): - raise ValueError("loss must have dimensionality 1 or greater: %s" % - loss) - - stoch_dependencies_map = _stochastic_dependencies_map( - sample_losses, stochastic_tensors=stochastic_tensors) - if not stoch_dependencies_map: - logging.warn( - "No collection of Stochastic Tensors found for current graph.") - return math_ops.add_n(sample_losses) - - # Iterate through all of the stochastic dependencies, adding - # surrogate terms where necessary. - sample_losses = [ops.convert_to_tensor(loss) for loss in sample_losses] - loss_terms = sample_losses - for (stoch_node, dependent_losses) in stoch_dependencies_map.items(): - dependent_losses = list(dependent_losses) - - logging.info("Losses influenced by StochasticTensor %s: [%s]", - stoch_node.name, ", ".join( - [loss.name for loss in dependent_losses])) - - # Sum up the downstream losses for this ST - influenced_loss = _add_n_or_sum(dependent_losses) - - # Compute surrogate loss term - loss_term = stoch_node.loss(array_ops.stop_gradient(influenced_loss)) - if loss_term is not None: - loss_terms.append(loss_term) - - return _add_n_or_sum(loss_terms) - - -def _add_n_or_sum(terms): - # add_n works for Tensors of the same dtype and shape - shape = terms[0].get_shape() - dtype = terms[0].dtype - - if all(term.get_shape().is_fully_defined() and - term.get_shape().is_compatible_with(shape) and term.dtype == dtype - for term in terms): - return math_ops.add_n(terms) - else: - return sum(terms) diff --git a/tensorflow/contrib/bayesflow/python/ops/stochastic_tensor_impl.py b/tensorflow/contrib/bayesflow/python/ops/stochastic_tensor_impl.py deleted file mode 100644 index ce5fdd98c69ca6b3482bfafa8859accdf8a78749..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/ops/stochastic_tensor_impl.py +++ /dev/null @@ -1,477 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Classes and helper functions for creating Stochastic Tensors. - -`StochasticTensor` objects wrap `Distribution` objects. Their -values may be samples from the underlying distribution, or the distribution -mean (as governed by `value_type`). These objects provide a `loss` -method for use when sampling from a non-reparameterized distribution. -The `loss`method is used in conjunction with `stochastic_graph.surrogate_loss` -to produce a single differentiable loss in stochastic graphs having -both continuous and discrete stochastic nodes. - -## Stochastic Tensor Classes - -@@BaseStochasticTensor -@@StochasticTensor - -## Stochastic Tensor Value Types - -@@MeanValue -@@SampleValue - -@@value_type -@@get_current_value_type -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import abc -import collections -import contextlib -import threading - -import six - -from tensorflow.contrib.bayesflow.python.ops import stochastic_gradient_estimators as sge -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops.distributions import distribution - -STOCHASTIC_TENSOR_COLLECTION = "_stochastic_tensor_collection_" - - -@six.add_metaclass(abc.ABCMeta) -class BaseStochasticTensor(object): - """Base Class for Tensor-like objects that emit stochastic values.""" - - def __init__(self): - # Add self to this graph's Stochsatic Tensor collection for - # purposes of later performing correct surrogate loss calculation. - ops.add_to_collection(STOCHASTIC_TENSOR_COLLECTION, self) - - @abc.abstractproperty - def name(self): - pass - - @abc.abstractproperty - def dtype(self): - pass - - @abc.abstractproperty - def graph(self): - pass - - @abc.abstractmethod - def value(self, name=None): - pass - - @abc.abstractmethod - def loss(self, sample_loss): - """Returns the term to add to the surrogate loss. - - This method is called by `surrogate_loss`. The input `sample_loss` should - have already had `stop_gradient` applied to it. This is because the - surrogate_loss usually provides a Monte Carlo sample term of the form - `differentiable_surrogate * sample_loss` where `sample_loss` is considered - constant with respect to the input for purposes of the gradient. - - Args: - sample_loss: `Tensor`, sample loss downstream of this `StochasticTensor`. - - Returns: - Either `None` or a `Tensor`. - """ - raise NotImplementedError("surrogate_loss not implemented") - - @staticmethod - def _tensor_conversion_function(v, dtype=None, name=None, as_ref=False): - _ = name - if dtype and not dtype.is_compatible_with(v.dtype): - raise ValueError( - "Incompatible type conversion requested to type '%s' for variable " - "of type '%s'" % (dtype.name, v.dtype.name)) - if as_ref: - raise ValueError("%s: Ref type is not supported." % v) - return v.value() - - -# pylint: disable=protected-access -ops.register_tensor_conversion_function( - BaseStochasticTensor, BaseStochasticTensor._tensor_conversion_function) - -# pylint: enable=protected-access - - -class _StochasticValueType(object): - """Interface for the ValueType classes. - - This is the base class for MeanValue, SampleValue, and their descendants. - """ - - def pushed_above(self, unused_value_type): - pass - - def popped_above(self, unused_value_type): - pass - - def declare_inputs(self, unused_stochastic_tensor, unused_inputs_dict): - pass - - @abc.abstractproperty - def stop_gradient(self): - """Whether the value should be wrapped in stop_gradient. - - StochasticTensors must respect this property. - """ - pass - - -class MeanValue(_StochasticValueType): - - def __init__(self, stop_gradient=False): - self._stop_gradient = stop_gradient - - @property - def stop_gradient(self): - return self._stop_gradient - - -class SampleValue(_StochasticValueType): - """Draw samples, possibly adding new outer dimensions along the way. - - This ValueType draws samples from StochasticTensors run within its - context, increasing the rank according to the requested shape. - - Examples: - - ```python - mu = tf.zeros((2,3)) - sigma = tf.ones((2, 3)) - with sg.value_type(sg.SampleValue()): - st = sg.StochasticTensor( - tf.contrib.distributions.Normal, mu=mu, sigma=sigma) - # draws 1 sample and does not reshape - assertEqual(st.value().get_shape(), (2, 3)) - ``` - - ```python - mu = tf.zeros((2,3)) - sigma = tf.ones((2, 3)) - with sg.value_type(sg.SampleValue(4)): - st = sg.StochasticTensor( - tf.contrib.distributions.Normal, mu=mu, sigma=sigma) - # draws 4 samples each with shape (2, 3) and concatenates - assertEqual(st.value().get_shape(), (4, 2, 3)) - ``` - """ - - def __init__(self, shape=(), stop_gradient=False): - """Sample according to shape. - - For the given StochasticTensor `st` using this value type, - the shape of `st.value()` will match that of - `st.distribution.sample(shape)`. - - Args: - shape: A shape tuple or int32 tensor. The sample shape. - Default is a scalar: take one sample and do not change the size. - stop_gradient: If `True`, StochasticTensors' values are wrapped in - `stop_gradient`, to avoid backpropagation through. - """ - self._shape = shape - self._stop_gradient = stop_gradient - - @property - def shape(self): - return self._shape - - @property - def stop_gradient(self): - return self._stop_gradient - - -# Keeps track of how a StochasticTensor's value should be accessed. -# Used by value_type and get_current_value_type below. -_STOCHASTIC_VALUE_STACK = collections.defaultdict(list) - - -@contextlib.contextmanager -def value_type(dist_value_type): - """Creates a value type context for any StochasticTensor created within. - - Typical usage: - - ``` - with sg.value_type(sg.MeanValue(stop_gradients=True)): - st = sg.StochasticTensor(tf.contrib.distributions.Normal, mu=mu, - sigma=sigma) - ``` - - In the example above, `st.value()` (or equivalently, `tf.identity(st)`) will - be the mean value of the Normal distribution, i.e., `mu` (possibly - broadcasted to the shape of `sigma`). Furthermore, because the `MeanValue` - was marked with `stop_gradients=True`, this value will have been wrapped - in a `stop_gradients` call to disable any possible backpropagation. - - Args: - dist_value_type: An instance of `MeanValue`, `SampleValue`, or - any other stochastic value type. - - Yields: - A context for `StochasticTensor` objects that controls the - value created when they are initialized. - - Raises: - TypeError: if `dist_value_type` is not an instance of a stochastic value - type. - """ - if not isinstance(dist_value_type, _StochasticValueType): - raise TypeError("dist_value_type must be a Distribution Value Type") - thread_id = threading.current_thread().ident - stack = _STOCHASTIC_VALUE_STACK[thread_id] - if stack: - stack[-1].pushed_above(dist_value_type) - stack.append(dist_value_type) - yield - stack.pop() - if stack: - stack[-1].popped_above(dist_value_type) - - -class NoValueTypeSetError(ValueError): - pass - - -def get_current_value_type(): - thread_id = threading.current_thread().ident - if not _STOCHASTIC_VALUE_STACK[thread_id]: - raise NoValueTypeSetError( - "No value type currently set for this thread (%s). Did you forget to " - "wrap 'with stochastic_graph.value_type(...)'?" % thread_id) - return _STOCHASTIC_VALUE_STACK[thread_id][-1] - - -class StochasticTensor(BaseStochasticTensor): - """StochasticTensor is a BaseStochasticTensor backed by a distribution.""" - - def __init__(self, - dist, - name="StochasticTensor", - dist_value_type=None, - loss_fn=sge.score_function): - """Construct a `StochasticTensor`. - - `StochasticTensor` is backed by the `dist` distribution and its `value` - method will return the same value each time it is called. What `value` is - returned is controlled by the `dist_value_type` (defaults to - `SampleValue`). - - Some distributions' sample functions are not differentiable (e.g. a sample - from a discrete distribution like a Bernoulli) and so to differentiate - wrt parameters upstream of the sample requires a gradient estimator like - the score function estimator. This is accomplished by passing a - differentiable `loss_fn` to the `StochasticTensor`, which - defaults to a function whose derivative is the score function estimator. - Calling `stochastic_graph.surrogate_loss(final_losses)` will call - `loss()` on every `StochasticTensor` upstream of final losses. - - `loss()` will return None for `StochasticTensor`s backed by - reparameterized distributions; it will also return None if the value type is - `MeanValueType` or if `loss_fn=None`. - - Args: - dist: an instance of `Distribution`. - name: a name for this `StochasticTensor` and its ops. - dist_value_type: a `_StochasticValueType`, which will determine what the - `value` of this `StochasticTensor` will be. If not provided, the - value type set with the `value_type` context manager will be used. - loss_fn: callable that takes - `(st, st.value(), influenced_loss)`, where - `st` is this `StochasticTensor`, and returns a `Tensor` loss. By - default, `loss_fn` is the `score_function`, or more precisely, the - integral of the score function, such that when the gradient is taken, - the score function results. See the `stochastic_gradient_estimators` - module for additional loss functions and baselines. - - Raises: - TypeError: if `dist` is not an instance of `Distribution`. - TypeError: if `loss_fn` is not `callable`. - """ - if not isinstance(dist, distribution.Distribution): - raise TypeError("dist must be an instance of Distribution") - if dist_value_type is None: - try: - self._value_type = get_current_value_type() - except NoValueTypeSetError: - self._value_type = SampleValue() - else: - # We want to enforce a value type here, but use the value_type() - # context manager to enforce some error checking. - with value_type(dist_value_type): - self._value_type = get_current_value_type() - - if loss_fn is not None and not callable(loss_fn): - raise TypeError("loss_fn must be callable") - self._loss_fn = loss_fn - - with ops.name_scope(name) as scope: - self._name = scope - self._dist = dist - self._value = self._create_value() - - super(StochasticTensor, self).__init__() - - @property - def value_type(self): - return self._value_type - - @property - def distribution(self): - return self._dist - - def _create_value(self): - """Create the value Tensor based on the value type, store as self._value.""" - - if isinstance(self._value_type, MeanValue): - value_tensor = self._dist.mean() - elif isinstance(self._value_type, SampleValue): - value_tensor = self._dist.sample(self._value_type.shape) - else: - raise TypeError("Unrecognized Distribution Value Type: %s", - self._value_type) - - if self._value_type.stop_gradient: - # stop_gradient is being enforced by the value type - return array_ops.stop_gradient(value_tensor) - - if isinstance(self._value_type, MeanValue): - return value_tensor # Using pathwise-derivative for this one. - if self._dist.reparameterization_type == distribution.FULLY_REPARAMETERIZED: - return value_tensor # Using pathwise-derivative for this one. - else: - # Will have to perform some variant of score function - # estimation. Call stop_gradient on the sampler just in case we - # may accidentally leak some gradient from it. - return array_ops.stop_gradient(value_tensor) - - @property - def name(self): - return self._name - - @property - def graph(self): - return self._value.graph - - @property - def dtype(self): - return self._dist.dtype - - def entropy(self, name="entropy"): - return self._dist.entropy(name=name) - - def mean(self, name="mean"): - return self._dist.mean(name=name) - - def value(self, name="value"): - return self._value - - def loss(self, final_loss, name="Loss"): - # Return a loss based on final_loss and the distribution. Returns - # None if pathwise derivatives are supported, if the loss_fn - # was explicitly set to None, or if the value type is MeanValue. - if self._loss_fn is None: - return None - - if (self._dist.reparameterization_type == distribution.FULLY_REPARAMETERIZED - and not self._value_type.stop_gradient): - # Can perform pathwise-derivative on this one; no additional loss needed. - return None - - with ops.name_scope(self.name, values=[final_loss]): - with ops.name_scope(name): - if (self._value_type.stop_gradient or - isinstance(self._value_type, SampleValue)): - return self._loss_fn(self, self._value, final_loss) - elif isinstance(self._value_type, MeanValue): - return None # MeanValue generally provides its own gradient - else: - raise TypeError("Unrecognized Distribution Value Type: %s", - self._value_type) - - -class ObservedStochasticTensor(StochasticTensor): - """A StochasticTensor with an observed value.""" - - # pylint: disable=super-init-not-called - def __init__(self, dist, value, name=None): - """Construct an `ObservedStochasticTensor`. - - `ObservedStochasticTensor` is backed by distribution `dist` and uses the - provided value instead of using the current value type to draw a value from - the distribution. The provided value argument must be appropriately shaped - to have come from the distribution. - - Args: - dist: an instance of `Distribution`. - value: a Tensor containing the observed value - name: a name for this `ObservedStochasticTensor` and its ops. - - Raises: - TypeError: if `dist` is not an instance of `Distribution`. - ValueError: if `value` is not compatible with the distribution. - """ - if not isinstance(dist, distribution.Distribution): - raise TypeError("dist must be an instance of Distribution") - with ops.name_scope(name, "ObservedStochasticTensor", [value]) as scope: - self._name = scope - self._dist = dist - dist_shape = self._dist.batch_shape.concatenate( - self._dist.event_shape) - value = ops.convert_to_tensor(value) - value_shape = value.get_shape() - - if not value_shape.is_compatible_with(dist_shape): - if value_shape.ndims < dist_shape.ndims: - raise ValueError( - "Rank of observed value (%d) must be >= rank of a sample from the" - " distribution (%d)." % (value_shape.ndims, dist_shape.ndims)) - sample_shape = value_shape[(value_shape.ndims - dist_shape.ndims):] - if not sample_shape.is_compatible_with(dist_shape): - raise ValueError( - "Shape of observed value %s is incompatible with the shape of a " - "sample from the distribution %s." % (value_shape, dist_shape)) - if value.dtype != self._dist.dtype: - raise ValueError("Type of observed value (%s) does not match type of " - "distribution (%s)." % (value.dtype, self._dist.dtype)) - self._value = array_ops.identity(value) - # pylint: disable=non-parent-init-called - BaseStochasticTensor.__init__(self) - - def loss(self, final_loss, name=None): - return None - - -__all__ = [ - "BaseStochasticTensor", - "StochasticTensor", - "ObservedStochasticTensor", - "MeanValue", - "SampleValue", - "value_type", - "get_current_value_type", -] diff --git a/tensorflow/contrib/bayesflow/python/ops/stochastic_variables.py b/tensorflow/contrib/bayesflow/python/ops/stochastic_variables.py deleted file mode 100644 index e16dbec11a188d42615c4e63d9f93925a6df30a3..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/ops/stochastic_variables.py +++ /dev/null @@ -1,151 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Custom `get_variable` for stochastic variables. - -@@get_stochastic_variable -@@make_stochastic_variable_getter -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import functools - -from tensorflow.contrib.bayesflow.python.ops import stochastic_tensor as st -from tensorflow.contrib.bayesflow.python.ops import variational_inference as vi - - -def get_stochastic_variable(getter, - name, - shape=None, - dist_cls=None, - dist_kwargs=None, - param_initializers=None, - prior=None, - **kwargs): - """Custom variable getter for stochastic variables. - - `get_stochastic_variable` will create variables backing the parameters of a - distribution, defined by `dist_cls`, and return a `StochasticTensor` which - represents a sample from the backing distribution. - - Meant to be passed as the `custom_getter` to a `variable_scope`. Use - `make_stochastic_variable_getter` to partially apply distribution-related - args. - - Usage: - - ```python - - sv = tf.contrib.bayesflow.stochastic_variables - dist = tf.contrib.distributions - - with tf.variable_scope('my_scope', - custom_getter=sv.make_stochastic_variable_getter( - dist_cls=dist.NormalWithSoftplusSigma - param_initializers={ - "sigma": lambda shape, dtype, pi: ( - tf.constant(0.5, dtype=dtype, shape=shape)) - })): - v = tf.get_variable('my_var', (10, 20)) - ``` - - `v` is a `StochasticTensor`, which is a sample from a backing - `NormalWithSoftplusSigma` distribution. Underneath, 2 variables have been - created: `my_var_mu` and `my_var_sigma`. `my_var_sigma` has been appropriately - constrained to be positive by the `NormalWithSoftplusSigma` constructor, and - initialized to a value of 0.5, which results in a sigma of ~1 after the - softplus. The sample will have shape `(10, 20)`. - - Args: - getter: original variable getter. - name: prefix for variable(s) backing distribution parameters. - shape: shape of the sample from the distribution (i.e. shape of the - returned `StochasticTensor`). - dist_cls: subclass of `Distribution` that implements `param_shapes`. Should - accept unconstrained parameters (e.g. `NormalWithSoftplusSigma` accepts - real-valued `sigma` and constrains it to be positive with `softplus`). - dist_kwargs: `dict` of kwargs to be forwarded to `dist_cls`. - param_initializers: `dict` from parameter name to initializer (see - `get_variable` for initializer docs). Will override `initializer` in - `kwargs`. `param_initializers` may contain initializers for only some of - the parameters. Those parameters that do not contain entries will be - initialized by `kwargs['initializer']`, if provided; otherwise, the - default initialization of `getter` will be used. - prior: instance of `Distribution` or a callable - `(TensorShape, dtype) => Distribution`. If provided, will be registered - as the prior for the `StochasticTensor` using - `variational_inference.register_prior`. - **kwargs: kwargs forwarded to `getter`. - - Returns: - `StochasticTensor`, which represents a sample from the backing distribution. - """ - param_initializers = param_initializers or {} - param_shapes = {} - - if shape is not None: - param_shapes = dist_cls.param_static_shapes(shape) - - param_names = set(list(param_shapes.keys()) + list(param_initializers.keys())) - params = {} - for param_name in param_names: - # For each parameter, its param_initializer is used, if provided. Otherwise, - # kwargs['initializer'] is used. If neither were provided, the default - # variable initialization in getter will be used (i.e. getter will be passed - # initializer=None. - original_initializer = kwargs.pop('initializer', None) - param_initializer = param_initializers.get(param_name, None) - if param_initializer is None: - param_initializer = original_initializer - - if callable(param_initializer) or param_initializer is None: - param_shape = param_shapes.get(param_name, None) - else: - param_shape = None - - params[param_name] = getter( - name + '_' + param_name, - shape=param_shape, - initializer=param_initializer, - **kwargs) - - dist_kwargs = dist_kwargs or {} - dist_kwargs.update(params) - sample = st.StochasticTensor(dist_cls(**dist_kwargs)) - - if prior is not None: - if callable(prior): - sample_value = sample.value() - sample_value.get_shape().assert_is_fully_defined() - prior = prior(sample_value.get_shape(), sample_value.dtype) - - vi.register_prior(sample, prior) - - return sample - - -def make_stochastic_variable_getter(dist_cls, - dist_kwargs=None, - param_initializers=None, - prior=None): - """`get_stochastic_variable` with args partially applied.""" - return functools.partial( - get_stochastic_variable, - dist_cls=dist_cls, - dist_kwargs=dist_kwargs, - param_initializers=param_initializers, - prior=prior) diff --git a/tensorflow/contrib/bayesflow/python/ops/variational_inference_impl.py b/tensorflow/contrib/bayesflow/python/ops/variational_inference_impl.py deleted file mode 100644 index 8d932a7c340e21da012d4ab93883735b13e01175..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/ops/variational_inference_impl.py +++ /dev/null @@ -1,327 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Variational inference. - -See the ${@python/contrib.bayesflow.variational_inference} guide. - -@@elbo -@@elbo_with_log_joint -@@ELBOForms -@@register_prior -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.bayesflow.python.ops import stochastic_graph_impl as sg -from tensorflow.contrib.bayesflow.python.ops import stochastic_tensor_impl as st -from tensorflow.python.framework import ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops.distributions import distribution -from tensorflow.python.ops.distributions import kullback_leibler -from tensorflow.python.platform import tf_logging as logging - -VI_PRIORS = "__vi_priors__" - - -def register_prior(variational, prior): - """Associate a variational `StochasticTensor` with a `Distribution` prior. - - This is a helper function used in conjunction with `elbo` that allows users - to specify the mapping between variational distributions and their priors - without having to pass in `variational_with_prior` explicitly. - - Args: - variational: `StochasticTensor` q(Z). Approximating distribution. - prior: `Distribution` p(Z). Prior distribution. - - Returns: - None - - Raises: - ValueError: if variational is not a `StochasticTensor` or `prior` is not - a `Distribution`. - """ - if not isinstance(variational, st.StochasticTensor): - raise TypeError("variational must be a StochasticTensor") - if not isinstance(prior, distribution.Distribution): - raise TypeError("prior must be a Distribution") - ops.add_to_collection(VI_PRIORS, (variational, prior)) - - -class _ELBOForm(object): - pass - - -class ELBOForms(object): - """Constants to control the `elbo` calculation. - - `analytic_kl` uses the analytic KL divergence between the - variational distribution(s) and the prior(s). - - `analytic_entropy` uses the analytic entropy of the variational - distribution(s). - - `sample` uses the sample KL or the sample entropy is the joint is provided. - - See `elbo` for what is used with `default`. - """ - default, analytic_kl, analytic_entropy, sample = (_ELBOForm() - for _ in range(4)) - - @staticmethod - def check_form(form): - if form not in { - ELBOForms.default, ELBOForms.analytic_kl, ELBOForms.analytic_entropy, - ELBOForms.sample - }: - raise TypeError("form must be an ELBOForms constant") - - -def elbo(log_likelihood, - variational_with_prior=None, - keep_batch_dim=True, - form=None, - name="ELBO"): - r"""Evidence Lower BOund. `log p(x) >= ELBO`. - - Optimization objective for inference of hidden variables by variational - inference. - - This function is meant to be used in conjunction with `StochasticTensor`. - The user should build out the inference network, using `StochasticTensor`s - as latent variables, and the generative network. `elbo` at minimum needs - `p(x|Z)` and assumes that all `StochasticTensor`s upstream of `p(x|Z)` are - the variational distributions. Use `register_prior` to register `Distribution` - priors for each `StochasticTensor`. Alternatively, pass in - `variational_with_prior` specifying all variational distributions and their - priors. - - Mathematical details: - - ``` - log p(x) = log \int p(x, Z) dZ - = log \int \frac {q(Z)p(x, Z)}{q(Z)} dZ - = log E_q[\frac {p(x, Z)}{q(Z)}] - >= E_q[log \frac {p(x, Z)}{q(Z)}] = L[q; p, x] # ELBO - - L[q; p, x] = E_q[log p(x|Z)p(Z)] - E_q[log q(Z)] - = E_q[log p(x|Z)p(Z)] + H[q] (1) - = E_q[log p(x|Z)] - KL(q || p) (2) - - H - Entropy - KL - Kullback-Leibler divergence - ``` - - See section 2.2 of Stochastic Variational Inference by Hoffman et al. for - more, including the ELBO's equivalence to minimizing `KL(q(Z)||p(Z|x))` - in the fully Bayesian setting. https://arxiv.org/pdf/1206.7051.pdf. - - `form` specifies which form of the ELBO is used. `form=ELBOForms.default` - tries, in order of preference: analytic KL, analytic entropy, sampling. - - Multiple entries in the `variational_with_prior` dict implies a factorization. - e.g. `q(Z) = q(z1)q(z2)q(z3)`. - - Args: - log_likelihood: `Tensor` log p(x|Z). - variational_with_prior: dict from `StochasticTensor` q(Z) to - `Distribution` p(Z). If `None`, defaults to all `StochasticTensor` - objects upstream of `log_likelihood` with priors registered with - `register_prior`. - keep_batch_dim: bool. Whether to keep the batch dimension when summing - entropy/KL term. When the sample is per data point, this should be True; - otherwise (e.g. in a Bayesian NN), this should be False. - form: ELBOForms constant. Controls how the ELBO is computed. Defaults to - ELBOForms.default. - name: name to prefix ops with. - - Returns: - `Tensor` ELBO of the same type and shape as `log_likelihood`. - - Raises: - TypeError: if variationals in `variational_with_prior` are not - `StochasticTensor`s or if priors are not `Distribution`s. - TypeError: if form is not a valid ELBOForms constant. - ValueError: if `variational_with_prior` is None and there are no - `StochasticTensor`s upstream of `log_likelihood`. - ValueError: if any variational does not have a prior passed or registered. - """ - if form is None: - form = ELBOForms.default - with ops.name_scope(name): - model = ops.convert_to_tensor(log_likelihood) - variational_with_prior = _find_variational_and_priors( - model, variational_with_prior) - return _elbo(form, log_likelihood, None, variational_with_prior, - keep_batch_dim) - - -def elbo_with_log_joint(log_joint, - variational=None, - keep_batch_dim=True, - form=None, - name="ELBO"): - """Evidence Lower BOund. `log p(x) >= ELBO`. - - This method is for models that have computed `p(x,Z)` instead of `p(x|Z)`. - See `elbo` for further details. - - Because only the joint is specified, analytic KL is not available. - - Args: - log_joint: `Tensor` log p(x, Z). - variational: list of `StochasticTensor` q(Z). If `None`, defaults to all - `StochasticTensor` objects upstream of `log_joint`. - keep_batch_dim: bool. Whether to keep the batch dimension when summing - entropy term. When the sample is per data point, this should be True; - otherwise (e.g. in a Bayesian NN), this should be False. - form: ELBOForms constant. Controls how the ELBO is computed. Defaults to - ELBOForms.default. - name: name to prefix ops with. - - Returns: - `Tensor` ELBO of the same type and shape as `log_joint`. - - Raises: - TypeError: if variationals in `variational` are not `StochasticTensor`s. - TypeError: if form is not a valid ELBOForms constant. - ValueError: if `variational` is None and there are no `StochasticTensor`s - upstream of `log_joint`. - ValueError: if form is ELBOForms.analytic_kl. - """ - if form is None: - form = ELBOForms.default - if form == ELBOForms.analytic_kl: - raise ValueError("ELBOForms.analytic_kl is not available when using " - "elbo_with_log_joint. Use elbo or a different form.") - - with ops.name_scope(name): - model = ops.convert_to_tensor(log_joint) - - variational_with_prior = None - if variational is not None: - variational_with_prior = dict(zip(variational, [None] * len(variational))) - variational_with_prior = _find_variational_and_priors( - model, variational_with_prior, require_prior=False) - return _elbo(form, None, log_joint, variational_with_prior, keep_batch_dim) - - -def _elbo(form, log_likelihood, log_joint, variational_with_prior, - keep_batch_dim): - """Internal implementation of ELBO. Users should use `elbo`. - - Args: - form: ELBOForms constant. Controls how the ELBO is computed. - log_likelihood: `Tensor` log p(x|Z). - log_joint: `Tensor` log p(x, Z). - variational_with_prior: `dict`, varational - distributions to prior distributions. - keep_batch_dim: bool. Whether to keep the batch dimension when reducing - the entropy/KL. - - Returns: - ELBO `Tensor` with same shape and dtype as `log_likelihood`/`log_joint`. - """ - ELBOForms.check_form(form) - - # Order of preference - # 1. Analytic KL: log_likelihood - KL(q||p) - # 2. Analytic entropy: log_likelihood + log p(Z) + H[q], or log_joint + H[q] - # 3. Sample: log_likelihood - (log q(Z) - log p(Z)) = - # log_likelihood + log p(Z) - log q(Z), or log_joint - q(Z) - - def _reduce(val): - if keep_batch_dim: - return val - else: - return math_ops.reduce_sum(val) - - kl_terms = [] - entropy_terms = [] - prior_terms = [] - for q, z, p in [(qz.distribution, qz.value(), pz) - for qz, pz in variational_with_prior.items()]: - # Analytic KL - kl = None - if log_joint is None and form in {ELBOForms.default, ELBOForms.analytic_kl}: - try: - kl = kullback_leibler.kl_divergence(q, p) - logging.info("Using analytic KL between q:%s, p:%s", q, p) - except NotImplementedError as e: - if form == ELBOForms.analytic_kl: - raise e - if kl is not None: - kl_terms.append(-1. * _reduce(kl)) - continue - - # Analytic entropy - entropy = None - if form in {ELBOForms.default, ELBOForms.analytic_entropy}: - try: - entropy = q.entropy() - logging.info("Using analytic entropy for q:%s", q) - except NotImplementedError as e: - if form == ELBOForms.analytic_entropy: - raise e - if entropy is not None: - entropy_terms.append(_reduce(entropy)) - if log_likelihood is not None: - prior = p.log_prob(z) - prior_terms.append(_reduce(prior)) - continue - - # Sample - if form in {ELBOForms.default, ELBOForms.sample}: - entropy = -q.log_prob(z) - entropy_terms.append(_reduce(entropy)) - if log_likelihood is not None: - prior = p.log_prob(z) - prior_terms.append(_reduce(prior)) - - first_term = log_joint if log_joint is not None else log_likelihood - return sum([first_term] + kl_terms + entropy_terms + prior_terms) - - -def _find_variational_and_priors(model, - variational_with_prior, - require_prior=True): - """Find upstream StochasticTensors and match with registered priors.""" - if variational_with_prior is None: - # pylint: disable=protected-access - upstreams = sg._upstream_stochastic_nodes([model]) - # pylint: enable=protected-access - upstreams = list(upstreams[model]) - if not upstreams: - raise ValueError("No upstream stochastic nodes found for tensor: %s", - model) - prior_map = dict(ops.get_collection(VI_PRIORS)) - variational_with_prior = {} - for q in upstreams: - if require_prior and (q not in prior_map or prior_map[q] is None): - raise ValueError("No prior specified for StochasticTensor: %s", q) - variational_with_prior[q] = prior_map.get(q) - - if not all( - [isinstance(q, st.StochasticTensor) for q in variational_with_prior]): - raise TypeError("variationals must be StochasticTensors") - if not all([ - p is None or isinstance(p, distribution.Distribution) - for p in variational_with_prior.values() - ]): - raise TypeError("priors must be Distribution objects") - - return variational_with_prior diff --git a/tensorflow/contrib/boosted_trees/BUILD b/tensorflow/contrib/boosted_trees/BUILD index 66a04d42e93331de74b6f3d41f83f071115c1097..7072f56420ac9e576b20b62c0aa67498857403a7 100644 --- a/tensorflow/contrib/boosted_trees/BUILD +++ b/tensorflow/contrib/boosted_trees/BUILD @@ -359,8 +359,8 @@ tf_custom_op_library( ], deps = [ "//tensorflow/contrib/boosted_trees/lib:example_partitioner", - "//tensorflow/contrib/boosted_trees/lib:feature-column-handlers", "//tensorflow/contrib/boosted_trees/lib:models", + "//tensorflow/contrib/boosted_trees/lib:node-stats", "//tensorflow/contrib/boosted_trees/lib:utils", "//tensorflow/contrib/boosted_trees/lib:weighted_quantiles", "//tensorflow/contrib/boosted_trees/proto:learner_proto_cc", @@ -404,10 +404,12 @@ tf_kernel_library( name = "split_handler_ops_kernels", srcs = ["kernels/split_handler_ops.cc"], deps = [ - "//tensorflow/contrib/boosted_trees/lib:feature-column-handlers", + "//tensorflow/contrib/boosted_trees/lib:node-stats", "//tensorflow/contrib/boosted_trees/proto:split_info_proto_cc", "//tensorflow/contrib/boosted_trees/proto:tree_config_proto_cc", "//tensorflow/core:framework_headers_lib", + "//tensorflow/core:protos_all_cc", + "//third_party/eigen3", ], alwayslink = 1, ) diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py index ef8dee91b6cc05c4c3dd5eb3c81de4fb65b473e3..6ebc7d7911df878ec91701db8b75feb9a27d18a2 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py @@ -33,6 +33,8 @@ from tensorflow.python.platform import gfile from tensorflow.python.saved_model import loader as saved_model_loader from tensorflow.python.saved_model import tag_constants +_SPARSE_FLOAT_FEATURE_NAME_TEMPLATE = "%s_%d" + def make_custom_export_strategy(name, convert_fn, @@ -147,13 +149,12 @@ def convert_to_universal_format(dtec, sorted_feature_names, inequality_test.threshold.float_value = split.threshold elif node_type == "sparse_float_binary_split_default_left": split = gtflow_node.sparse_float_binary_split_default_left.split - node.default_direction = ( - generic_tree_model_pb2.BinaryNode.LEFT) - # TODO(nponomareva): adjust this id assignement when we allow multi- - # column sparse tensors. + node.default_direction = (generic_tree_model_pb2.BinaryNode.LEFT) feature_id = split.feature_column + num_dense inequality_test = node.inequality_left_child_test - inequality_test.feature_id.id.value = sorted_feature_names[feature_id] + inequality_test.feature_id.id.value = ( + _SPARSE_FLOAT_FEATURE_NAME_TEMPLATE % + (sorted_feature_names[feature_id], split.dimension_id)) inequality_test.type = ( generic_tree_model_pb2.InequalityTest.LESS_OR_EQUAL) inequality_test.threshold.float_value = split.threshold @@ -165,7 +166,9 @@ def convert_to_universal_format(dtec, sorted_feature_names, # column sparse tensors. feature_id = split.feature_column + num_dense inequality_test = node.inequality_left_child_test - inequality_test.feature_id.id.value = sorted_feature_names[feature_id] + inequality_test.feature_id.id.value = ( + _SPARSE_FLOAT_FEATURE_NAME_TEMPLATE % + (sorted_feature_names[feature_id], split.dimension_id)) inequality_test.type = ( generic_tree_model_pb2.InequalityTest.LESS_OR_EQUAL) inequality_test.threshold.float_value = split.threshold @@ -201,10 +204,14 @@ def _get_feature_importances(dtec, feature_names, num_dense_floats, split_column = feature_names[split.feature_column] elif node_type == "sparse_float_binary_split_default_left": split = tree_node.sparse_float_binary_split_default_left.split - split_column = feature_names[split.feature_column + num_dense_floats] + split_column = _SPARSE_FLOAT_FEATURE_NAME_TEMPLATE % ( + feature_names[split.feature_column + num_dense_floats], + split.dimension_id) elif node_type == "sparse_float_binary_split_default_right": split = tree_node.sparse_float_binary_split_default_right.split - split_column = feature_names[split.feature_column + num_dense_floats] + split_column = _SPARSE_FLOAT_FEATURE_NAME_TEMPLATE % ( + feature_names[split.feature_column + num_dense_floats], + split.dimension_id) elif node_type == "categorical_id_binary_split": split = tree_node.categorical_id_binary_split split_column = feature_names[split.feature_column + num_dense_floats + diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy_test.py index 4ed18b2d34c5af47826ab1c058f5d13797593bd4..492d9ca40c5cfa84e186020605429aacc02af6a6 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy_test.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy_test.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the conversion code from GTFlow format to Chauffeur.""" +"""Tests for the conversion code and for feature importances export. + +Tests that cover conversion from TFBT format to a tensorflow.contrib. +decision_tree generic_tree_model format and feature importances export. +""" from __future__ import absolute_import from __future__ import division @@ -95,10 +99,31 @@ class ConvertModelTest(test_util.TensorFlowTestCase): } } } + nodes { + sparse_float_binary_split_default_right { + split { + feature_column: 1 + dimension_id:3 + threshold: -0.4 + left_id: 7 + right_id: 8 + } + } + node_metadata { + gain: 3600 + } + } + nodes { + leaf { + vector { + value: 0.36 + } + } + } nodes { leaf { vector { - value: 0.3 + value: 18 } } } @@ -108,17 +133,25 @@ class ConvertModelTest(test_util.TensorFlowTestCase): """ dtec = tree_config_pb2.DecisionTreeEnsembleConfig() text_format.Merge(dtec_str, dtec) - feature_columns = ["feature_b", "feature_a", "feature_d"] + feature_columns = [ + "feature_b", + "feature_a", + "feature_a_m", + "feature_d", + ] return dtec, feature_columns def testConvertModel(self): dtec, feature_columns = self._make_trees() + # Assume 2 sparse float columns, one with 1 dimension, the second one with + # 5 dimensions. # The feature columns in the order they were added. out = custom_export_strategy.convert_to_universal_format( - dtec, feature_columns, 1, 1, - 1) + dtec, feature_columns, 1, 2, 1) + # Features a and a_m are sparse float features, a_m is multidimensional. expected_tree = """ features { key: "feature_a" } + features { key: "feature_a_m" } features { key: "feature_b" } features { key: "feature_d" } model { @@ -169,7 +202,6 @@ class ConvertModelTest(test_util.TensorFlowTestCase): } } } - nodes { node_id { value: 1 @@ -196,7 +228,7 @@ class ConvertModelTest(test_util.TensorFlowTestCase): inequality_left_child_test { feature_id { id { - value: "feature_a" + value: "feature_a_0" } } threshold { @@ -259,14 +291,51 @@ class ConvertModelTest(test_util.TensorFlowTestCase): node_id { value: 6 } + binary_node { + left_child_id { + value: 7 + } + right_child_id { + value: 8 + } + default_direction: RIGHT + inequality_left_child_test { + feature_id { + id { + value: "feature_a_m_3" + } + } + threshold { + float_value: -0.4 + } + } + } + } + nodes { + node_id { + value: 7 + } leaf { vector { value { - float_value: 0.03 + float_value: 0.036 } } } } + nodes { + node_id { + value: 8 + } + leaf { + vector { + value { + float_value: 1.8 + } + } + } + } + } } submodel_id { @@ -280,12 +349,15 @@ class ConvertModelTest(test_util.TensorFlowTestCase): def testFeatureImportance(self): dtec, feature_columns = self._make_trees() feature_importances = custom_export_strategy._get_feature_importances( - dtec, feature_columns, 1, 1, 1) - self.assertItemsEqual(["feature_b", "feature_a", "feature_d"], - feature_importances.keys()) + dtec, feature_columns, 1, 2, 1) + self.assertItemsEqual( + ["feature_b", "feature_a_0", "feature_a_m_3", "feature_d"], + feature_importances.keys()) self.assertAlmostEqual(50.0, feature_importances["feature_b"], places=4) - self.assertAlmostEqual(50.0, feature_importances["feature_a"], places=4) + self.assertAlmostEqual(50.0, feature_importances["feature_a_0"], places=4) self.assertAlmostEqual(50.0, feature_importances["feature_d"], places=4) + self.assertAlmostEqual( + 360.0, feature_importances["feature_a_m_3"], places=4) if __name__ == "__main__": diff --git a/tensorflow/contrib/boosted_trees/examples/boston.py b/tensorflow/contrib/boosted_trees/examples/boston.py index 2c0a3c4912b82aba88e2f8f1b97a227c894ee2ae..e9dbdb0fd784052eeb36ac1aa9342165ef2ac0a7 100644 --- a/tensorflow/contrib/boosted_trees/examples/boston.py +++ b/tensorflow/contrib/boosted_trees/examples/boston.py @@ -22,7 +22,7 @@ r"""Demonstrates a regression on Boston housing data. python tensorflow/contrib/boosted_trees/examples/boston.py \ --batch_size=404 --output_dir="/tmp/boston" --depth=4 --learning_rate=0.1 \ - --num_eval_steps=1 --num_trees=500 --l2=4 \ + --num_eval_steps=1 --num_trees=500 --l2=0.001 \ --vmodule=training_ops=1 When training is done, mean squared error on eval data is reported. @@ -37,8 +37,10 @@ from __future__ import division from __future__ import print_function import argparse +import os import sys import tensorflow as tf +from tensorflow.contrib.boosted_trees.estimator_batch import custom_export_strategy from tensorflow.contrib.boosted_trees.estimator_batch.estimator import GradientBoostedDecisionTreeRegressor from tensorflow.contrib.boosted_trees.proto import learner_pb2 from tensorflow.contrib.layers.python.layers import feature_column @@ -51,22 +53,18 @@ _BOSTON_NUM_FEATURES = 13 def _get_tfbt(output_dir, feature_cols): """Configures TF Boosted Trees estimator based on flags.""" learner_config = learner_pb2.LearnerConfig() - learner_config.learning_rate_tuner.fixed.learning_rate = FLAGS.learning_rate learner_config.regularization.l1 = 0.0 - # Set the regularization per instance in such a way that - # regularization for the full training data is equal to l2 flag. - learner_config.regularization.l2 = FLAGS.l2 / FLAGS.batch_size + learner_config.regularization.l2 = FLAGS.l2 learner_config.constraints.max_tree_depth = FLAGS.depth - learner_config.growing_mode = learner_pb2.LearnerConfig.WHOLE_TREE run_config = tf.contrib.learn.RunConfig(save_checkpoints_secs=300) # Create a TF Boosted trees regression estimator. estimator = GradientBoostedDecisionTreeRegressor( learner_config=learner_config, - # For the WHOLE_TREE strategy, set the examples_per_layer to be equal to - # batch size. + # This should be the number of examples. For large datasets it can be + # larger than the batch_size. examples_per_layer=FLAGS.batch_size, feature_columns=feature_cols, label_dimension=1, @@ -77,6 +75,14 @@ def _get_tfbt(output_dir, feature_cols): return estimator +def _convert_fn(dtec, sorted_feature_names, num_dense, num_sparse_float, + num_sparse_int, export_dir, unused_eval_result): + universal_format = custom_export_strategy.convert_to_universal_format( + dtec, sorted_feature_names, num_dense, num_sparse_float, num_sparse_int) + with tf.gfile.GFile(os.path.join(export_dir, "tree_proto"), "w") as f: + f.write(str(universal_format)) + + def _make_experiment_fn(output_dir): """Creates experiment for gradient boosted decision trees.""" (x_train, y_train), (x_test, @@ -88,21 +94,31 @@ def _make_experiment_fn(output_dir): batch_size=FLAGS.batch_size, num_epochs=None, shuffle=True) - eval_input_fn = tf.estimator.inputs.numpy_input_fn( x={"x": x_test}, y=y_test, num_epochs=1, shuffle=False) feature_columns = [ feature_column.real_valued_column("x", dimension=_BOSTON_NUM_FEATURES) ] - + feature_spec = tf.contrib.layers.create_feature_spec_for_parsing( + feature_columns) + serving_input_fn = tf.contrib.learn.utils.build_parsing_serving_input_fn( + feature_spec) + # An export strategy that outputs the feature importance and also exports + # the internal tree representation in another format. + export_strategy = custom_export_strategy.make_custom_export_strategy( + "exports", + convert_fn=_convert_fn, + feature_columns=feature_columns, + export_input_fn=serving_input_fn) return tf.contrib.learn.Experiment( estimator=_get_tfbt(output_dir, feature_columns), train_input_fn=train_input_fn, eval_input_fn=eval_input_fn, train_steps=None, eval_steps=FLAGS.num_eval_steps, - eval_metrics=None) + eval_metrics=None, + export_strategies=[export_strategy]) def main(unused_argv): diff --git a/tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc b/tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc index 766982b4f2023310e6046619939f83bef63b0302..f8086b0c2bb93eae6af0336bbe33fc23f8fcde22 100644 --- a/tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc @@ -63,19 +63,26 @@ const char* kPredictionsTensorName = "predictions"; void CalculateTreesToInclude( const boosted_trees::trees::DecisionTreeEnsembleConfig& config, const std::vector& trees_to_drop, const int32 num_trees, - const bool only_finalized, std::vector* trees_to_include) { + const bool only_finalized, const bool center_bias, + std::vector* trees_to_include) { trees_to_include->reserve(num_trees - trees_to_drop.size()); int32 index = 0; // This assumes that trees_to_drop is a sorted list of tree ids. for (int32 tree = 0; tree < num_trees; ++tree) { - if ((!trees_to_drop.empty() && index < trees_to_drop.size() && - trees_to_drop[index] == tree) || - (only_finalized && config.tree_metadata_size() > 0 && - !config.tree_metadata(tree).is_finalized())) { + // Skip the tree if tree is in the list of trees_to_drop. + if (!trees_to_drop.empty() && index < trees_to_drop.size() && + trees_to_drop[index] == tree) { ++index; continue; } + // Or skip if the tree is not finalized and only_finalized is set, + // with the exception of centering bias. + if (only_finalized && !(center_bias && tree == 0) && + config.tree_metadata_size() > 0 && + !config.tree_metadata(tree).is_finalized()) { + continue; + } trees_to_include->push_back(tree); } } @@ -250,7 +257,7 @@ class GradientTreesPredictionOp : public OpKernel { CalculateTreesToInclude( ensemble_resource->decision_tree_ensemble(), dropped_trees, ensemble_resource->decision_tree_ensemble().trees_size(), - only_finalized_trees_, &trees_to_include); + only_finalized_trees_, center_bias_, &trees_to_include); // Allocate output predictions matrix. Tensor* output_predictions_t = nullptr; diff --git a/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc b/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc index b08028eb635385357ba13b48d88157936978b6f1..8600c8c53caa5fd4274ba6730fc764d8315d680c 100644 --- a/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc @@ -50,6 +50,7 @@ const char* const kAreBucketsReadyName = "are_buckets_ready"; const char* const kNumSparseFeaturesName = "num_sparse_features"; const char* const kSparseBucketsName = "sparse_buckets"; const char* const kSparseValuesName = "sparse_values"; +const char* const kSparseIndicesName = "sparse_indices"; const char* const kSparseStreamsStateName = "sparse_streams_state"; const char* const kSparseSummariesName = "sparse_summaries"; const char* const kSparseConfigName = "sparse_config"; @@ -85,9 +86,23 @@ std::vector GetBuckets(const int32 feature, return buckets_vector; } -void QuantizeFeatures(const string& output_name, const OpInputList& values_list, - const OpInputList& buckets_list, - OpKernelContext* const context) { +int32 GetFeatureDimension(const int32 feature_index, const int64 instance, + const OpInputList* const indices_list) { + if (indices_list != nullptr) { + // Sparse multidimensional. + return (*indices_list)[feature_index].matrix()(instance, 1); + } + // No indices, assume one-dimensional tensor. + return 0; +} + +// Allows quantization for each of multiple dimensions of a sparse feature. +void QuantizeFeatures( + const string& output_name, const OpInputList& values_list, + const OpInputList& buckets_list, + const OpInputList* const + indices_list /** Optional, provide for sparse features **/, + OpKernelContext* const context) { if (values_list.size() == 0) { return; } @@ -100,10 +115,13 @@ void QuantizeFeatures(const string& output_name, const OpInputList& values_list, const int64 num_values = values_tensor.dim_size(0); Tensor* output_t = nullptr; + // Output will have bucket id and dimension of the features for that bucket. OP_REQUIRES_OK( - context, output_list.allocate(feature_index, TensorShape({num_values}), - &output_t)); - TTypes::Vec output = output_t->vec(); + context, output_list.allocate(feature_index, + TensorShape({num_values, 2}), &output_t)); + + auto output = output_t->matrix(); + const std::vector& buckets_vector = GetBuckets(feature_index, buckets_list); auto flat_values = values_tensor.flat(); @@ -116,7 +134,11 @@ void QuantizeFeatures(const string& output_name, const OpInputList& values_list, } const int32 bucket = static_cast(bucket_iter - buckets_vector.begin()); - output(instance) = bucket; + // Bucket id. + output(instance, 0) = bucket; + // Dimension. + output(instance, 1) = + GetFeatureDimension(feature_index, instance, indices_list); } } } @@ -851,6 +873,11 @@ class QuantilesOp : public OpKernel { OP_REQUIRES_OK(context, context->input_list(kSparseValuesName, &sparse_float_feature_values_list)); + + OpInputList sparse_float_indices_list; + OP_REQUIRES_OK(context, context->input_list(kSparseIndicesName, + &sparse_float_indices_list)); + OpInputList sparse_buckets_list; OP_REQUIRES_OK( context, context->input_list(kSparseBucketsName, &sparse_buckets_list)); @@ -865,10 +892,10 @@ class QuantilesOp : public OpKernel { // Quantize the feature values QuantizeFeatures(kDenseOutputTensorName, dense_float_features_list, - dense_buckets_list, context); + dense_buckets_list, nullptr, context); QuantizeFeatures(kSparseOutputTensorName, sparse_float_feature_values_list, - sparse_buckets_list, context); + sparse_buckets_list, &sparse_float_indices_list, context); } }; diff --git a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc index 29635bb3c404e54f0561d9b9189270022f063cbe..18b4abd654ea3541d646a43ac901aca1a678446f 100644 --- a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc @@ -16,7 +16,7 @@ #include #include -#include "tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/feature-column-handler.h" +#include "tensorflow/contrib/boosted_trees/lib/learner/common/stats/node-stats.h" #include "tensorflow/contrib/boosted_trees/proto/split_info.pb.h" #include "tensorflow/contrib/boosted_trees/proto/tree_config.pb.h" #include "tensorflow/core/framework/device_base.h" @@ -39,6 +39,10 @@ using boosted_trees::learner::stochastic::GradientStats; using boosted_trees::learner::stochastic::NodeStats; using boosted_trees::learner::LearnerConfig_MultiClassStrategy; +namespace { +const int32 DUMMY_FEATURE_DIMENSION = -1; +} // namespace + class BaseBuildSplitOp : public OpKernel { public: explicit BaseBuildSplitOp(OpKernelConstruction* const context) @@ -128,7 +132,7 @@ class BuildDenseInequalitySplitsOp : public BaseBuildSplitOp { const Tensor* bucket_ids_t; OP_REQUIRES_OK(context, context->input("bucket_ids", &bucket_ids_t)); - const auto& bucket_ids = bucket_ids_t->vec(); + const auto& bucket_ids = bucket_ids_t->matrix(); const Tensor* gradients_t; OP_REQUIRES_OK(context, context->input("gradients", &gradients_t)); @@ -219,7 +223,7 @@ class BuildDenseInequalitySplitsOp : public BaseBuildSplitOp { split_info.mutable_split_node()->mutable_dense_float_binary_split(); dense_split->set_feature_column(feature_column_group_id_); dense_split->set_threshold( - bucket_boundaries(bucket_ids(best_bucket_idx))); + bucket_boundaries(bucket_ids(best_bucket_idx, 0))); auto* left_child = split_info.mutable_left_child(); auto* right_child = split_info.mutable_right_child(); @@ -262,7 +266,9 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { const Tensor* bucket_ids_t; OP_REQUIRES_OK(context, context->input("bucket_ids", &bucket_ids_t)); - const auto& bucket_ids = bucket_ids_t->vec(); + const auto& bucket_ids_and_dimensions = bucket_ids_t->matrix(); + + const int32 tensor_elements = partition_ids.size(); const Tensor* gradients_t; OP_REQUIRES_OK(context, context->input("gradients", &gradients_t)); @@ -273,24 +279,59 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { int class_id; ReadClassId(context, &class_id); - // Find the number of unique partitions before we allocate the output. - std::vector partition_boundaries; + // For each partition (tree node), store starting index for each dimension. + PartitionAndDimensionBoundaries partition_boundaries; + // Stores indices in partition_boundaries for those partitions that are + // not empty (have at least one dimension and a bucket apart from catch-all + // bucket of -1 bucket id and dimension 0. std::vector non_empty_partitions; - for (int i = 0; i < partition_ids.size() - 1; ++i) { + bool non_empty_partition = false; + + for (int i = 0; i < partition_ids.size(); ++i) { // Make sure the input is sorted by partition_ids; - CHECK_LE(partition_ids(i), partition_ids(i + 1)); - if (i == 0 || partition_ids(i) != partition_ids(i - 1)) { - partition_boundaries.push_back(i); - // Some partitions might only have bias feature. We don't want to split - // those so check that the partition has at least 2 buckets. - if (partition_ids(i) == partition_ids(i + 1)) { - non_empty_partitions.push_back(partition_boundaries.size() - 1); + if (i > 0) { + CHECK_LE(partition_ids(i - 1), partition_ids(i)) + << "Partition ids should be sorted. Not sorted for " << i; + } + const int32 dimension = bucket_ids_and_dimensions(i, 1); + + if (i == 0 || (partition_ids(i) != partition_ids(i - 1))) { + if (i != 0) { + // Not the first entry, so partition has changed. + if (non_empty_partition) { + // Saves the id of a previous partition in a list of non empty + // partitions, since it was non empty (had more than just a bias + // bucket -1. + non_empty_partitions.push_back(partition_boundaries.size() - 1); + } + // Add dummy dimension to signify the end for the previous dimension. + partition_boundaries.back().emplace_back(DUMMY_FEATURE_DIMENSION, i); } + // Allocate for a new partition. + partition_boundaries.emplace_back(); + // Save info about the first dimension for a new partition. + partition_boundaries.back().emplace_back(dimension, i); + + // Each partition has dummy -1 bucket with all gradients and then info + // for all other dimensions -> if we have >1 elements for a partition, + // then it is not empty. + non_empty_partition = (i < partition_ids.size() - 1) && + (partition_ids(i) == partition_ids(i + 1)); + } else if (bucket_ids_and_dimensions(i, 1) != + bucket_ids_and_dimensions(i - 1, 1)) { + // Dimension changed. + partition_boundaries.back().emplace_back(dimension, i); } } - if (partition_ids.size() > 0) { - partition_boundaries.push_back(partition_ids.size()); + if (tensor_elements > 0) { + if (non_empty_partition) { + non_empty_partitions.push_back(partition_boundaries.size() - 1); + } + // Add dummy dimension to signify the end for the previous dimension. + partition_boundaries.back().emplace_back(DUMMY_FEATURE_DIMENSION, + partition_ids.size()); } + int num_elements = non_empty_partitions.size(); Tensor* output_partition_ids_t = nullptr; OP_REQUIRES_OK(context, @@ -314,73 +355,128 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { &output_splits_t)); tensorflow::TTypes::Vec output_splits = output_splits_t->vec(); + // For each tree node that needs to be split. for (int root_idx = 0; root_idx < num_elements; ++root_idx) { + const auto& dimension_boundaries = + partition_boundaries[non_empty_partitions[root_idx]]; + float best_gain = std::numeric_limits::lowest(); - int start_index = partition_boundaries[non_empty_partitions[root_idx]]; - int end_index = partition_boundaries[non_empty_partitions[root_idx] + 1]; - // First bucket ID in each partition should be the bias feature. - OP_REQUIRES(context, bucket_ids(start_index) == bias_feature_id_, - errors::InvalidArgument("Bias feature ID missing.")); + int32 best_dimension_idx = 0; + bool default_right = false; + int32 best_element_idx = 0; + + NodeStats best_right_node_stats(0); + NodeStats best_left_node_stats(0); + + // For each partition, the first bucket is dummy catch all. + int32 bias_start_index = dimension_boundaries[0].start_index; + + OP_REQUIRES( + context, + bucket_ids_and_dimensions(bias_start_index, 0) == bias_feature_id_, + errors::InvalidArgument("Bias feature ID missing.")); + + // Dimension for bias feature is always 0 + OP_REQUIRES( + context, bucket_ids_and_dimensions(bias_start_index, 1) == 0, + errors::InvalidArgument("Bias feature ID must be with dimension 0.")); + // For each root, we do two passes over the quantized feature buckets // accumulating gradients on one side and using the root aggregate // gradients to get the gradients for the other side. // Split gains are evaluated for each pass at every threshold and the best // split is picked. - GradientStats root_gradient_stats(*gradients_t, *hessians_t, start_index); + GradientStats root_gradient_stats(*gradients_t, *hessians_t, + bias_start_index); root_gradient_stats *= normalizer_ratio; NodeStats root_stats = ComputeNodeStats(root_gradient_stats); - GradientStats present_gradient_stats; - for (int64 bucket_idx = start_index + 1; bucket_idx < end_index; - ++bucket_idx) { - present_gradient_stats += - GradientStats(*gradients_t, *hessians_t, bucket_idx); - } - present_gradient_stats *= normalizer_ratio; - int32 best_bucket_idx = 0; - NodeStats best_right_node_stats(0); - NodeStats best_left_node_stats(0); - GradientStats left_gradient_stats; - bool default_right = false; - for (int64 bucket_idx = start_index + 1; bucket_idx < end_index; - ++bucket_idx) { - GradientStats g(*gradients_t, *hessians_t, bucket_idx); - g *= normalizer_ratio; - left_gradient_stats += g; - // We have the sum of all present gradients. Use that to compute the - // backward pass gradients. - GradientStats right_gradient_stats = - present_gradient_stats - left_gradient_stats; - { - NodeStats left_stats_default_left = - ComputeNodeStats(root_gradient_stats - right_gradient_stats); - NodeStats right_stats_default_left = - ComputeNodeStats(right_gradient_stats); - if (left_stats_default_left.gain + right_stats_default_left.gain > - best_gain) { - best_gain = - left_stats_default_left.gain + right_stats_default_left.gain; - best_left_node_stats = left_stats_default_left; - best_right_node_stats = right_stats_default_left; - best_bucket_idx = bucket_idx; - default_right = false; - } + + // Iterate through dimensions. + for (int j = 0; j < dimension_boundaries.size() - 1; ++j) { + const DimensionBoundary& dimension_and_start = dimension_boundaries[j]; + const int32 dimension_id = dimension_and_start.dimension_id; + + int start_index = dimension_and_start.start_index; + // Even for the last dimension, we always have additional dummy + // dimension that we can use to find the end index. + const int end_index = + partition_boundaries[non_empty_partitions[root_idx]][j + 1] + .start_index; + CHECK(bucket_ids_and_dimensions(start_index, 1) == + bucket_ids_and_dimensions(end_index - 1, 1)) + << "For bucket " << bucket_ids_and_dimensions(start_index, 0) + << " the dimension was " + << bucket_ids_and_dimensions(start_index, 1) << " and for " + << bucket_ids_and_dimensions(end_index - 1, 0) << " " + << bucket_ids_and_dimensions(end_index - 1, 1); + if (bucket_ids_and_dimensions(start_index, 0) == bias_feature_id_) { + // 0-dimension case which has a first bucket for catch all feature. + CHECK(bucket_ids_and_dimensions(start_index, 1) == 0) + << "Dimension of bias feature should be 0"; + ++start_index; } - { - NodeStats left_stats_default_right = - ComputeNodeStats(left_gradient_stats); - NodeStats right_stats_default_right = - ComputeNodeStats(root_gradient_stats - left_gradient_stats); - if (left_stats_default_right.gain + right_stats_default_right.gain > - best_gain) { - best_gain = - left_stats_default_right.gain + right_stats_default_right.gain; - best_left_node_stats = left_stats_default_right; - best_right_node_stats = right_stats_default_right; - best_bucket_idx = bucket_idx; - default_right = true; + + GradientStats present_gradient_stats; + for (int64 bucket_idx = start_index; bucket_idx < end_index; + ++bucket_idx) { + present_gradient_stats += + GradientStats(*gradients_t, *hessians_t, bucket_idx); + } + present_gradient_stats *= normalizer_ratio; + + GradientStats left_gradient_stats; + for (int64 element_idx = start_index; element_idx < end_index; + ++element_idx) { + // Check that bucket ids are sorted. + if (element_idx != start_index) { + CHECK(bucket_ids_and_dimensions(element_idx - 1, 0) < + bucket_ids_and_dimensions(element_idx, 0)) + << "Bucket ids must be sorted." + << ", problem on " << element_idx << " and dimension is " << j; + } + + GradientStats g(*gradients_t, *hessians_t, element_idx); + g *= normalizer_ratio; + left_gradient_stats += g; + // We have the sum of all present gradients. Use that to compute the + // backward pass gradients. + GradientStats right_gradient_stats = + present_gradient_stats - left_gradient_stats; + { + NodeStats left_stats_default_left = + ComputeNodeStats(root_gradient_stats - right_gradient_stats); + NodeStats right_stats_default_left = + ComputeNodeStats(right_gradient_stats); + if (left_stats_default_left.gain + right_stats_default_left.gain > + best_gain) { + best_gain = + left_stats_default_left.gain + right_stats_default_left.gain; + best_left_node_stats = left_stats_default_left; + best_right_node_stats = right_stats_default_left; + best_element_idx = element_idx; + default_right = false; + best_dimension_idx = dimension_id; + } + } + { + NodeStats left_stats_default_right = + ComputeNodeStats(left_gradient_stats); + NodeStats right_stats_default_right = + ComputeNodeStats(root_gradient_stats - left_gradient_stats); + if (left_stats_default_right.gain + right_stats_default_right.gain > + best_gain) { + best_gain = left_stats_default_right.gain + + right_stats_default_right.gain; + best_left_node_stats = left_stats_default_right; + best_right_node_stats = right_stats_default_right; + best_element_idx = element_idx; + default_right = true; + best_dimension_idx = dimension_id; + } } } } + SplitInfo split_info; boosted_trees::trees::DenseFloatBinarySplit* dense_split = nullptr; if (default_right) { @@ -393,8 +489,13 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { ->mutable_split(); } dense_split->set_feature_column(feature_column_group_id_); - dense_split->set_threshold( - bucket_boundaries(bucket_ids(best_bucket_idx))); + // Set the feature index for the best feature column. + const int64 best_dimension_id = + bucket_ids_and_dimensions(best_element_idx, 1); + const int32 best_bucket_id = + bucket_ids_and_dimensions(best_element_idx, 0); + dense_split->set_dimension_id(best_dimension_id); + dense_split->set_threshold(bucket_boundaries(best_bucket_id)); auto* left_child = split_info.mutable_left_child(); auto* right_child = split_info.mutable_right_child(); @@ -403,11 +504,23 @@ class BuildSparseInequalitySplitsOp : public BaseBuildSplitOp { split_info.SerializeToString(&output_splits(root_idx)); gains(root_idx) = best_gain - root_stats.gain - tree_complexity_regularization_; - output_partition_ids(root_idx) = partition_ids(start_index); + output_partition_ids(root_idx) = partition_ids(bias_start_index); } } private: + struct DimensionBoundary { + DimensionBoundary(const int32 dimension_id, const int32 start_index) + : dimension_id(dimension_id), start_index(start_index) {} + + int32 dimension_id; + int32 start_index; + }; + + // For each partition, store start indices of feature column dimensions. + typedef std::vector> + PartitionAndDimensionBoundaries; + int64 bias_feature_id_; }; REGISTER_KERNEL_BUILDER(Name("BuildSparseInequalitySplits").Device(DEVICE_CPU), @@ -434,7 +547,7 @@ class BuildCategoricalEqualitySplitsOp : public BaseBuildSplitOp { const Tensor* feature_ids_t; OP_REQUIRES_OK(context, context->input("feature_ids", &feature_ids_t)); - const auto& feature_ids = feature_ids_t->vec(); + const auto& feature_ids = feature_ids_t->matrix(); const Tensor* gradients_t; OP_REQUIRES_OK(context, context->input("gradients", &gradients_t)); @@ -491,7 +604,7 @@ class BuildCategoricalEqualitySplitsOp : public BaseBuildSplitOp { int start_index = partition_boundaries[non_empty_partitions[root_idx]]; int end_index = partition_boundaries[non_empty_partitions[root_idx] + 1]; // First feature ID in each partition should be the bias feature. - OP_REQUIRES(context, feature_ids(start_index) == bias_feature_id_, + OP_REQUIRES(context, feature_ids(start_index, 0) == bias_feature_id_, errors::InvalidArgument("Bias feature ID missing.")); GradientStats root_gradient_stats(*gradients_t, *hessians_t, start_index); root_gradient_stats *= normalizer_ratio; @@ -519,7 +632,7 @@ class BuildCategoricalEqualitySplitsOp : public BaseBuildSplitOp { auto* equality_split = split_info.mutable_split_node() ->mutable_categorical_id_binary_split(); equality_split->set_feature_column(feature_column_group_id_); - equality_split->set_feature_id(feature_ids(best_feature_idx)); + equality_split->set_feature_id(feature_ids(best_feature_idx, 0)); auto* left_child = split_info.mutable_left_child(); auto* right_child = split_info.mutable_right_child(); FillLeaf(class_id, best_left_node_stats, left_child); diff --git a/tensorflow/contrib/boosted_trees/kernels/stats_accumulator_ops.cc b/tensorflow/contrib/boosted_trees/kernels/stats_accumulator_ops.cc index cff75e71d93cb703d87bb09a4b32439e01d70f76..a9a229c8ae0c26bba5f0a684dad7e546298577bb 100644 --- a/tensorflow/contrib/boosted_trees/kernels/stats_accumulator_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/stats_accumulator_ops.cc @@ -39,13 +39,14 @@ const char* const kStampTokenName = "stamp_token"; const char* const kNextStampTokenName = "next_stamp_token"; struct PartitionKey { - PartitionKey() : partition_id(-1), feature_id(-1) {} + PartitionKey() : partition_id(-1), feature_id(-1), dimension(-1) {} - PartitionKey(int32 p, int64 f) : partition_id(p), feature_id(f) {} + PartitionKey(int32 p, int64 f, int32 d) + : partition_id(p), feature_id(f), dimension(d) {} bool operator==(const PartitionKey& other) const { - return (feature_id == other.feature_id) && - (partition_id == other.partition_id); + return (partition_id == other.partition_id) && + (dimension == other.dimension) && (feature_id == other.feature_id); } // Compare for PartitionKey. @@ -54,7 +55,11 @@ struct PartitionKey { if (a.partition_id < b.partition_id) { return true; } - if ((a.partition_id == b.partition_id) && (a.feature_id < b.feature_id)) { + if ((a.partition_id == b.partition_id) && (a.dimension < b.dimension)) { + return true; + } + if ((a.partition_id == b.partition_id) && (a.dimension == b.dimension) && + (a.feature_id < b.feature_id)) { return true; } return false; @@ -64,8 +69,11 @@ struct PartitionKey { // Tree partition defined by traversing the tree to the leaf. int32 partition_id; - // Feature Id within the feature column. + // Feature column id. int64 feature_id; + + // Dimension within feature column. + int32 dimension; }; template @@ -132,12 +140,12 @@ void SerializeScalarAccumulatorToOutput( &partition_ids_t)); auto partition_ids = partition_ids_t->vec(); + // Feature ids tensor has ids of feature columns and their dimensions. Tensor* feature_ids_t = nullptr; - OP_REQUIRES_OK( - context, - context->allocate_output("output_feature_ids", TensorShape({num_slots}), - &feature_ids_t)); - auto feature_ids = feature_ids_t->vec(); + OP_REQUIRES_OK(context, context->allocate_output("output_feature_ids", + TensorShape({num_slots, 2}), + &feature_ids_t)); + auto feature_ids = feature_ids_t->matrix(); Tensor* gradients_t = nullptr; OP_REQUIRES_OK( @@ -155,7 +163,9 @@ void SerializeScalarAccumulatorToOutput( int i = 0; for (const auto& iter : accumulator_resource.values()) { partition_ids(i) = iter.first.partition_id; - feature_ids(i) = iter.first.feature_id; + feature_ids(i, 0) = iter.first.feature_id; + feature_ids(i, 1) = iter.first.dimension; + gradients(i) = iter.second.first; hessians(i) = iter.second.second; ++i; @@ -174,11 +184,10 @@ void SerializeTensorAccumulatorToOutput( auto partition_ids = partition_ids_t->vec(); Tensor* feature_ids_t = nullptr; - OP_REQUIRES_OK( - context, - context->allocate_output("output_feature_ids", TensorShape({num_slots}), - &feature_ids_t)); - auto feature_ids = feature_ids_t->vec(); + OP_REQUIRES_OK(context, context->allocate_output("output_feature_ids", + TensorShape({num_slots, 2}), + &feature_ids_t)); + auto feature_ids = feature_ids_t->matrix(); TensorShape gradient_shape = accumulator_resource.gradient_shape(); int64 num_gradient_elements = gradient_shape.num_elements(); @@ -201,7 +210,9 @@ void SerializeTensorAccumulatorToOutput( int i = 0; for (const auto& iter : accumulator_resource.values()) { partition_ids(i) = iter.first.partition_id; - feature_ids(i) = iter.first.feature_id; + feature_ids(i, 0) = iter.first.feature_id; + feature_ids(i, 1) = iter.first.dimension; + for (int j = 0; j < num_gradient_elements; ++j) { gradients(i, j) = iter.second.first[j]; } @@ -220,14 +231,16 @@ void AddToScalarAccumulator( 1); const TensorShape& partition_ids_shape = partition_ids_t.shape(); const auto& partition_ids = partition_ids_t.vec(); - const auto& feature_ids = feature_ids_t.vec(); + const auto& feature_ids_and_dimensions = feature_ids_t.matrix(); const auto& gradients = gradients_t.vec(); const auto& hessians = hessians_t.vec(); int64 num_updates = partition_ids_shape.dim_size(0); auto stats_map = accumulator_resource->mutable_values(); for (int64 i = 0; i < num_updates; ++i) { - const auto key = PartitionKey(partition_ids(i), feature_ids(i)); + const auto key = + PartitionKey(partition_ids(i), feature_ids_and_dimensions(i, 0), + feature_ids_and_dimensions(i, 1)); auto itr = stats_map->find(key); if (itr != stats_map->end()) { itr->second.first += gradients(i); @@ -263,7 +276,7 @@ void AddToTensorAccumulator( const TensorShape& partition_ids_shape = partition_ids_t.shape(); const auto& partition_ids = partition_ids_t.vec(); - const auto& feature_ids = feature_ids_t.vec(); + const auto& feature_ids_and_dimensions = feature_ids_t.matrix(); TensorShape gradients_shape = gradients_t.shape(); const auto& gradients = gradients_t.flat_outer_dims(); TensorShape hessians_shape = hessians_t.shape(); @@ -288,7 +301,9 @@ void AddToTensorAccumulator( int64 num_updates = partition_ids_shape.dim_size(0); auto stats_map = accumulator_resource->mutable_values(); for (int64 i = 0; i < num_updates; ++i) { - const auto key = PartitionKey(partition_ids(i), feature_ids(i)); + const auto key = + PartitionKey(partition_ids(i), feature_ids_and_dimensions(i, 0), + feature_ids_and_dimensions(i, 1)); auto itr = stats_map->find(key); if (itr == stats_map->end()) { std::vector new_gradients(gradients_shape.num_elements()); diff --git a/tensorflow/contrib/boosted_trees/kernels/training_ops.cc b/tensorflow/contrib/boosted_trees/kernels/training_ops.cc index 4c56718f1bbc0b42c1f5454ddfafe6ccd8c35c2c..c77d90e243c304ec8e9a10a0b63401f9bd825c3e 100644 --- a/tensorflow/contrib/boosted_trees/kernels/training_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/training_ops.cc @@ -208,27 +208,19 @@ class CenterTreeEnsembleBiasOp : public OpKernel { int64 next_stamp_token = next_stamp_token_t->scalar()(); CHECK(stamp_token != next_stamp_token); + // Update the ensemble stamp. + ensemble_resource->set_stamp(next_stamp_token); + // Get the delta updates. const Tensor* delta_updates_t; OP_REQUIRES_OK(context, context->input("delta_updates", &delta_updates_t)); - OP_REQUIRES( - context, - delta_updates_t->dim_size(0) + 1 == learner_config_.num_classes(), - errors::InvalidArgument( - "Delta updates size must be consistent with label dimensions.")); auto delta_updates = delta_updates_t->vec(); - - // Update the ensemble stamp. - ensemble_resource->set_stamp(next_stamp_token); + const int64 logits_dimension = delta_updates_t->dim_size(0); // Get the bias. - boosted_trees::trees::Leaf* const bias = RetrieveBias(ensemble_resource); + boosted_trees::trees::Leaf* const bias = + RetrieveBias(ensemble_resource, logits_dimension); CHECK(bias->has_vector()); - OP_REQUIRES( - context, - bias->vector().value_size() + 1 == learner_config_.num_classes(), - errors::InvalidArgument( - "Bias vector size must be consistent with label dimensions.")); // Update the bias. float total_delta = 0; @@ -245,6 +237,7 @@ class CenterTreeEnsembleBiasOp : public OpKernel { VLOG(1) << "Continuing to center bias, delta=" << total_delta; } else { VLOG(1) << "Done centering bias, delta=" << total_delta; + ensemble_resource->LastTreeMetadata()->set_is_finalized(true); } Tensor* continue_centering_t = nullptr; OP_REQUIRES_OK( @@ -256,7 +249,8 @@ class CenterTreeEnsembleBiasOp : public OpKernel { private: // Helper method to retrieve the bias from the tree ensemble. boosted_trees::trees::Leaf* RetrieveBias( - boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource) { + boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource, + int64 logits_dimension) { const int32 num_trees = ensemble_resource->num_trees(); if (num_trees <= 0) { // Add a new bias leaf. @@ -264,10 +258,9 @@ class CenterTreeEnsembleBiasOp : public OpKernel { boosted_trees::trees::DecisionTreeConfig* const tree_config = ensemble_resource->AddNewTree(1.0); auto* const leaf = tree_config->add_nodes()->mutable_leaf(); - for (size_t idx = 0; idx + 1 < learner_config_.num_classes(); ++idx) { + for (size_t idx = 0; idx < logits_dimension; ++idx) { leaf->mutable_vector()->add_value(0.0); } - ensemble_resource->LastTreeMetadata()->set_is_finalized(true); return leaf; } else if (num_trees == 1) { // Confirms that the only tree is a bias and returns its leaf. diff --git a/tensorflow/contrib/boosted_trees/lib/BUILD b/tensorflow/contrib/boosted_trees/lib/BUILD index 107ff0d295bee530c1711a97849fbd3c6cdb2f00..131bd48562a55a08981ac73277e93024db0d85d3 100644 --- a/tensorflow/contrib/boosted_trees/lib/BUILD +++ b/tensorflow/contrib/boosted_trees/lib/BUILD @@ -406,51 +406,9 @@ tf_cc_test( ) # Learner/stochastic - -cc_library( - name = "feature-column-handlers", - srcs = [ - "learner/stochastic/handlers/bias-feature-column-handler.cc", - "learner/stochastic/handlers/categorical-feature-column-handler.cc", - "learner/stochastic/handlers/dense-quantized-feature-column-handler.cc", - "learner/stochastic/handlers/sparse-quantized-feature-column-handler.cc", - ], - hdrs = [ - "learner/stochastic/handlers/bias-feature-column-handler.h", - "learner/stochastic/handlers/categorical-feature-column-handler.h", - "learner/stochastic/handlers/dense-quantized-feature-column-handler.h", - "learner/stochastic/handlers/feature-column-handler.h", - "learner/stochastic/handlers/sparse-quantized-feature-column-handler.h", - ], - deps = [ - ":feature-split-candidate", - ":feature-stats-accumulator", - "//tensorflow/contrib/boosted_trees/proto:learner_proto_cc", - "//tensorflow/core:framework_headers_lib", - "//tensorflow/core:protos_all_cc", - ], -) - -tf_cc_test( - name = "feature-column-handlers_test", - size = "small", - srcs = [ - "learner/stochastic/handlers/bias-feature-column-handler_test.cc", - "learner/stochastic/handlers/categorical-feature-column-handler_test.cc", - "learner/stochastic/handlers/dense-quantized-feature-column-handler_test.cc", - "learner/stochastic/handlers/sparse-quantized-feature-column-handler_test.cc", - ], - deps = [ - ":feature-column-handlers", - "//tensorflow/core:tensor_testutil", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - ], -) - cc_library( name = "gradient-stats", - hdrs = ["learner/stochastic/stats/gradient-stats.h"], + hdrs = ["learner/common/stats/gradient-stats.h"], deps = [ "//tensorflow/core:framework_headers_lib", "//third_party/eigen3", @@ -459,7 +417,7 @@ cc_library( cc_library( name = "node-stats", - hdrs = ["learner/stochastic/stats/node-stats.h"], + hdrs = ["learner/common/stats/node-stats.h"], deps = [ ":gradient-stats", "//tensorflow/contrib/boosted_trees/proto:learner_proto_cc", @@ -471,7 +429,7 @@ cc_library( cc_library( name = "split-stats", - hdrs = ["learner/stochastic/stats/split-stats.h"], + hdrs = ["learner/common/stats/split-stats.h"], deps = [ ":node-stats", ], @@ -479,7 +437,7 @@ cc_library( cc_library( name = "feature-split-candidate", - hdrs = ["learner/stochastic/stats/feature-split-candidate.h"], + hdrs = ["learner/common/stats/feature-split-candidate.h"], deps = [ ":split-stats", "//tensorflow/contrib/boosted_trees/proto:tree_config_proto_cc", @@ -489,7 +447,7 @@ cc_library( tf_cc_test( name = "node-stats_test", size = "small", - srcs = ["learner/stochastic/stats/node-stats_test.cc"], + srcs = ["learner/common/stats/node-stats_test.cc"], deps = [ ":node-stats", "//tensorflow/core:tensor_testutil", diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py index 83dad7e4b3301327bcbae5203e9d9330c9e0084d..9f78ab20242800fd8af7ad049d5970fbe26ec0ea 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py @@ -110,8 +110,8 @@ class EqualitySplitHandler(base_split_handler.BaseSplitHandler): def not_active_inputs(): return (constant_op.constant([], dtype=dtypes.int32), - constant_op.constant([], dtype=dtypes.int64), empty_gradients, - empty_hessians) + constant_op.constant([], dtype=dtypes.int64, shape=[1, 2]), + empty_gradients, empty_hessians) def active_inputs(): """The normal flow when the handler is active.""" @@ -154,7 +154,12 @@ class EqualitySplitHandler(base_split_handler.BaseSplitHandler): [per_partition_hessians, filtered_hessians], 0) feature_ids = array_ops.concat( [bias_feature_ids, self._sparse_int_column.values], 0) - return partition_ids, feature_ids, filtered_gradients, filtered_hessians + # Dimension is always zero for sparse int features. + dimension_ids = array_ops.zeros_like(feature_ids, dtype=dtypes.int64) + feature_ids_and_dimensions = array_ops.stack( + [feature_ids, dimension_ids], axis=1) + return (partition_ids, feature_ids_and_dimensions, filtered_gradients, + filtered_hessians) partition_ids, feature_ids, gradients_out, hessians_out = ( control_flow_ops.cond(is_active[0], active_inputs, not_active_inputs)) diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py index 8c0a3f0d91e0fbd6b6ca02352c8b80b8485d029d..72e20aaa127cda592bd314786cddb925cc87a075 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py @@ -257,6 +257,7 @@ class DenseSplitHandler(InequalitySplitHandler): # Put quantile and stats accumulator flushing in the dependency path. are_splits_ready = control_flow_ops.with_dependencies( [flush_quantiles, partition_ids], are_splits_ready) + partition_ids, gains, split_infos = ( split_handler_ops.build_dense_inequality_splits( num_minibatches=num_minibatches, @@ -433,14 +434,15 @@ def dense_make_stats_update(is_active, are_buckets_ready, float_column, def ready_inputs_fn(): """Branch to execute when quantiles are ready.""" quantized_feature = quantile_ops.quantiles([float_column], [], - [quantile_buckets], []) + [quantile_buckets], [], []) quantized_feature = math_ops.cast(quantized_feature[0], dtypes.int64) - quantized_feature = array_ops.reshape(quantized_feature, [-1]) + quantized_feature = array_ops.squeeze(quantized_feature) return (example_partition_ids, quantized_feature, gradients, hessians) def not_ready_inputs_fn(): - return (constant_op.constant([], dtype=dtypes.int32), constant_op.constant( - [], dtype=dtypes.int64), empty_gradients, empty_hessians) + return (constant_op.constant([], dtype=dtypes.int32), + constant_op.constant([[]], dtype=dtypes.int64, shape=[1, 2]), + empty_gradients, empty_hessians) example_partition_ids, feature_ids, gradients, hessians = ( control_flow_ops.cond( @@ -461,10 +463,13 @@ def sparse_make_stats_update( def quantiles_ready(): """The subgraph for when the quantiles are ready.""" - quantized_feature = quantile_ops.quantiles([sparse_column_values], [], - [quantile_buckets], []) - quantized_feature = math_ops.cast(quantized_feature[0], dtypes.int64) - quantized_feature = array_ops.reshape(quantized_feature, [-1]) + quantized_feature = quantile_ops.quantiles([], [sparse_column_values], [], + [quantile_buckets], + [sparse_column_indices]) + + quantized_feature = math_ops.cast(quantized_feature[1], dtypes.int64) + quantized_feature = array_ops.squeeze(quantized_feature) + example_indices, _ = array_ops.split( sparse_column_indices, num_or_size_splits=2, axis=1) example_indices = array_ops.squeeze(example_indices, [1]) @@ -486,19 +491,25 @@ def sparse_make_stats_update( bias_feature_ids = array_ops.fill( array_ops.shape(unique_partitions), _BIAS_FEATURE_ID) bias_feature_ids = math_ops.cast(bias_feature_ids, dtypes.int64) + zeros = array_ops.zeros_like(bias_feature_ids) + bias_feature_ids = array_ops.stack([bias_feature_ids, zeros], axis=1) + partition_ids = array_ops.concat( [unique_partitions, filtered_partition_ids], 0) filtered_gradients = array_ops.concat( [per_partition_gradients, filtered_gradients], 0) filtered_hessians = array_ops.concat( [per_partition_hessians, filtered_hessians], 0) + bucket_ids = array_ops.concat([bias_feature_ids, quantized_feature], 0) + return partition_ids, bucket_ids, filtered_gradients, filtered_hessians def quantiles_not_ready(): """The subgraph for when the quantiles are not ready.""" - return (constant_op.constant([], dtype=dtypes.int32), constant_op.constant( - [], dtype=dtypes.int64), empty_gradients, empty_hessians) + return (constant_op.constant([], dtype=dtypes.int32), + constant_op.constant([], dtype=dtypes.int64, shape=[1, 2]), + empty_gradients, empty_hessians) empty_float = constant_op.constant([], dtype=dtypes.float32) handler_not_active = (constant_op.constant( diff --git a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/feature-split-candidate.h b/tensorflow/contrib/boosted_trees/lib/learner/common/stats/feature-split-candidate.h similarity index 90% rename from tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/feature-split-candidate.h rename to tensorflow/contrib/boosted_trees/lib/learner/common/stats/feature-split-candidate.h index fe22691178213094b9affcdee06af98011f85bd2..339c2e0fded10e6a7b140da62e152e2868ffd164 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/feature-split-candidate.h +++ b/tensorflow/contrib/boosted_trees/lib/learner/common/stats/feature-split-candidate.h @@ -13,10 +13,10 @@ // limitations under the License. // // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_STATS_FEATURE_SPLIT_CANDIDATE_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_STATS_FEATURE_SPLIT_CANDIDATE_H_ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_FEATURE_SPLIT_CANDIDATE_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_FEATURE_SPLIT_CANDIDATE_H_ -#include "tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/split-stats.h" +#include "tensorflow/contrib/boosted_trees/lib/learner/common/stats/split-stats.h" #include "tensorflow/contrib/boosted_trees/proto/tree_config.pb.h" namespace tensorflow { @@ -58,4 +58,4 @@ struct FeatureSplitCandidate { } // namespace boosted_trees } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_STATS_FEATURE_SPLIT_CANDIDATE_H_ +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_FEATURE_SPLIT_CANDIDATE_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/gradient-stats.h b/tensorflow/contrib/boosted_trees/lib/learner/common/stats/gradient-stats.h similarity index 98% rename from tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/gradient-stats.h rename to tensorflow/contrib/boosted_trees/lib/learner/common/stats/gradient-stats.h index dad64bf165a41bc4f32eea6b37e7afb569887a06..34e3ddb777242553d62035a51f1aec33d0f9ba54 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/gradient-stats.h +++ b/tensorflow/contrib/boosted_trees/lib/learner/common/stats/gradient-stats.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_STATS_GRADIENT_STATS_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_STATS_GRADIENT_STATS_H_ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_GRADIENT_STATS_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_GRADIENT_STATS_H_ #include @@ -190,4 +190,4 @@ inline GradientStats operator-(const GradientStats& a, const GradientStats& b) { } // namespace boosted_trees } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_STATS_GRADIENT_STATS_H_ +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_GRADIENT_STATS_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/node-stats.h b/tensorflow/contrib/boosted_trees/lib/learner/common/stats/node-stats.h similarity index 98% rename from tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/node-stats.h rename to tensorflow/contrib/boosted_trees/lib/learner/common/stats/node-stats.h index 4e5f53874df2207ffa6664a33675f84ef055394b..642a183aec5c7e591579fa5ee91d45729bfb624d 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/node-stats.h +++ b/tensorflow/contrib/boosted_trees/lib/learner/common/stats/node-stats.h @@ -12,12 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_STATS_NODE_STATS_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_STATS_NODE_STATS_H_ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_NODE_STATS_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_NODE_STATS_H_ #include "third_party/eigen3/Eigen/Core" #include "third_party/eigen3/Eigen/Eigenvalues" -#include "tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/gradient-stats.h" +#include "tensorflow/contrib/boosted_trees/lib/learner/common/stats/gradient-stats.h" #include "tensorflow/contrib/boosted_trees/proto/learner.pb.h" #include "tensorflow/contrib/boosted_trees/proto/tree_config.pb.h" #include "tensorflow/core/framework/shape_inference.h" @@ -298,4 +298,4 @@ struct NodeStats { } // namespace boosted_trees } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_STATS_NODE_STATS_H_ +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_NODE_STATS_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/node-stats_test.cc b/tensorflow/contrib/boosted_trees/lib/learner/common/stats/node-stats_test.cc similarity index 99% rename from tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/node-stats_test.cc rename to tensorflow/contrib/boosted_trees/lib/learner/common/stats/node-stats_test.cc index ecb7a04efb96248210d9af770c8377b7f6906598..f867e77d3ef0609774628b2a9c36ca52bcf2a957 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/node-stats_test.cc +++ b/tensorflow/contrib/boosted_trees/lib/learner/common/stats/node-stats_test.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= -#include "tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/node-stats.h" +#include "tensorflow/contrib/boosted_trees/lib/learner/common/stats/node-stats.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/platform/test.h" diff --git a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/split-stats.h b/tensorflow/contrib/boosted_trees/lib/learner/common/stats/split-stats.h similarity index 94% rename from tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/split-stats.h rename to tensorflow/contrib/boosted_trees/lib/learner/common/stats/split-stats.h index f700cbced833543227de39f54c9ecbb03a7ce7c9..054ccd9a8cd0be0c48b14cca013f15677deba900 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/split-stats.h +++ b/tensorflow/contrib/boosted_trees/lib/learner/common/stats/split-stats.h @@ -12,12 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_STATS_SPLIT_STATS_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_STATS_SPLIT_STATS_H_ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_SPLIT_STATS_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_SPLIT_STATS_H_ #include -#include "tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/node-stats.h" +#include "tensorflow/contrib/boosted_trees/lib/learner/common/stats/node-stats.h" namespace tensorflow { namespace boosted_trees { @@ -81,4 +81,4 @@ struct SplitStats { } // namespace boosted_trees } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_STATS_SPLIT_STATS_H_ +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_SPLIT_STATS_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/bias-feature-column-handler.cc b/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/bias-feature-column-handler.cc deleted file mode 100644 index b880cf2c47989b1434f17802befb7dd7c248b36f..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/bias-feature-column-handler.cc +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright 2017 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -#include "tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/bias-feature-column-handler.h" - -namespace tensorflow { -namespace boosted_trees { -namespace learner { -namespace stochastic { - -void BiasFeatureColumnHandler::AggregateGradientStats( - const std::vector& example_partition_ids, - const Tensor& example_first_order_gradients, - const Tensor& example_second_order_gradients, - FeatureStatsAccumulator* - gradient_stats_accumulator) const { - // Pass over all examples and aggregate gradient stats for each sub-root. - for (int64 example_idx = 0; example_idx < batch_size_; ++example_idx) { - auto partition_id = example_partition_ids[example_idx]; - gradient_stats_accumulator->AddStats( - slot_id_, class_id_, partition_id, kBiasFeatureId, - GradientStats(example_first_order_gradients, - example_second_order_gradients, example_idx)); - } -} - -void BiasFeatureColumnHandler::GenerateFeatureSplitCandidates( - const LearnerConfig& learner_config, const std::vector& roots, - const std::vector& root_stats, - const FeatureStatsAccumulator& - gradient_stats_accumulator, - std::vector* split_candidates) const { - split_candidates->clear(); - split_candidates->reserve(roots.size()); - boosted_trees::trees::TreeNode tree_node; - for (size_t root_idx = 0; root_idx < roots.size(); ++root_idx) { - const NodeStats& root_node_stats = root_stats[root_idx]; - tree_node.Clear(); - root_node_stats.FillLeaf(class_id_, tree_node.mutable_leaf()); - split_candidates->emplace_back(slot_id_, tree_node, - SplitStats(learner_config, root_node_stats)); - } -} - -} // namespace stochastic -} // namespace learner -} // namespace boosted_trees -} // namespace tensorflow diff --git a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/bias-feature-column-handler.h b/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/bias-feature-column-handler.h deleted file mode 100644 index 5c0f99185a63db33a391a98fa16f37bef99507c9..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/bias-feature-column-handler.h +++ /dev/null @@ -1,57 +0,0 @@ -// Copyright 2017 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_HANDLERS_H_ // NOLINT -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_HANDLERS_H_ // NOLINT - -#include "tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/feature-column-handler.h" - -namespace tensorflow { -namespace boosted_trees { -namespace learner { -namespace stochastic { - -// Handler for a bias feature column in the single class case. -// This handler is useful even if we don't introduce a bias feature because -// it allows us to aggregate stats per partition which in turn allows us -// to compute node stats for each root to split. -class BiasFeatureColumnHandler : public FeatureColumnHandler { - public: - BiasFeatureColumnHandler(const uint32 class_id, const uint32 slot_id, - const int64 batch_size) - : FeatureColumnHandler(class_id, slot_id, batch_size) {} - - void AggregateGradientStats( - const std::vector& example_partition_ids, - const Tensor& example_first_order_gradients, - const Tensor& example_second_order_gradients, - FeatureStatsAccumulator* - gradient_stats_accumulator) const override; - - void GenerateFeatureSplitCandidates( - const LearnerConfig& learner_config, const std::vector& roots, - const std::vector& root_stats, - const FeatureStatsAccumulator& - gradient_stats_accumulator, - std::vector* split_candidates) const override; - - static constexpr auto kBiasFeatureId = 0; -}; - -} // namespace stochastic -} // namespace learner -} // namespace boosted_trees -} // namespace tensorflow - -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_HANDLERS_H_ // NOLINT diff --git a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/bias-feature-column-handler_test.cc b/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/bias-feature-column-handler_test.cc deleted file mode 100644 index f4c7df7fabda1a38d7e6cca4c5c8bc81cb7551b1..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/bias-feature-column-handler_test.cc +++ /dev/null @@ -1,135 +0,0 @@ -// Copyright 2017 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -#include "tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/bias-feature-column-handler.h" - -#include "tensorflow/core/framework/tensor_testutil.h" -#include "tensorflow/core/platform/test.h" - -namespace tensorflow { -namespace boosted_trees { -namespace learner { -namespace stochastic { -namespace { - -using boosted_trees::learner::LearnerConfig; - -const auto kClassId = 7; -const auto kSlotId = 0; -const auto kBatchSize = 4; - -using FeatureStatsAccumulator = - FeatureStatsAccumulator; - -class BiasFeatureColumnHandlerTest : public ::testing::Test { - protected: - BiasFeatureColumnHandlerTest() - : example_first_order_gradients_( - test::AsTensor({0.2f, -0.5f, 1.2f, 4.0f}, {4})), - example_second_order_gradients_( - test::AsTensor({0.12f, 0.07f, 0.2f, 0.13f}, {4})), - example_partitions_({0, 0, 1, 3}) { - // Set L2 regularization. - learner_config_.mutable_regularization()->set_l2(2.0f); - learner_config_.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS); - - // Create handler. - handler_.reset(new BiasFeatureColumnHandler(kClassId, kSlotId, kBatchSize)); - } - - LearnerConfig learner_config_; - const Tensor example_first_order_gradients_; - const Tensor example_second_order_gradients_; - const std::vector example_partitions_; - std::unique_ptr handler_; -}; - -TEST_F(BiasFeatureColumnHandlerTest, AggregateGradientStats) { - // Create handler. - FeatureStatsAccumulator accumulator(1); - handler_->AggregateGradientStats( - example_partitions_, example_first_order_gradients_, - example_second_order_gradients_, &accumulator); - - // Check stats for each partition. - // Partition 0. - EXPECT_GRADIENT_STATS_EQ( - GradientStats(-0.3f, 0.19f), - accumulator.GetStats(kSlotId, kClassId, 0, - BiasFeatureColumnHandler::kBiasFeatureId)); - // Partition 1. - EXPECT_GRADIENT_STATS_EQ( - GradientStats(1.2f, 0.2f), - accumulator.GetStats(kSlotId, kClassId, 1, - BiasFeatureColumnHandler::kBiasFeatureId)); - // Partition 2. - EXPECT_GRADIENT_STATS_EQ( - GradientStats(0.0f, 0.0f), - accumulator.GetStats(kSlotId, kClassId, 2, - BiasFeatureColumnHandler::kBiasFeatureId)); - // Partition 3. - EXPECT_GRADIENT_STATS_EQ( - GradientStats(4.0f, 0.13f), - accumulator.GetStats(kSlotId, kClassId, 3, - BiasFeatureColumnHandler::kBiasFeatureId)); -} - -TEST_F(BiasFeatureColumnHandlerTest, GenerateFeatureSplitCandidates) { - // Create handler. - FeatureStatsAccumulator accumulator(1); - handler_->AggregateGradientStats( - example_partitions_, example_first_order_gradients_, - example_second_order_gradients_, &accumulator); - - // Get feature split candidates for two roots 0 and 3. - // Root 0 has zero gain and root 3 has the same gain as the leaf. - const std::vector roots = {0, 3}; - const std::vector& root_stats = { - NodeStats(1), NodeStats(learner_config_, GradientStats(4.0f, 0.13f))}; - std::vector split_candidates; - handler_->GenerateFeatureSplitCandidates(learner_config_, roots, root_stats, - accumulator, &split_candidates); - // Expect two candidate splits (one per root). - EXPECT_EQ(2, split_candidates.size()); - - // Verify first candidate for root 0, gain is expected to be the same as - // the left child since the root node gain is zero. - const SplitStats expected_split_stats0(learner_config_, root_stats[0]); - EXPECT_SPLIT_STATS_EQ(expected_split_stats0, split_candidates[0].split_stats); - const auto& tree_node0 = split_candidates[0].tree_node; - EXPECT_EQ(boosted_trees::trees::TreeNode::kLeaf, tree_node0.node_case()); - EXPECT_EQ(1, tree_node0.leaf().sparse_vector().index_size()); - EXPECT_EQ(kClassId, tree_node0.leaf().sparse_vector().index(0)); - EXPECT_EQ(1, tree_node0.leaf().sparse_vector().value_size()); - EXPECT_EQ(root_stats[0].weight_contribution[0], - tree_node0.leaf().sparse_vector().value(0)); - - // Verify second candidate for root 3, gain is expected to be zero as - // the left child gain is equal to the parent gain. - const SplitStats expected_split_stats1(learner_config_, root_stats[1]); - EXPECT_SPLIT_STATS_EQ(expected_split_stats1, split_candidates[1].split_stats); - const auto& tree_node1 = split_candidates[1].tree_node; - EXPECT_EQ(boosted_trees::trees::TreeNode::kLeaf, tree_node1.node_case()); - EXPECT_EQ(1, tree_node1.leaf().sparse_vector().index_size()); - EXPECT_EQ(kClassId, tree_node1.leaf().sparse_vector().index(0)); - EXPECT_EQ(1, tree_node1.leaf().sparse_vector().value_size()); - EXPECT_EQ(root_stats[1].weight_contribution[0], - tree_node1.leaf().sparse_vector().value(0)); -} - -} // namespace -} // namespace stochastic -} // namespace learner -} // namespace boosted_trees -} // namespace tensorflow diff --git a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/categorical-feature-column-handler.cc b/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/categorical-feature-column-handler.cc deleted file mode 100644 index 3a6c409f846c9ca0bd6b5101e96447642b949978..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/categorical-feature-column-handler.cc +++ /dev/null @@ -1,140 +0,0 @@ -// Copyright 2017 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -#include "tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/categorical-feature-column-handler.h" - -#include "tensorflow/core/platform/macros.h" - -namespace tensorflow { -namespace boosted_trees { -namespace learner { -namespace stochastic { - -namespace { - -// Creates a categorical Id split node without assigning children. -boosted_trees::trees::TreeNode CreateCategoricalIdNode( - const int32 feature_column, const int32 id) { - boosted_trees::trees::TreeNode split_node; - auto* split = split_node.mutable_categorical_id_binary_split(); - split->set_feature_column(feature_column); - split->set_feature_id(id); - return split_node; -} - -} // namespace - -void CategoricalFeatureColumnHandler::AggregateGradientStats( - const std::vector& example_partition_ids, - const Tensor& example_first_order_gradients, - const Tensor& example_second_order_gradients, - FeatureStatsAccumulator* - gradient_stats_accumulator) const { - // Pass over all rows and aggregate gradient stats for each feature id. - const int64 num_rows = indices_.dimension(0); - for (int64 row_idx = 0; row_idx < num_rows; ++row_idx) { - auto example_idx = indices_(row_idx, 0); - auto feature_id = values_(row_idx); - const GradientStats norm_gradient_stats(example_first_order_gradients, - example_second_order_gradients, - example_idx); - auto partition_id = example_partition_ids[example_idx]; - gradient_stats_accumulator->AddStats(slot_id_, class_id_, partition_id, - feature_id, norm_gradient_stats); - } -} - -void CategoricalFeatureColumnHandler::GenerateFeatureSplitCandidates( - const LearnerConfig& learner_config, const std::vector& roots, - const std::vector& root_stats, - const FeatureStatsAccumulator& - gradient_stats_accumulator, - std::vector* split_candidates) const { - // Build a reverse lookup of partition id to root idx. - std::unordered_map partition_id_to_root_idx; - partition_id_to_root_idx.reserve(roots.size()); - for (size_t root_idx = 0; root_idx < roots.size(); ++root_idx) { - partition_id_to_root_idx[roots[root_idx]] = root_idx; - } - - // Initialize split candidates. - split_candidates->clear(); - if (!roots.empty()) { - FeatureSplitCandidate empty_candidate( - root_stats[0].weight_contribution.size()); - split_candidates->resize(roots.size(), empty_candidate); - } - for (auto& split_candidate : *split_candidates) { - split_candidate.split_stats.gain = std::numeric_limits::lowest(); - } - - // Evaluate split candidates for every root as each is a separate - // logical partition over the examples. - // Then for each root, we evaluate every feature id as an equality split - // and pick the highest split gain. - for (const auto& entry : - gradient_stats_accumulator.GetFeatureStats(slot_id_)) { - DCHECK_EQ(entry.first.class_id, class_id_); - - // Get partition id and root node stats. - const int32 partition_id = entry.first.partition_id; - auto root_idx_it = partition_id_to_root_idx.find(partition_id); - if (root_idx_it == partition_id_to_root_idx.end()) { - // Inactive partition. - continue; - } - size_t root_idx = root_idx_it->second; - const NodeStats& root_node_stats = root_stats[root_idx]; - - // Get gradient stats. - const auto& left_gradient_stats = entry.second; - auto right_gradient_stats = - root_node_stats.gradient_stats - left_gradient_stats; - - // Get node stats. - NodeStats left_node_stats(learner_config, left_gradient_stats); - NodeStats right_node_stats(learner_config, right_gradient_stats); - - // Generate split candidate and update best split candidate for the - // current root if needed. - FeatureSplitCandidate split_candidate( - slot_id_, - CreateCategoricalIdNode(feature_column_, entry.first.feature_id), - SplitStats(learner_config, root_node_stats, left_node_stats, - right_node_stats)); - FeatureSplitCandidate& best_split_candidate = (*split_candidates)[root_idx]; - if (TF_PREDICT_FALSE(best_split_candidate.tree_node.node_case() == - boosted_trees::trees::TreeNode::NODE_NOT_SET)) { - // Always replace candidates with no node set. - best_split_candidate = std::move(split_candidate); - } else if (TF_PREDICT_FALSE(split_candidate.split_stats.gain == - best_split_candidate.split_stats.gain)) { - // Tie break on feature id. - auto best_split_feature_id = - best_split_candidate.tree_node.categorical_id_binary_split() - .feature_id(); - if (entry.first.feature_id < best_split_feature_id) { - best_split_candidate = std::move(split_candidate); - } - } else if (split_candidate.split_stats.gain > - best_split_candidate.split_stats.gain) { - best_split_candidate = std::move(split_candidate); - } - } -} - -} // namespace stochastic -} // namespace learner -} // namespace boosted_trees -} // namespace tensorflow diff --git a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/categorical-feature-column-handler.h b/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/categorical-feature-column-handler.h deleted file mode 100644 index ef964ba716c6adf9cf9c291cca5f52f7a6efe26f..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/categorical-feature-column-handler.h +++ /dev/null @@ -1,64 +0,0 @@ -// Copyright 2017 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= - -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_HANDLERS_CATEGORICAL_FEATURE_COLUMN_HANDLER_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_HANDLERS_CATEGORICAL_FEATURE_COLUMN_HANDLER_H_ - -#include "tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/feature-column-handler.h" - -namespace tensorflow { -namespace boosted_trees { -namespace learner { -namespace stochastic { - -// Handler for a categorical feature column in the single class case. -class CategoricalFeatureColumnHandler : public FeatureColumnHandler { - public: - CategoricalFeatureColumnHandler(const int32 class_id, const int32 slot_id, - const int64 batch_size, - const int32 feature_column, - TTypes::ConstMatrix indices, - TTypes::ConstVec values) - : FeatureColumnHandler(class_id, slot_id, batch_size), - feature_column_(feature_column), - indices_(indices), - values_(values) {} - - void AggregateGradientStats( - const std::vector& example_partition_ids, - const Tensor& example_first_order_gradients, - const Tensor& example_second_order_gradients, - FeatureStatsAccumulator* - gradient_stats_accumulator) const override; - - void GenerateFeatureSplitCandidates( - const LearnerConfig& learner_config, const std::vector& roots, - const std::vector& root_stats, - const FeatureStatsAccumulator& - gradient_stats_accumulator, - std::vector* split_candidates) const override; - - protected: - const int32 feature_column_; - TTypes::ConstMatrix indices_; - TTypes::ConstVec values_; -}; - -} // namespace stochastic -} // namespace learner -} // namespace boosted_trees -} // namespace tensorflow - -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_HANDLERS_CATEGORICAL_FEATURE_COLUMN_HANDLER_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/categorical-feature-column-handler_test.cc b/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/categorical-feature-column-handler_test.cc deleted file mode 100644 index ea82b3f086d24dc1f9ceb4783abd68be35b34b00..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/categorical-feature-column-handler_test.cc +++ /dev/null @@ -1,165 +0,0 @@ -// Copyright 2017 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -#include "tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/categorical-feature-column-handler.h" - -#include "tensorflow/core/framework/tensor_testutil.h" -#include "tensorflow/core/platform/test.h" - -namespace tensorflow { -namespace boosted_trees { -namespace learner { -namespace stochastic { -namespace { - -using boosted_trees::learner::LearnerConfig; - -const auto kClassId = 7; -const auto kSlotId = 0; -const auto kBatchSize = 4; -const auto kFeatureColumn = 3; - -using FeatureStatsAccumulator = - FeatureStatsAccumulator; - -class CategoricalFeatureColumnHandlerTest : public ::testing::Test { - protected: - // The data looks like the following: - // Example | Gradients | Partition | Feature Id | - // i0 | (0.2, 0.12) | 0 | 1,2 | - // i1 | (-0.5, 0.07) | 0 | | - // i2 | (1.2, 0.2) | 0 | 2 | - // i3 | (4.0, 0.13) | 1 | 0 | - CategoricalFeatureColumnHandlerTest() - : example_first_order_gradients_( - test::AsTensor({0.2f, -0.5f, 1.2f, 4.0f}, {4})), - example_second_order_gradients_( - test::AsTensor({0.12f, 0.07f, 0.2f, 0.13f}, {4})), - example_partitions_({0, 0, 0, 1}), - indices_(test::AsTensor({0, 0, 0, 1, 2, 0, 3, 0}, {4, 2})), - values_(test::AsTensor({1, 2, 2, 0}, {4})) { - // Set L2 regularization. - learner_config_.mutable_regularization()->set_l2(2.0f); - learner_config_.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS); - // Create handler. - handler_.reset(new CategoricalFeatureColumnHandler( - kClassId, kSlotId, kBatchSize, kFeatureColumn, indices_.matrix(), - values_.vec())); - } - - LearnerConfig learner_config_; - const Tensor example_first_order_gradients_; - const Tensor example_second_order_gradients_; - const std::vector example_partitions_; - const Tensor indices_; - const Tensor values_; - std::unique_ptr handler_; -}; - -TEST_F(CategoricalFeatureColumnHandlerTest, AggregateGradientStats) { - // Create handler. - FeatureStatsAccumulator accumulator(1); - handler_->AggregateGradientStats( - example_partitions_, example_first_order_gradients_, - example_second_order_gradients_, &accumulator); - - // Check stats for each partition and feature. - // Partition 0, Feature 0. - EXPECT_GRADIENT_STATS_EQ(GradientStats(0.0f, 0.0f), - accumulator.GetStats(kSlotId, kClassId, 0, 0)); - // Partition 0, Feature 1. - EXPECT_GRADIENT_STATS_EQ(GradientStats(0.2f, 0.12f), - accumulator.GetStats(kSlotId, kClassId, 0, 1)); - // Partition 0, Feature 2. - EXPECT_GRADIENT_STATS_EQ(GradientStats(0.2f + 1.2f, 0.12f + 0.2f), - accumulator.GetStats(kSlotId, kClassId, 0, 2)); - - // Partition 1, Feature 0. - EXPECT_GRADIENT_STATS_EQ(GradientStats(4.0f, 0.13f), - accumulator.GetStats(kSlotId, kClassId, 1, 0)); - // Partition 1, Feature 1. - EXPECT_GRADIENT_STATS_EQ(GradientStats(0.0f, 0.0f), - accumulator.GetStats(kSlotId, kClassId, 1, 1)); - // Partition 1, Feature 2. - EXPECT_GRADIENT_STATS_EQ(GradientStats(0.0f, 0.0f), - accumulator.GetStats(kSlotId, kClassId, 1, 2)); -} - -TEST_F(CategoricalFeatureColumnHandlerTest, GenerateFeatureSplitCandidates) { - // Create handler. - FeatureStatsAccumulator accumulator(1); - handler_->AggregateGradientStats( - example_partitions_, example_first_order_gradients_, - example_second_order_gradients_, &accumulator); - - // Get feature split candidates for two roots 0 and 1. - // The root stats are derived from the per-partition total gradient stats. - const std::vector roots = {0, 1, 5}; - const std::vector& root_stats = { - NodeStats(learner_config_, GradientStats(0.9f, 0.39f)), - NodeStats(learner_config_, GradientStats(4.0f, 0.13f)), NodeStats(1)}; - std::vector split_candidates; - handler_->GenerateFeatureSplitCandidates(learner_config_, roots, root_stats, - accumulator, &split_candidates); - // Expect three candidate splits (one per root). - EXPECT_EQ(3, split_candidates.size()); - - // Verify candidate for root 0, the best split occurs when we route - // example i0, i2 left and i1 right. - const NodeStats expected_left_node0(learner_config_, - GradientStats(0.2f + 1.2f, 0.12f + 0.2f)); - const NodeStats expected_right_node0( - learner_config_, - root_stats[0].gradient_stats - expected_left_node0.gradient_stats); - const SplitStats expected_split_stats0(learner_config_, root_stats[0], - expected_left_node0, - expected_right_node0); - EXPECT_SPLIT_STATS_EQ(expected_split_stats0, split_candidates[0].split_stats); - - const auto& tree_node0 = split_candidates[0].tree_node; - EXPECT_EQ( - boosted_trees::trees::TreeNode::kCategoricalIdBinarySplitFieldNumber, - tree_node0.node_case()); - const auto& split0 = tree_node0.categorical_id_binary_split(); - EXPECT_EQ(2, split0.feature_id()); - EXPECT_EQ(kFeatureColumn, split0.feature_column()); - - // Verify candidate for root 1, there's only one active feature here - // so zero gain is expected. - const NodeStats expected_left_node1(learner_config_, - root_stats[1].gradient_stats); - const NodeStats expected_right_node1(learner_config_, GradientStats(0, 0)); - const SplitStats expected_split_stats1(learner_config_, root_stats[1], - expected_left_node1, - expected_right_node1); - EXPECT_SPLIT_STATS_EQ(expected_split_stats1, split_candidates[1].split_stats); - const auto& tree_node1 = split_candidates[1].tree_node; - EXPECT_EQ( - boosted_trees::trees::TreeNode::kCategoricalIdBinarySplitFieldNumber, - tree_node1.node_case()); - const auto& split1 = tree_node1.categorical_id_binary_split(); - EXPECT_EQ(0, split1.feature_id()); - EXPECT_EQ(kFeatureColumn, split1.feature_column()); - - // Verify there are no candidate splits for root 5. - const auto& tree_node2 = split_candidates[2].tree_node; - EXPECT_EQ(boosted_trees::trees::TreeNode::NODE_NOT_SET, - tree_node2.node_case()); -} - -} // namespace -} // namespace stochastic -} // namespace learner -} // namespace boosted_trees -} // namespace tensorflow diff --git a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/dense-quantized-feature-column-handler.cc b/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/dense-quantized-feature-column-handler.cc deleted file mode 100644 index ca7bb71e7d0b0fc945ee29092b1e36022d4c0943..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/dense-quantized-feature-column-handler.cc +++ /dev/null @@ -1,116 +0,0 @@ -// Copyright 2017 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -#include "tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/dense-quantized-feature-column-handler.h" - -namespace tensorflow { -namespace boosted_trees { -namespace learner { -namespace stochastic { - -namespace { - -// Creates a dense split node without assigning children. -boosted_trees::trees::TreeNode CreateDenseSplitNode(const int32 feature_column, - const float threshold) { - boosted_trees::trees::TreeNode split_node; - auto* split = split_node.mutable_dense_float_binary_split(); - split->set_feature_column(feature_column); - split->set_threshold(threshold); - return split_node; -} - -} // namespace - -void DenseQuantizedFeatureColumnHandler::AggregateGradientStats( - const std::vector& example_partition_ids, - const Tensor& example_first_order_gradients, - const Tensor& example_second_order_gradients, - FeatureStatsAccumulator* - gradient_stats_accumulator) const { - // Pass over all examples and aggregate gradient stats for each partition - // and quantized feature bucket. - for (int64 example_idx = 0; example_idx < batch_size_; ++example_idx) { - auto partition_id = example_partition_ids[example_idx]; - auto feature_id = dense_quantized_values_(example_idx); - gradient_stats_accumulator->AddStats( - slot_id_, class_id_, partition_id, feature_id, - GradientStats(example_first_order_gradients, - example_second_order_gradients, example_idx)); - } -} - -void DenseQuantizedFeatureColumnHandler::GenerateFeatureSplitCandidates( - const LearnerConfig& learner_config, const std::vector& roots, - const std::vector& root_stats, - const FeatureStatsAccumulator& - gradient_stats_accumulator, - std::vector* split_candidates) const { - // Evaluate split candidates for every root as each is a separate - // logical partition over the examples. - // Then for each root, we do a forward-only pass over the quantized - // feature buckets accumulating gradients from left to right. - // Split gains are evaluated at every threshold and the best split is picked. - split_candidates->clear(); - split_candidates->reserve(roots.size()); - for (size_t root_idx = 0; root_idx < roots.size(); ++root_idx) { - // Get partition Id and root node stats. - const int32 partition_id = roots[root_idx]; - const NodeStats& root_node_stats = root_stats[root_idx]; - - // Forward left to right pass over quantiles. - GradientStats left_gradient_stats; - GradientStats right_gradient_stats(root_node_stats.gradient_stats); - FeatureSplitCandidate best_split_candidate( - root_node_stats.weight_contribution.size()); - best_split_candidate.split_stats.gain = - std::numeric_limits::lowest(); - for (int bucket_id = 0; bucket_id < dense_quantiles_.size(); ++bucket_id) { - // Get gradient stats. - auto gradient_stats = gradient_stats_accumulator.GetStats( - slot_id_, class_id_, partition_id, bucket_id); - if (gradient_stats.IsZero()) { - continue; - } - - // Update gradient stats. - left_gradient_stats += gradient_stats; - right_gradient_stats = - root_node_stats.gradient_stats - left_gradient_stats; - - // Get node stats - NodeStats left_node_stats(learner_config, left_gradient_stats); - NodeStats right_node_stats(learner_config, right_gradient_stats); - - // Generate split candidate. - const float threshold = dense_quantiles_(bucket_id); - FeatureSplitCandidate split_candidate( - slot_id_, CreateDenseSplitNode(dense_feature_column_, threshold), - SplitStats(learner_config, root_node_stats, left_node_stats, - right_node_stats)); - if (split_candidate.split_stats.gain > - best_split_candidate.split_stats.gain) { - best_split_candidate = std::move(split_candidate); - } - } - - // Add best candidate for partition. - split_candidates->push_back(std::move(best_split_candidate)); - } -} - -} // namespace stochastic -} // namespace learner -} // namespace boosted_trees -} // namespace tensorflow diff --git a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/dense-quantized-feature-column-handler.h b/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/dense-quantized-feature-column-handler.h deleted file mode 100644 index 0f3858e4d8c406e9ec3ae7079b241e94ef4aa35c..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/dense-quantized-feature-column-handler.h +++ /dev/null @@ -1,62 +0,0 @@ -// Copyright 2017 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_HANDLERS_DENSE_QUANTIZED_FEATURE_COLUMN_HANDLER_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_HANDLERS_DENSE_QUANTIZED_FEATURE_COLUMN_HANDLER_H_ - -#include "tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/feature-column-handler.h" - -namespace tensorflow { -namespace boosted_trees { -namespace learner { -namespace stochastic { - -// Handler for a dense quantized feature column in the single class case. -class DenseQuantizedFeatureColumnHandler : public FeatureColumnHandler { - public: - DenseQuantizedFeatureColumnHandler( - const int32 class_id, const int32 slot_id, const int64 batch_size, - const int32 dense_feature_column, TTypes::ConstVec dense_quantiles, - TTypes::ConstVec dense_quantized_values) - : FeatureColumnHandler(class_id, slot_id, batch_size), - dense_feature_column_(dense_feature_column), - dense_quantiles_(dense_quantiles), - dense_quantized_values_(dense_quantized_values) {} - - void AggregateGradientStats( - const std::vector& example_partition_ids, - const Tensor& example_first_order_gradients, - const Tensor& example_second_order_gradients, - FeatureStatsAccumulator* - gradient_stats_accumulator) const override; - - void GenerateFeatureSplitCandidates( - const LearnerConfig& learner_config, const std::vector& roots, - const std::vector& root_stats, - const FeatureStatsAccumulator& - gradient_stats_accumulator, - std::vector* split_candidates) const override; - - protected: - const int32 dense_feature_column_; - TTypes::ConstVec dense_quantiles_; - TTypes::ConstVec dense_quantized_values_; -}; - -} // namespace stochastic -} // namespace learner -} // namespace boosted_trees -} // namespace tensorflow - -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_HANDLERS_DENSE_QUANTIZED_FEATURE_COLUMN_HANDLER_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/dense-quantized-feature-column-handler_test.cc b/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/dense-quantized-feature-column-handler_test.cc deleted file mode 100644 index 1bc9d733ad3090f1cfc9547644061f54d7d2c8c6..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/dense-quantized-feature-column-handler_test.cc +++ /dev/null @@ -1,155 +0,0 @@ -// Copyright 2017 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -#include "tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/dense-quantized-feature-column-handler.h" - -#include "tensorflow/core/framework/tensor_testutil.h" -#include "tensorflow/core/platform/test.h" - -namespace tensorflow { -namespace boosted_trees { -namespace learner { -namespace stochastic { -namespace { - -using boosted_trees::learner::LearnerConfig; - -const auto kClassId = 1; -const auto kSlotId = 0; -const auto kBatchSize = 4; -const auto kFeatureColumn = 2; - -using FeatureStatsAccumulator = - FeatureStatsAccumulator; - -class DenseQuantizedFeatureColumnHandlerTest : public ::testing::Test { - protected: - // The data looks like the following: - // Example | Gradients | Partition | Dense Quantile | - // i0 | (0.2, 0.12) | 0 | 1 | - // i1 | (-0.5, 0.07) | 0 | 1 | - // i2 | (1.2, 0.2) | 0 | 0 | - // i3 | (4.0, 0.13) | 1 | 1 | - DenseQuantizedFeatureColumnHandlerTest() - : example_first_order_gradients_( - test::AsTensor({0.2f, -0.5f, 1.2f, 4.0f}, {4})), - example_second_order_gradients_( - test::AsTensor({0.12f, 0.07f, 0.2f, 0.13f}, {4})), - example_partitions_({0, 0, 0, 1}), - dense_quantiles_(test::AsTensor({0.3f, 0.52f}, {2})), - dense_quantized_values_(test::AsTensor({1, 1, 0, 1}, {4})) { - // Set L2 regularization. - learner_config_.mutable_regularization()->set_l2(2.0f); - learner_config_.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS); - // Create handler. - handler_.reset(new DenseQuantizedFeatureColumnHandler( - kClassId, kSlotId, kBatchSize, kFeatureColumn, - dense_quantiles_.vec(), dense_quantized_values_.vec())); - } - - LearnerConfig learner_config_; - const Tensor example_first_order_gradients_; - const Tensor example_second_order_gradients_; - const std::vector example_partitions_; - const Tensor dense_quantiles_; - const Tensor dense_quantized_values_; - std::unique_ptr handler_; -}; - -TEST_F(DenseQuantizedFeatureColumnHandlerTest, AggregateGradientStats) { - // Create handler. - FeatureStatsAccumulator accumulator(1); - handler_->AggregateGradientStats( - example_partitions_, example_first_order_gradients_, - example_second_order_gradients_, &accumulator); - - // Check stats for each partition and feature. - // Partition 0, Feature 0. - EXPECT_GRADIENT_STATS_EQ(GradientStats(1.2f, 0.2f), - accumulator.GetStats(kSlotId, kClassId, 0, 0)); - // Partition 0, Feature 1. - EXPECT_GRADIENT_STATS_EQ(GradientStats(-0.3f, 0.19f), - accumulator.GetStats(kSlotId, kClassId, 0, 1)); - // Partition 1, Feature 0. - EXPECT_GRADIENT_STATS_EQ(GradientStats(0.0f, 0.0f), - accumulator.GetStats(kSlotId, kClassId, 1, 0)); - // Partition 1, Feature 1. - EXPECT_GRADIENT_STATS_EQ(GradientStats(4.0f, 0.13f), - accumulator.GetStats(kSlotId, kClassId, 1, 1)); -} - -TEST_F(DenseQuantizedFeatureColumnHandlerTest, GenerateFeatureSplitCandidates) { - // Create handler. - FeatureStatsAccumulator accumulator(1); - handler_->AggregateGradientStats( - example_partitions_, example_first_order_gradients_, - example_second_order_gradients_, &accumulator); - - // Get feature split candidates for two roots 0 and 1. - // The root stats are derived from the per-partition total gradient stats. - const std::vector roots = {0, 1, 5}; - const std::vector& root_stats = { - NodeStats(learner_config_, GradientStats(0.9f, 0.39f)), - NodeStats(learner_config_, GradientStats(4.0f, 0.13f)), NodeStats(1)}; - std::vector split_candidates; - handler_->GenerateFeatureSplitCandidates(learner_config_, roots, root_stats, - accumulator, &split_candidates); - // Expect three candidate splits (one per root). - EXPECT_EQ(3, split_candidates.size()); - - // Verify candidate for root 0, the best split occurs when we route - // example i2 left and i0, i1 right. - const NodeStats expected_left_node0(learner_config_, - GradientStats(1.2f, 0.2f)); - const NodeStats expected_right_node0( - learner_config_, - root_stats[0].gradient_stats - expected_left_node0.gradient_stats); - const SplitStats expected_split_stats0(learner_config_, root_stats[0], - expected_left_node0, - expected_right_node0); - EXPECT_SPLIT_STATS_EQ(expected_split_stats0, split_candidates[0].split_stats); - const auto& tree_node0 = split_candidates[0].tree_node; - EXPECT_EQ(boosted_trees::trees::TreeNode::kDenseFloatBinarySplit, - tree_node0.node_case()); - const auto& split0 = tree_node0.dense_float_binary_split(); - EXPECT_FLOAT_EQ(dense_quantiles_.vec()(0), split0.threshold()); - EXPECT_EQ(kFeatureColumn, split0.feature_column()); - - // Verify candidate for root 1, there's only one active bucket here - // so zero gain is expected. - const NodeStats expected_left_node1(learner_config_, - root_stats[1].gradient_stats); - const NodeStats expected_right_node1(learner_config_, GradientStats(0, 0)); - const SplitStats expected_split_stats1(learner_config_, root_stats[1], - expected_left_node1, - expected_right_node1); - EXPECT_SPLIT_STATS_EQ(expected_split_stats1, split_candidates[1].split_stats); - const auto& tree_node1 = split_candidates[1].tree_node; - EXPECT_EQ(boosted_trees::trees::TreeNode::kDenseFloatBinarySplit, - tree_node1.node_case()); - const auto& split1 = tree_node1.dense_float_binary_split(); - EXPECT_FLOAT_EQ(dense_quantiles_.vec()(1), split1.threshold()); - EXPECT_EQ(kFeatureColumn, split1.feature_column()); - - // Verify there are no candidate splits for root 5. - const auto& tree_node2 = split_candidates[2].tree_node; - EXPECT_EQ(boosted_trees::trees::TreeNode::NODE_NOT_SET, - tree_node2.node_case()); -} - -} // namespace -} // namespace stochastic -} // namespace learner -} // namespace boosted_trees -} // namespace tensorflow diff --git a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/feature-column-handler.h b/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/feature-column-handler.h deleted file mode 100644 index 8bd2092f9609cb684b89f70cab35a92789fb39a4..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/feature-column-handler.h +++ /dev/null @@ -1,83 +0,0 @@ -// Copyright 2017 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= - -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_HANDLERS_FEATURE_COLUMN_HANDLER_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_HANDLERS_FEATURE_COLUMN_HANDLER_H_ - -#include -#include "tensorflow/contrib/boosted_trees/lib/learner/common/accumulators/feature-stats-accumulator.h" -#include "tensorflow/contrib/boosted_trees/lib/learner/stochastic/stats/feature-split-candidate.h" -#include "tensorflow/contrib/boosted_trees/proto/learner.pb.h" -#include "tensorflow/core/framework/attr_value.pb.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/tensor_types.h" - -namespace tensorflow { -namespace boosted_trees { -namespace learner { -namespace stochastic { - -// Handler interface for feature columns. Each feature column type may -// have its own handler which encapsulates the logic of aggregating gradient -// stats as well as generating split candidates for each partition. -// Handlers can be stateful and must be thread compatible. -class FeatureColumnHandler { - public: - FeatureColumnHandler(const int32 class_id, const int32 slot_id, - const int64 batch_size) - : class_id_(class_id), slot_id_(slot_id), batch_size_(batch_size) {} - - virtual ~FeatureColumnHandler() {} - FeatureColumnHandler(const FeatureColumnHandler& other) = delete; - FeatureColumnHandler& operator=(const FeatureColumnHandler& other) = delete; - - // Aggregates example gradient stats for the feature column. - virtual void AggregateGradientStats( - const std::vector& example_partition_ids, - const Tensor& example_first_order_gradients, - const Tensor& example_second_order_gradients, - FeatureStatsAccumulator* - gradient_stats_accumulator) const = 0; - - // Generates feature column split candidates for the specified roots. - virtual void GenerateFeatureSplitCandidates( - const LearnerConfig& learner_config, const std::vector& roots, - const std::vector& root_stats, - const FeatureStatsAccumulator& - gradient_stats_accumulator, - std::vector* split_candidates) const = 0; - - // Accessors. - int32 class_id() const { return class_id_; } - int32 slot_id() const { return slot_id_; } - int64 batch_size() const { return batch_size_; } - - protected: - // The class Id. - const int32 class_id_; - - // The slod Id for use as a unique Id across all feature columns. - const int32 slot_id_; - - // Size of the batch of examples. - const int64 batch_size_; -}; - -} // namespace stochastic -} // namespace learner -} // namespace boosted_trees -} // namespace tensorflow - -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_HANDLERS_FEATURE_COLUMN_HANDLER_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/sparse-quantized-feature-column-handler.cc b/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/sparse-quantized-feature-column-handler.cc deleted file mode 100644 index a0e9efbbc5030e8c2e25fafab98271337a2e582a..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/sparse-quantized-feature-column-handler.cc +++ /dev/null @@ -1,172 +0,0 @@ -// Copyright 2017 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -#include "tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/sparse-quantized-feature-column-handler.h" - -namespace tensorflow { -namespace boosted_trees { -namespace learner { -namespace stochastic { - -namespace { - -// Creates a sparse default right split node without assigning children. -boosted_trees::trees::TreeNode CreateSparseSplitNodeDefaultRight( - int32 feature_column, float threshold) { - boosted_trees::trees::TreeNode split_node; - auto* split = split_node.mutable_sparse_float_binary_split_default_right() - ->mutable_split(); - split->set_feature_column(feature_column); - split->set_threshold(threshold); - return split_node; -} - -// Creates a sparse default left split node without assigning children. -boosted_trees::trees::TreeNode CreateSparseSplitNodeDefaultLeft( - int32 feature_column, float threshold) { - boosted_trees::trees::TreeNode split_node; - auto* split = split_node.mutable_sparse_float_binary_split_default_left() - ->mutable_split(); - split->set_feature_column(feature_column); - split->set_threshold(threshold); - return split_node; -} - -} // namespace - -void SparseQuantizedFeatureColumnHandler::AggregateGradientStats( - const std::vector& example_partition_ids, - const Tensor& example_first_order_gradients, - const Tensor& example_second_order_gradients, - FeatureStatsAccumulator* - gradient_stats_accumulator) const { - // Pass over all rows and aggregate gradient stats for each partition - // and quantized feature bucket. - const int64 num_rows = sparse_indices_.dimension(0); - for (int64 row_idx = 0; row_idx < num_rows; ++row_idx) { - auto example_idx = sparse_indices_(row_idx, 0); - auto partition_id = example_partition_ids[example_idx]; - auto feature_id = sparse_quantized_values_(row_idx); - gradient_stats_accumulator->AddStats( - slot_id_, class_id_, partition_id, feature_id, - GradientStats(example_first_order_gradients, - example_second_order_gradients, example_idx)); - } -} - -void SparseQuantizedFeatureColumnHandler::GenerateFeatureSplitCandidates( - const LearnerConfig& learner_config, const std::vector& roots, - const std::vector& root_stats, - const FeatureStatsAccumulator& - gradient_stats_accumulator, - std::vector* split_candidates) const { - // Evaluate split candidates for every root as each is a separate - // logical partition over the examples. - // Then for each root, we do both a forward left to right pass and a backward - // right to left pass over the quantized feature buckets accumulating - // gradients on one side and using the root aggregate gradients to get the - // gradients for the other side. Split gains are evaluated for each pass at - // every threshold and the best split is picked. - split_candidates->clear(); - split_candidates->reserve(roots.size()); - for (size_t root_idx = 0; root_idx < roots.size(); ++root_idx) { - // Get partition Id and root node stats. - const int32 partition_id = roots[root_idx]; - const NodeStats& root_node_stats = root_stats[root_idx]; - - // Forward pass with right default direction. - GradientStats left_gradient_stats; - GradientStats right_gradient_stats(root_node_stats.gradient_stats); - FeatureSplitCandidate best_split_candidate( - root_node_stats.weight_contribution.size()); - best_split_candidate.split_stats.gain = - std::numeric_limits::lowest(); - for (int bucket_id = 0; bucket_id < sparse_quantiles_.size(); ++bucket_id) { - // Get gradient stats. - auto gradient_stats = gradient_stats_accumulator.GetStats( - slot_id_, class_id_, partition_id, bucket_id); - if (gradient_stats.IsZero()) { - continue; - } - - // Update gradient stats. - left_gradient_stats += gradient_stats; - right_gradient_stats = - root_node_stats.gradient_stats - left_gradient_stats; - - // Get node stats - NodeStats left_node_stats(learner_config, left_gradient_stats); - NodeStats right_node_stats(learner_config, right_gradient_stats); - - // Generate split candidate. - const float threshold = sparse_quantiles_(bucket_id); - FeatureSplitCandidate split_candidate( - slot_id_, - CreateSparseSplitNodeDefaultRight(sparse_feature_column_, threshold), - SplitStats(learner_config, root_node_stats, left_node_stats, - right_node_stats)); - if (split_candidate.split_stats.gain > - best_split_candidate.split_stats.gain) { - best_split_candidate = std::move(split_candidate); - } - } - - // Determine if we need a backward pass by checking if the residual gradient - // after forward aggregation is almost the same as the aggregated gradient. - // for the current root. This helps avoid unnecessary computation as well - // as consistency due to floating point precision. - if (!right_gradient_stats.IsAlmostZero()) { - // Backward pass with left default direction. - right_gradient_stats = GradientStats(); - left_gradient_stats = root_node_stats.gradient_stats; - for (int bucket_id = sparse_quantiles_.size() - 1; bucket_id > 0; - --bucket_id) { - // Get gradient stats. - auto gradient_stats = gradient_stats_accumulator.GetStats( - slot_id_, class_id_, partition_id, bucket_id); - if (gradient_stats.IsZero()) { - continue; - } - - // Update gradient stats. - right_gradient_stats += gradient_stats; - left_gradient_stats = root_node_stats.gradient_stats - gradient_stats; - - // Get node stats - NodeStats left_node_stats(learner_config, left_gradient_stats); - NodeStats right_node_stats(learner_config, right_gradient_stats); - - // Generate split candidate. - const float threshold = sparse_quantiles_(bucket_id - 1); - FeatureSplitCandidate split_candidate( - slot_id_, - CreateSparseSplitNodeDefaultLeft(sparse_feature_column_, threshold), - SplitStats(learner_config, root_node_stats, left_node_stats, - right_node_stats)); - if (split_candidate.split_stats.gain > - best_split_candidate.split_stats.gain) { - best_split_candidate = std::move(split_candidate); - } - } - } - - // Add best candidate for partition. - split_candidates->push_back(std::move(best_split_candidate)); - } -} - -} // namespace stochastic -} // namespace learner -} // namespace boosted_trees -} // namespace tensorflow diff --git a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/sparse-quantized-feature-column-handler.h b/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/sparse-quantized-feature-column-handler.h deleted file mode 100644 index eb63e705471a65e8448bda38b2e31eb971d5c1bb..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/sparse-quantized-feature-column-handler.h +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright 2017 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= - -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_HANDLERS_SPARSE_QUANTIZED_FEATURE_COLUMN_HANDLER_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_HANDLERS_SPARSE_QUANTIZED_FEATURE_COLUMN_HANDLER_H_ - -#include "tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/feature-column-handler.h" - -namespace tensorflow { -namespace boosted_trees { -namespace learner { -namespace stochastic { - -// Handler for a sparse quantized feature column in the single class case. -class SparseQuantizedFeatureColumnHandler : public FeatureColumnHandler { - public: - SparseQuantizedFeatureColumnHandler( - const int32 class_id, const int32 slot_id, const int64 batch_size, - const int32 sparse_feature_column, - TTypes::ConstVec sparse_quantiles, - TTypes::ConstMatrix sparse_indices, - TTypes::ConstVec sparse_quantized_values) - : FeatureColumnHandler(class_id, slot_id, batch_size), - sparse_feature_column_(sparse_feature_column), - sparse_quantiles_(sparse_quantiles), - sparse_indices_(sparse_indices), - sparse_quantized_values_(sparse_quantized_values) {} - - void AggregateGradientStats( - const std::vector& example_partition_ids, - const Tensor& example_first_order_gradients, - const Tensor& example_second_order_gradients, - FeatureStatsAccumulator* - gradient_stats_accumulator) const override; - - void GenerateFeatureSplitCandidates( - const LearnerConfig& learner_config, const std::vector& roots, - const std::vector& root_stats, - const FeatureStatsAccumulator& - gradient_stats_accumulator, - std::vector* split_candidates) const override; - - protected: - const int32 sparse_feature_column_; - TTypes::ConstVec sparse_quantiles_; - TTypes::ConstMatrix sparse_indices_; - TTypes::ConstVec sparse_quantized_values_; -}; - -} // namespace stochastic -} // namespace learner -} // namespace boosted_trees -} // namespace tensorflow - -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_STOCHASTIC_HANDLERS_SPARSE_QUANTIZED_FEATURE_COLUMN_HANDLER_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/sparse-quantized-feature-column-handler_test.cc b/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/sparse-quantized-feature-column-handler_test.cc deleted file mode 100644 index 643d936ad23850e601bc5518d69c8637011f53c0..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/sparse-quantized-feature-column-handler_test.cc +++ /dev/null @@ -1,162 +0,0 @@ -// Copyright 2017 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= - -#include "tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/sparse-quantized-feature-column-handler.h" - -#include "tensorflow/core/framework/tensor_testutil.h" -#include "tensorflow/core/platform/test.h" - -namespace tensorflow { -namespace boosted_trees { -namespace learner { -namespace stochastic { -namespace { - -using boosted_trees::learner::LearnerConfig; - -const auto kClassId = 3; -const auto kSlotId = 0; -const auto kBatchSize = 4; -const auto kFeatureColumn = 4; - -using FeatureStatsAccumulator = - FeatureStatsAccumulator; - -class SparseQuantizedFeatureColumnHandlerTest : public ::testing::Test { - protected: - // The data looks like the following: - // Example | Gradients | Partition | Sparse Quantile | - // i0 | (0.2, 0.12) | 0 | 1 | - // i1 | (-0.5, 0.07) | 0 | N/A | - // i2 | (1.2, 0.2) | 0 | 0 | - // i3 | (4.0, 0.13) | 1 | 1 | - SparseQuantizedFeatureColumnHandlerTest() - : example_first_order_gradients_( - test::AsTensor({0.2f, -0.5f, 1.2f, 4.0f}, {4})), - example_second_order_gradients_( - test::AsTensor({0.12f, 0.07f, 0.2f, 0.13f}, {4})), - example_partitions_({0, 0, 0, 1}), - sparse_quantiles_(test::AsTensor({0.3f, 0.52f}, {2})), - sparse_indices_(test::AsTensor({0, 0, 2, 0, 3, 0}, {3, 2})), - sparse_quantized_values_(test::AsTensor({1, 0, 1}, {3})) { - // Set L2 regularization. - learner_config_.mutable_regularization()->set_l2(2.0f); - learner_config_.set_multi_class_strategy(LearnerConfig::TREE_PER_CLASS); - // Create handler. - handler_.reset(new SparseQuantizedFeatureColumnHandler( - kClassId, kSlotId, kBatchSize, kFeatureColumn, - sparse_quantiles_.vec(), sparse_indices_.matrix(), - sparse_quantized_values_.vec())); - } - - LearnerConfig learner_config_; - const Tensor example_first_order_gradients_; - const Tensor example_second_order_gradients_; - const std::vector example_partitions_; - const Tensor sparse_quantiles_; - const Tensor sparse_indices_; - const Tensor sparse_quantized_values_; - std::unique_ptr handler_; -}; - -TEST_F(SparseQuantizedFeatureColumnHandlerTest, AggregateGradientStats) { - // Create handler. - FeatureStatsAccumulator accumulator(1); - handler_->AggregateGradientStats( - example_partitions_, example_first_order_gradients_, - example_second_order_gradients_, &accumulator); - - // Check stats for each partition and feature. - // Partition 0, Feature 0. - EXPECT_GRADIENT_STATS_EQ(GradientStats(1.2f, 0.2f), - accumulator.GetStats(kSlotId, kClassId, 0, 0)); - // Partition 0, Feature 1. - EXPECT_GRADIENT_STATS_EQ(GradientStats(0.2f, 0.12f), - accumulator.GetStats(kSlotId, kClassId, 0, 1)); - // Partition 1, Feature 0. - EXPECT_GRADIENT_STATS_EQ(GradientStats(0.0f, 0.0f), - accumulator.GetStats(kSlotId, kClassId, 1, 0)); - // Partition 1, Feature 1. - EXPECT_GRADIENT_STATS_EQ(GradientStats(4.0f, 0.13f), - accumulator.GetStats(kSlotId, kClassId, 1, 1)); -} - -TEST_F(SparseQuantizedFeatureColumnHandlerTest, - GenerateFeatureSplitCandidates) { - // Create handler. - FeatureStatsAccumulator accumulator(1); - handler_->AggregateGradientStats( - example_partitions_, example_first_order_gradients_, - example_second_order_gradients_, &accumulator); - - // Get feature split candidates for two roots 0 and 1. - // The root stats are derived from the per-partition total gradient stats. - const std::vector roots = {0, 1, 9}; - const std::vector& root_stats = { - NodeStats(learner_config_, GradientStats(0.9f, 0.39f)), - NodeStats(learner_config_, GradientStats(4.0f, 0.13f)), NodeStats(1)}; - std::vector split_candidates; - handler_->GenerateFeatureSplitCandidates(learner_config_, roots, root_stats, - accumulator, &split_candidates); - // Expect three candidate splits (one per root). - EXPECT_EQ(3, split_candidates.size()); - - // Verify candidate for root 0, the best split occurs when we route - // example i0 and i2 to the left and i1 to the right (by default direction). - const NodeStats expected_left_node0(learner_config_, - GradientStats(0.2f + 1.2f, 0.12f + 0.2f)); - const NodeStats expected_right_node0( - learner_config_, - root_stats[0].gradient_stats - expected_left_node0.gradient_stats); - const SplitStats expected_split_stats0(learner_config_, root_stats[0], - expected_left_node0, - expected_right_node0); - EXPECT_SPLIT_STATS_EQ(expected_split_stats0, split_candidates[0].split_stats); - const auto& tree_node0 = split_candidates[0].tree_node; - EXPECT_EQ(boosted_trees::trees::TreeNode::kSparseFloatBinarySplitDefaultRight, - tree_node0.node_case()); - const auto& split0 = - tree_node0.sparse_float_binary_split_default_right().split(); - EXPECT_FLOAT_EQ(sparse_quantiles_.vec()(1), split0.threshold()); - EXPECT_EQ(kFeatureColumn, split0.feature_column()); - - // Verify candidate for root 1, there's only one active bucket here - // so zero gain is expected. - const NodeStats expected_left_node1(learner_config_, - root_stats[1].gradient_stats); - const NodeStats expected_right_node1(learner_config_, GradientStats(0, 0)); - const SplitStats expected_split_stats1(learner_config_, root_stats[1], - expected_left_node1, - expected_right_node1); - EXPECT_SPLIT_STATS_EQ(expected_split_stats1, split_candidates[1].split_stats); - const auto& tree_node1 = split_candidates[1].tree_node; - EXPECT_EQ(boosted_trees::trees::TreeNode::kSparseFloatBinarySplitDefaultRight, - tree_node1.node_case()); - const auto& split1 = - tree_node1.sparse_float_binary_split_default_right().split(); - EXPECT_FLOAT_EQ(sparse_quantiles_.vec()(1), split1.threshold()); - EXPECT_EQ(kFeatureColumn, split1.feature_column()); - - // Verify there are no candidate splits for root 9. - const auto& tree_node2 = split_candidates[2].tree_node; - EXPECT_EQ(boosted_trees::trees::TreeNode::NODE_NOT_SET, - tree_node2.node_case()); -} - -} // namespace -} // namespace stochastic -} // namespace learner -} // namespace boosted_trees -} // namespace tensorflow diff --git a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_buffer.h b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_buffer.h index 5e316538cefed30b2867252c9ebc4754216db329..70037d5bd8f446bdbbfcc468edb8a76c05e4fab7 100644 --- a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_buffer.h +++ b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_buffer.h @@ -33,9 +33,9 @@ template mutable_leaf(); // Split on first column - split_node->set_feature_id(0); + split_node->set_dimension_id(0); split_node->set_threshold(2.0f); // Both instances have this feature value. @@ -199,7 +199,7 @@ TEST_F(DecisionTreeTest, TraverseSparseBinarySplit) { EXPECT_EQ(1, DecisionTree::Traverse(tree_config, 0, *++example_it)); // Split on second column - split_node->set_feature_id(1); + split_node->set_dimension_id(1); split_node->set_threshold(5.0f); // First instance does not have it (default right), second does have it. @@ -208,7 +208,7 @@ TEST_F(DecisionTreeTest, TraverseSparseBinarySplit) { EXPECT_EQ(1, DecisionTree::Traverse(tree_config, 0, *++example_it)); // Split on third column - split_node->set_feature_id(2); + split_node->set_dimension_id(2); split_node->set_threshold(3.0f); example_it = example_iterable.begin(); diff --git a/tensorflow/contrib/boosted_trees/lib/utils/batch_features.h b/tensorflow/contrib/boosted_trees/lib/utils/batch_features.h index 7a550d6f7328765d8815a947885e47fa0b0a8f8b..badc629a118f768d5aa25ef1b94b8190e6910c7f 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/batch_features.h +++ b/tensorflow/contrib/boosted_trees/lib/utils/batch_features.h @@ -56,7 +56,7 @@ class BatchFeatures { *num_sparse_int_features = sparse_int_feature_columns_.size(); if (*num_dense_float_features == 0 && *num_sparse_float_features == 0 && *num_sparse_int_features == 0) { - return errors::FailedPrecondition("Not intialized yet."); + return errors::FailedPrecondition("Not initialized yet."); } return Status::OK(); } diff --git a/tensorflow/contrib/boosted_trees/lib/utils/example.h b/tensorflow/contrib/boosted_trees/lib/utils/example.h index e388cf332c3ff327f79ea57e3a0bccbbaa1b5e45..54f60e1dee49a4a40b84fcc6e042fac1858aa187 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/example.h +++ b/tensorflow/contrib/boosted_trees/lib/utils/example.h @@ -63,7 +63,7 @@ class SparseFloatFeatureColumn { public: void Reserve(const int32 size) { if (!single_dimensional_) { - mutlidimensional_values.Reserve(size); + multidimensional_values.Reserve(size); } } @@ -76,7 +76,7 @@ class SparseFloatFeatureColumn { DCHECK_EQ(0, feature_idx); single_value_ = value; } else { - mutlidimensional_values.Add(feature_idx, value); + multidimensional_values.Add(feature_idx, value); } initialized_ = true; } @@ -84,7 +84,7 @@ class SparseFloatFeatureColumn { void Clear() { single_dimensional_ = false; initialized_ = false; - mutlidimensional_values.Clear(); + multidimensional_values.Clear(); } OptionalValue operator[](int feature_idx) const { @@ -94,7 +94,7 @@ class SparseFloatFeatureColumn { if (single_dimensional_) { return OptionalValue(single_value_); } else { - return mutlidimensional_values[feature_idx]; + return multidimensional_values[feature_idx]; } } @@ -102,7 +102,7 @@ class SparseFloatFeatureColumn { bool single_dimensional_; bool initialized_; T single_value_; - SparseMultidimensionalValues mutlidimensional_values; + SparseMultidimensionalValues multidimensional_values; }; // Holds data for one example and enables lookup by feature column. diff --git a/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable.cc b/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable.cc index bc0a93db8c39abf737d11682088233e2fd88e868..ccee9530b6897924453461c13b1238402c0f6cfa 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable.cc +++ b/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable.cc @@ -96,6 +96,10 @@ class IndicesRowIterator return (row_idx_ != other.row_idx_); } + bool operator<(const IndicesRowIterator& other) const { + return (row_idx_ < other.row_idx_); + } + bool operator==(const IndicesRowIterator& other) const { QCHECK_EQ(iter_, other.iter_); return (row_idx_ == other.row_idx_); diff --git a/tensorflow/contrib/boosted_trees/ops/prediction_ops.cc b/tensorflow/contrib/boosted_trees/ops/prediction_ops.cc index 82b8e8c1c272ca415b5841f5ba9433e00173f8fa..d66f645f62aba84261337eb37d6e3204930f8f15 100644 --- a/tensorflow/contrib/boosted_trees/ops/prediction_ops.cc +++ b/tensorflow/contrib/boosted_trees/ops/prediction_ops.cc @@ -36,7 +36,7 @@ static Status ApplyGradientTreesPredictionShapeFn(InferenceContext* c) { c->set_output(0, {c->Matrix(InferenceContext::kUnknownDim, reduce_dim ? learner_config.num_classes() - 1 : learner_config.num_classes())}); - c->set_output(1, {c->Vector(InferenceContext::kUnknownDim)}); + c->set_output(1, {c->UnknownShape()}); return Status::OK(); } diff --git a/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc b/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc index 4ca73ef6e3301aadda48d5c971c31b57b7925614..1fa70bafddb0c94f47d006d5694bea941edaddf9 100644 --- a/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc +++ b/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc @@ -268,6 +268,7 @@ REGISTER_OP("Quantiles") .Input("sparse_values: num_sparse_features * float") .Input("dense_buckets: num_dense_features * float") .Input("sparse_buckets: num_sparse_features * float") + .Input("sparse_indices: num_sparse_features * int64") .Output("dense_quantiles: num_dense_features * int32") .Output("sparse_quantiles: num_sparse_features * int32") .Doc(R"doc( @@ -280,10 +281,13 @@ dense_values: List of rank 1 tensors containing the dense values. sparse_values: List of rank 1 tensors containing the sparse feature values. dense_buckets: Quantile summary for each of the dense float tensor. sparse_buckets: Quantile summary for each of the sparse feature float tensor. -dense_quantiles: Rank 1 tensors representing associated quantiles for each of -dense float tensors. -sparse_quantiles: Rank 1 tensors representing associated quantiles for each of -the sparse feature tensors. +sparse_indices: List of rank 2 tensors with indices for sparse float +tensors. +dense_quantiles: Rank 2 tensors representing associated quantiles for each of +dense float tensors and the dimension. +sparse_quantiles: Rank 2 tensors representing associated quantiles for each of +the sparse feature tensors for each of sparse feature dimensions: +[quantile id, dimension id]. )doc"); REGISTER_OP("BucketizeWithInputBoundaries") diff --git a/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc index 07cfd413bbd389053ff52ca65693445ef28e8ede..0d27ddaf3a1d540efee268c2bcca217077ff5871 100644 --- a/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc +++ b/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc @@ -47,9 +47,7 @@ REGISTER_OP("BuildDenseInequalitySplits") ShapeHandle partition_ids_shape; TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &partition_ids_shape)); ShapeHandle bucket_ids_shape; - TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &bucket_ids_shape)); - TF_RETURN_IF_ERROR(c->Merge(c->Dim(partition_ids_shape, 0), - c->Dim(bucket_ids_shape, 0), &unused_dim)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &bucket_ids_shape)); ShapeHandle gradients_shape; TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(3), 1, &gradients_shape)); TF_RETURN_IF_ERROR(c->Merge(c->Dim(partition_ids_shape, 0), @@ -71,7 +69,7 @@ Find the split that has the best gain for the accumulated stats. num_minibatches: A scalar, the number of times per example gradients & hessians were accumulated. The stats are divided by this to get per example stats. partition_ids: A rank 1 tensor of partition IDs. -bucket_ids: A rank 1 tensor of buckets IDs. +bucket_ids: A rank 2 tensor of buckets IDs and dimensions. gradients: A rank 1 tensor of gradients. hessians: A rank 1 tensor of hessians. bucket_boundaries: A rank 1 tensor, thresholds that were used for bucketization. @@ -108,9 +106,7 @@ REGISTER_OP("BuildSparseInequalitySplits") ShapeHandle partition_ids_shape; TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &partition_ids_shape)); ShapeHandle bucket_ids_shape; - TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &bucket_ids_shape)); - TF_RETURN_IF_ERROR(c->Merge(c->Dim(partition_ids_shape, 0), - c->Dim(bucket_ids_shape, 0), &unused_dim)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &bucket_ids_shape)); ShapeHandle gradients_shape; TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(3), 1, &gradients_shape)); TF_RETURN_IF_ERROR(c->Merge(c->Dim(partition_ids_shape, 0), @@ -127,12 +123,13 @@ REGISTER_OP("BuildSparseInequalitySplits") return Status::OK(); }) .Doc(R"doc( -Find the split that has the best gain for the accumulated stats. +Find the split that has the best gain for the accumulated stats for a particular +feature column. num_minibatches: A scalar, the number of times per example gradients & hessians were accumulated. The stats are divided by this to get per example stats. -partition_ids: A rank 1 tensor of partition IDs. -bucket_ids: A rank 1 tensor of buckets IDs. +partition_ids: A rank 2 tensor of partition IDs for each dimension of feature column. +bucket_ids: A rank 2 tensor of buckets IDs and dimensions. gradients: A rank 1 tensor of gradients. hessians: A rank 1 tensor of hessians. bucket_boundaries: A rank 1 tensor, thresholds that were used for bucketization. @@ -168,9 +165,7 @@ REGISTER_OP("BuildCategoricalEqualitySplits") ShapeHandle partition_ids_shape; TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &partition_ids_shape)); ShapeHandle bucket_ids_shape; - TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &bucket_ids_shape)); - TF_RETURN_IF_ERROR(c->Merge(c->Dim(partition_ids_shape, 0), - c->Dim(bucket_ids_shape, 0), &unused_dim)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &bucket_ids_shape)); ShapeHandle gradients_shape; TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(3), 1, &gradients_shape)); TF_RETURN_IF_ERROR(c->Merge(c->Dim(partition_ids_shape, 0), @@ -190,7 +185,7 @@ Find the split that has the best gain for the accumulated stats. num_minibatches: A scalar, the number of times per example gradients & hessians were accumulated. The stats are divided by this to get per example stats. partition_ids: A rank 1 tensor of partition IDs. -feature_ids: A rank 1 tensor of feature IDs. +feature_ids: A rank 2 tensor of feature IDs and dimensions. gradients: A rank 1 tensor of gradients. hessians: A rank 1 tensor of hessians. output_partition_ids: A rank 1 tensor, the partition IDs that we created splits diff --git a/tensorflow/contrib/boosted_trees/ops/stats_accumulator_ops.cc b/tensorflow/contrib/boosted_trees/ops/stats_accumulator_ops.cc index f988755de021034fc0d33529286dd3b508d746ed..0354f7853cbedf22d0a299273b4dbd225b3121ab 100644 --- a/tensorflow/contrib/boosted_trees/ops/stats_accumulator_ops.cc +++ b/tensorflow/contrib/boosted_trees/ops/stats_accumulator_ops.cc @@ -73,9 +73,7 @@ REGISTER_OP("StatsAccumulatorScalarAdd") 1, &partition_ids_shape)); ShapeHandle feature_ids_shape; TF_RETURN_IF_ERROR(c->WithRank( - c->input(num_resource_handles * 2 + i + 1), 1, &feature_ids_shape)); - TF_RETURN_IF_ERROR(c->Merge(c->Dim(partition_ids_shape, 0), - c->Dim(feature_ids_shape, 0), &unused_dim)); + c->input(num_resource_handles * 2 + i + 1), 2, &feature_ids_shape)); ShapeHandle gradients_shape; TF_RETURN_IF_ERROR(c->WithRank( c->input(num_resource_handles * 3 + i + 1), 1, &gradients_shape)); @@ -96,11 +94,11 @@ stamp_token: Stamp token for Read/Write operations. Any operation with a mismatching token will be dropped. stats_accumulator_handles: A list of handles to the stats accumulator. partition_ids: A list of vectors of partition_ids. -feature_ids: A list of vectors of feature_ids. +feature_ids: Rank 2 tensor of feature id and feature dimension ids. gradients: A list of vectors of gradients for each slot in - . + . hessians: A list of vectors of hessians for each slot in - . + . )doc"); REGISTER_OP("StatsAccumulatorScalarFlush") @@ -119,7 +117,7 @@ REGISTER_OP("StatsAccumulatorScalarFlush") TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_input)); c->set_output(0, c->Scalar()); c->set_output(1, c->Vector(c->UnknownDim())); - c->set_output(2, c->Vector(c->UnknownDim())); + c->set_output(2, c->UnknownShape()); c->set_output(3, c->Vector(c->UnknownDim())); c->set_output(4, c->Vector(c->UnknownDim())); return Status::OK(); @@ -134,7 +132,7 @@ next_stamp_token: Stamp token for the next iteration. num_updates: Number of times stats were added to this accumulator since last flush. output_partition_ids A vector of partition_ids for the slots. -output_feature_ids: A vector of feature_ids for the slots. +output_feature_ids: Rank 2 tensor of feature id and feature dimension ids. output_gradients: A vector of gradients, with a value for each slot in . output_hessians: A vector of hessians, with a value for each slot @@ -161,9 +159,7 @@ REGISTER_OP("StatsAccumulatorScalarDeserialize") ShapeHandle partition_ids_shape; TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 1, &partition_ids_shape)); ShapeHandle feature_ids_shape; - TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 1, &feature_ids_shape)); - TF_RETURN_IF_ERROR(c->Merge(c->Dim(partition_ids_shape, 0), - c->Dim(feature_ids_shape, 0), &unused_dim)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 2, &feature_ids_shape)); ShapeHandle gradients_shape; TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 1, &gradients_shape)); TF_RETURN_IF_ERROR(c->Merge(c->Dim(partition_ids_shape, 0), @@ -183,9 +179,11 @@ stamp_token: Stamp token for Read/Write operations. num_updates: Number of times stats were added to this accumulator since last flush. partition_ids: A vector of partition_ids. -feature_ids: A vector of feature_ids. -gradients: A vector of gradients for each slot in . -hessians: A vector of hessians for each slot in . +feature_ids: Rank 2 tensor of feature id and feature dimension ids. +gradients: A vector of gradients for each slot in . +hessians: A vector of hessians for each slot in )doc"); REGISTER_OP("StatsAccumulatorScalarSerialize") @@ -204,7 +202,7 @@ REGISTER_OP("StatsAccumulatorScalarSerialize") // num_updates c->set_output(1, c->Scalar()); c->set_output(2, c->Vector(c->UnknownDim())); - c->set_output(3, c->Vector(c->UnknownDim())); + c->set_output(3, c->UnknownShape()); c->set_output(4, c->Vector(c->UnknownDim())); c->set_output(5, c->Vector(c->UnknownDim())); return Status::OK(); @@ -217,7 +215,7 @@ stamp_token: The current stamp token for the resource. num_updates: Number of times stats were added to this accumulator since last flush. output_partition_ids A vector of partition_ids for the slots. -output_feature_ids: A vector of feature_ids for the slots. +output_feature_ids: Rank 2 tensor of feature id and feature dimension ids. output_gradients: A vector of gradients, with a value for each slot in . output_hessians: A vector of hessians, with a value for each slot @@ -293,9 +291,7 @@ REGISTER_OP("StatsAccumulatorTensorAdd") 1, &partition_ids_shape)); ShapeHandle feature_ids_shape; TF_RETURN_IF_ERROR(c->WithRank( - c->input(num_resource_handles * 2 + i + 1), 1, &feature_ids_shape)); - TF_RETURN_IF_ERROR(c->Merge(c->Dim(partition_ids_shape, 0), - c->Dim(feature_ids_shape, 0), &unused_dim)); + c->input(num_resource_handles * 2 + i + 1), 2, &feature_ids_shape)); ShapeHandle gradients_shape; TF_RETURN_IF_ERROR(c->WithRankAtLeast( c->input(num_resource_handles * 3 + i + 1), 2, &gradients_shape)); @@ -316,11 +312,11 @@ stats_accumulator_handles: A list of handles to the stats accumulator. stamp_token: Stamp token for Read/Write operations. Any operation with a mismatching token will be dropped. partition_ids: A list of vectors of partition_ids. -feature_ids: A list of vectors of feature_ids. +feature_ids: Rank 2 tensor of feature id and feature dimension ids. gradients: A list of vectors of gradients for each slot in - . + . hessians: A list of vectors of hessians for each slot in - . + . )doc"); REGISTER_OP("StatsAccumulatorTensorFlush") @@ -340,7 +336,7 @@ REGISTER_OP("StatsAccumulatorTensorFlush") // num_updates c->set_output(0, c->Scalar()); c->set_output(1, c->Vector(c->UnknownDim())); - c->set_output(2, c->Vector(c->UnknownDim())); + c->set_output(2, c->UnknownShape()); c->set_output(3, c->UnknownShape()); c->set_output(4, c->UnknownShape()); return Status::OK(); @@ -355,11 +351,11 @@ next_stamp_token: Stamp token to be used for the next iteration. num_updates: Number of times stats were added to this accumulator since last flush. output_partition_ids: A vector of partition_ids for the slots. -output_feature_ids: A vector of feature_ids for the slots. +output_feature_ids: Rank 2 tensor of feature id and feature dimension ids. output_gradients: A tensor of gradients, first dimension matches slots - in . + in . output_hessians: A tensor of hessians, first dimension matches slots - in . + in >. )doc"); REGISTER_OP("StatsAccumulatorTensorDeserialize") @@ -382,9 +378,7 @@ REGISTER_OP("StatsAccumulatorTensorDeserialize") ShapeHandle partition_ids_shape; TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 1, &partition_ids_shape)); ShapeHandle feature_ids_shape; - TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 1, &feature_ids_shape)); - TF_RETURN_IF_ERROR(c->Merge(c->Dim(partition_ids_shape, 0), - c->Dim(feature_ids_shape, 0), &unused_dim)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 2, &feature_ids_shape)); ShapeHandle gradients_shape; TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(5), 2, &gradients_shape)); TF_RETURN_IF_ERROR(c->Merge(c->Dim(partition_ids_shape, 0), @@ -405,9 +399,11 @@ stamp_token: Stamp token for Read/Write operations. num_updates: Number of times stats were added to this accumulator since last flush. partition_ids: A vector of partition_ids. -feature_ids: A vector of feature_ids. -gradients: A vector of gradients for each slot in . -hessians: A vector of hessians for each slot in . +feature_ids: Rank 2 tensor of feature id and feature dimension ids. +gradients: A vector of gradients for each slot in +hessians: A vector of hessians for each slot in . )doc"); REGISTER_OP("StatsAccumulatorTensorSerialize") @@ -426,7 +422,7 @@ REGISTER_OP("StatsAccumulatorTensorSerialize") // num_updates c->set_output(1, c->Scalar()); c->set_output(2, c->Vector(c->UnknownDim())); - c->set_output(3, c->Vector(c->UnknownDim())); + c->set_output(3, c->UnknownShape()); c->set_output(4, c->UnknownShape()); c->set_output(5, c->UnknownShape()); return Status::OK(); @@ -440,11 +436,11 @@ stamp_token: Stamp token for Read/Write operations. num_updates: Number of times stats were added to this accumulator since last flush. output_partition_ids: A vector of partition_ids for the slots. -output_feature_ids: A vector of feature_ids for the slots. +output_feature_ids: Rank 2 tensor of feature id and feature dimension ids. output_gradients: A tensor of gradients, first dimension matches slots - in . + in . output_hessians: A tensor of hessians, first dimension matches slots - in . + in . )doc"); REGISTER_OP("StatsAccumulatorTensorMakeSummary") @@ -458,18 +454,20 @@ REGISTER_OP("StatsAccumulatorTensorMakeSummary") .Output("output_hessians: float") .Doc(R"doc( Summarizes the stats by summing the that are for the same -. +. partition_ids: A vector of partition_ids. -feature_ids: A vector of feature_ids. -gradients: A vector of gradients for each slot in . -hessians: A vector of hessians for each slot in . +feature_ids: Rank 2 tensor of feature id and feature dimension ids. +gradients: A vector of gradients for each slot in . +hessians: A vector of hessians for each slot in . output_partition_ids: A vector of partition_ids for the slots. -output_feature_ids: A vector of feature_ids for the slots. +output_feature_ids: A rank2 tensor of feature_ids and dimensions for the slots. output_gradients: A tensor of gradients, first dimension matches slots - in . + in . output_hessians: A tensor of hessians, first dimension matches slots - in . + in . )doc"); } // namespace boosted_trees } // namespace tensorflow diff --git a/tensorflow/contrib/boosted_trees/proto/tree_config.proto b/tensorflow/contrib/boosted_trees/proto/tree_config.proto index f14abf45a517ad7c4c6d7bb1ab88b7a1d47d6fb6..fc570c1083d01a65760a456c109dad93afd9f62a 100644 --- a/tensorflow/contrib/boosted_trees/proto/tree_config.proto +++ b/tensorflow/contrib/boosted_trees/proto/tree_config.proto @@ -53,9 +53,9 @@ message DenseFloatBinarySplit { // Float feature column and split threshold describing // the rule feature <= threshold. int32 feature_column = 1; - // If feature column is multivalent, this holds the index of the feature for - // the split. Defaults to 0. - int32 feature_id = 5; + // If feature column is multivalent, this holds the index of the dimensiong + // for the split. Defaults to 0. + int32 dimension_id = 5; float threshold = 2; // Node children indexing into a contiguous diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py index cf0958511350f82d548c56849f6179ae0f0215f5..c1acf351603dd80c2d14c7ee0a5b4c89706bc1bf 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py @@ -75,7 +75,7 @@ def _append_multi_values_to_dense_leaf(leaf, w): leaf.vector.value.append(x) -def _set_float_split(split, feat_col, thresh, l_id, r_id): +def _set_float_split(split, feat_col, thresh, l_id, r_id, feature_dim_id=None): """Helper method for building tree float splits. Sets split feature column, threshold and children. @@ -86,11 +86,14 @@ def _set_float_split(split, feat_col, thresh, l_id, r_id): thresh: threshold to split on forming rule x <= thresh. l_id: left child Id. r_id: right child Id. + feature_dim_id: dimension of the feature column to be used in the split. """ split.feature_column = feat_col split.threshold = thresh split.left_id = l_id split.right_id = r_id + if feature_dim_id is not None: + split.dimension_id = feature_dim_id def _set_categorical_id_split(split, feat_col, feat_id, l_id, r_id): @@ -116,12 +119,12 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): def setUp(self): """Sets up the prediction tests. - Create a batch of two examples having one dense float, two sparse float and - one sparse int features. + Create a batch of two examples having one dense float, two sparse float + single valued, one sparse float multidimensionl and one sparse int features. The data looks like the following: - | Instance | Dense0 | SparseF0 | SparseF1 | SparseI0 | - | 0 | 7 | -3 | | 9,1 | - | 1 | -2 | | 4 | | + | Instance | Dense0 | SparseF0 | SparseF1 | SparseI0 | SparseM + | 0 | 7 | -3 | | 9,1 | __, 5.0 + | 1 | -2 | | 4 | | 3, ___ """ super(PredictionOpsTest, self).setUp() self._dense_float_tensor = np.array([[7.0], [-2.0]]) @@ -131,11 +134,37 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): self._sparse_float_indices2 = np.array([[1, 0]]) self._sparse_float_values2 = np.array([4.0]) self._sparse_float_shape2 = np.array([2, 1]) + # Multi dimensional sparse float + self._sparse_float_indices_m = np.array([[0, 1], [1, 0]]) + self._sparse_float_values_m = np.array([5.0, 3.0]) + self._sparse_float_shape_m = np.array([2, 2]) + self._sparse_int_indices1 = np.array([[0, 0], [0, 1]]) self._sparse_int_values1 = np.array([9, 1]) self._sparse_int_shape1 = np.array([2, 2]) self._seed = 123 + def _get_predictions(self, + tree_ensemble_handle, + learner_config, + apply_dropout=False, + apply_averaging=False, + center_bias=False, + reduce_dim=False): + return prediction_ops.gradient_trees_prediction( + tree_ensemble_handle, + self._seed, [self._dense_float_tensor], + [self._sparse_float_indices1, self._sparse_float_indices2], + [self._sparse_float_values1, self._sparse_float_values2], + [self._sparse_float_shape1, self._sparse_float_shape2], + [self._sparse_int_indices1], [self._sparse_int_values1], + [self._sparse_int_shape1], + learner_config=learner_config, + apply_dropout=apply_dropout, + apply_averaging=apply_averaging, + center_bias=center_bias, + reduce_dim=reduce_dim) + def testEmptyEnsemble(self): with self.test_session(): # Empty tree ensenble. @@ -151,18 +180,9 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): learner_config = learner_pb2.LearnerConfig() learner_config.num_classes = 2 - result, dropout_info = prediction_ops.gradient_trees_prediction( + result, dropout_info = self._get_predictions( tree_ensemble_handle, - self._seed, [self._dense_float_tensor], [ - self._sparse_float_indices1, self._sparse_float_indices2 - ], [self._sparse_float_values1, self._sparse_float_values2], - [self._sparse_float_shape1, - self._sparse_float_shape2], [self._sparse_int_indices1], - [self._sparse_int_values1], [self._sparse_int_shape1], learner_config=learner_config.SerializeToString(), - apply_dropout=False, - apply_averaging=False, - center_bias=False, reduce_dim=True) self.assertAllEqual([[0], [0]], result.eval()) # Empty dropout. @@ -187,18 +207,9 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): learner_config = learner_pb2.LearnerConfig() learner_config.num_classes = 2 - result, dropout_info = prediction_ops.gradient_trees_prediction( + result, dropout_info = self._get_predictions( tree_ensemble_handle, - self._seed, [self._dense_float_tensor], [ - self._sparse_float_indices1, self._sparse_float_indices2 - ], [self._sparse_float_values1, self._sparse_float_values2], - [self._sparse_float_shape1, - self._sparse_float_shape2], [self._sparse_int_indices1], - [self._sparse_int_values1], [self._sparse_int_shape1], learner_config=learner_config.SerializeToString(), - apply_dropout=False, - apply_averaging=False, - center_bias=False, reduce_dim=True) self.assertAllClose([[-0.4], [-0.4]], result.eval()) @@ -226,18 +237,9 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): learner_config = learner_pb2.LearnerConfig() learner_config.num_classes = 3 - result, dropout_info = prediction_ops.gradient_trees_prediction( + result, dropout_info = self._get_predictions( tree_ensemble_handle, - self._seed, [self._dense_float_tensor], [ - self._sparse_float_indices1, self._sparse_float_indices2 - ], [self._sparse_float_values1, self._sparse_float_values2], - [self._sparse_float_shape1, - self._sparse_float_shape2], [self._sparse_int_indices1], - [self._sparse_int_values1], [self._sparse_int_shape1], learner_config=learner_config.SerializeToString(), - apply_dropout=False, - apply_averaging=False, - center_bias=False, reduce_dim=True) self.assertAllClose([[-0.4, 0.9], [-0.4, 0.9]], result.eval()) @@ -279,14 +281,94 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): learner_config = learner_pb2.LearnerConfig() learner_config.num_classes = 2 + result, dropout_info = self._get_predictions( + tree_ensemble_handle, + learner_config=learner_config.SerializeToString(), + reduce_dim=True) + + # The first example will get bias -0.4 from first tree and + # leaf 4 payload of -0.9 hence -1.3, the second example will + # get the same bias -0.4 and leaf 3 payload (sparse feature missing) + # of 1.2 hence 0.8. + self.assertAllClose([[-1.3], [0.8]], result.eval()) + + # Empty dropout. + self.assertAllEqual([[], []], dropout_info.eval()) + + def testFullEnsembleWithMultidimensionalSparseSingleClass(self): + with self.test_session(): + tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() + # Bias tree. + tree1 = tree_ensemble_config.trees.add() + tree_ensemble_config.tree_metadata.add().is_finalized = True + _append_to_leaf(tree1.nodes.add().leaf, 0, -0.4) + + # Depth 3 tree. + tree2 = tree_ensemble_config.trees.add() + tree_ensemble_config.tree_metadata.add().is_finalized = True + # Use feature column 2 (sparse multidimensional), split on first value + # node 0. + _set_float_split( + tree2.nodes.add().sparse_float_binary_split_default_right.split, + 2, + 7.0, + 1, + 2, + feature_dim_id=0) + # Leafs split on second dimension of sparse multidimensional feature. + # Node 1. + _set_float_split( + tree2.nodes.add().sparse_float_binary_split_default_left.split, + 2, + 4.5, + 3, + 4, + feature_dim_id=1) + # Node 2. + _set_float_split( + tree2.nodes.add().sparse_float_binary_split_default_right.split, + 2, + 9, + 5, + 6, + feature_dim_id=1) + + # Node 3. + _append_to_leaf(tree2.nodes.add().leaf, 0, 0.6) + # Node 4. + _append_to_leaf(tree2.nodes.add().leaf, 0, 1.3) + + # Node 5. + _append_to_leaf(tree2.nodes.add().leaf, 0, -0.1) + # Node 6. + _append_to_leaf(tree2.nodes.add().leaf, 0, 0.8) + + tree_ensemble_config.tree_weights.append(1.0) + tree_ensemble_config.tree_weights.append(1.0) + + tree_ensemble_handle = model_ops.tree_ensemble_variable( + stamp_token=0, + tree_ensemble_config=tree_ensemble_config.SerializeToString(), + name="full_ensemble") + resources.initialize_resources(resources.shared_resources()).run() + + # Prepare learner config. + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + result, dropout_info = prediction_ops.gradient_trees_prediction( tree_ensemble_handle, self._seed, [self._dense_float_tensor], [ - self._sparse_float_indices1, self._sparse_float_indices2 - ], [self._sparse_float_values1, self._sparse_float_values2], - [self._sparse_float_shape1, - self._sparse_float_shape2], [self._sparse_int_indices1], - [self._sparse_int_values1], [self._sparse_int_shape1], + self._sparse_float_indices1, self._sparse_float_indices2, + self._sparse_float_indices_m + ], [ + self._sparse_float_values1, self._sparse_float_values2, + self._sparse_float_values_m + ], [ + self._sparse_float_shape1, self._sparse_float_shape2, + self._sparse_float_shape_m + ], [self._sparse_int_indices1], [self._sparse_int_values1], + [self._sparse_int_shape1], learner_config=learner_config.SerializeToString(), apply_dropout=False, apply_averaging=False, @@ -294,10 +376,9 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): reduce_dim=True) # The first example will get bias -0.4 from first tree and - # leaf 4 payload of -0.9 hence -1.3, the second example will - # get the same bias -0.4 and leaf 3 payload (sparse feature missing) - # of 1.2 hence 0.8. - self.assertAllClose([[-1.3], [0.8]], result.eval()) + # leaf 5 payload of -0.1 hence -0.5, the second example will + # get the same bias -0.4 and leaf 3 payload (0.6) hence 0.2 + self.assertAllClose([[-0.5], [0.2]], result.eval()) # Empty dropout. self.assertAllEqual([[], []], dropout_info.eval()) @@ -337,19 +418,9 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): learner_config = learner_pb2.LearnerConfig() learner_config.num_classes = 2 learner_config.growing_mode = learner_pb2.LearnerConfig.WHOLE_TREE - - result, dropout_info = prediction_ops.gradient_trees_prediction( + result, dropout_info = self._get_predictions( tree_ensemble_handle, - self._seed, [self._dense_float_tensor], [ - self._sparse_float_indices1, self._sparse_float_indices2 - ], [self._sparse_float_values1, self._sparse_float_values2], - [self._sparse_float_shape1, - self._sparse_float_shape2], [self._sparse_int_indices1], - [self._sparse_int_values1], [self._sparse_int_shape1], learner_config=learner_config.SerializeToString(), - apply_dropout=False, - apply_averaging=False, - center_bias=False, reduce_dim=True) # All the examples should get only the bias since the second tree is @@ -394,19 +465,9 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): learner_config = learner_pb2.LearnerConfig() learner_config.num_classes = 2 learner_config.growing_mode = learner_pb2.LearnerConfig.LAYER_BY_LAYER - - result, dropout_info = prediction_ops.gradient_trees_prediction( + result, dropout_info = self._get_predictions( tree_ensemble_handle, - self._seed, [self._dense_float_tensor], [ - self._sparse_float_indices1, self._sparse_float_indices2 - ], [self._sparse_float_values1, self._sparse_float_values2], - [self._sparse_float_shape1, - self._sparse_float_shape2], [self._sparse_int_indices1], - [self._sparse_int_values1], [self._sparse_int_shape1], learner_config=learner_config.SerializeToString(), - apply_dropout=False, - apply_averaging=False, - center_bias=False, reduce_dim=True) # The first example will get bias -0.4 from first tree and @@ -453,19 +514,9 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): # Prepare learner config. learner_config = learner_pb2.LearnerConfig() learner_config.num_classes = 2 - - result, dropout_info = prediction_ops.gradient_trees_prediction( + result, dropout_info = self._get_predictions( tree_ensemble_handle, - self._seed, [self._dense_float_tensor], [ - self._sparse_float_indices1, self._sparse_float_indices2 - ], [self._sparse_float_values1, self._sparse_float_values2], - [self._sparse_float_shape1, - self._sparse_float_shape2], [self._sparse_int_indices1], - [self._sparse_int_values1], [self._sparse_int_shape1], learner_config=learner_config.SerializeToString(), - apply_dropout=False, - apply_averaging=False, - center_bias=False, reduce_dim=True) # The first example will get bias -0.4 from first tree and @@ -512,18 +563,9 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): learner_config.multi_class_strategy = ( learner_pb2.LearnerConfig.TREE_PER_CLASS) - result, dropout_info = prediction_ops.gradient_trees_prediction( + result, dropout_info = self._get_predictions( tree_ensemble_handle, - self._seed, [self._dense_float_tensor], [ - self._sparse_float_indices1, self._sparse_float_indices2 - ], [self._sparse_float_values1, self._sparse_float_values2], - [self._sparse_float_shape1, - self._sparse_float_shape2], [self._sparse_int_indices1], - [self._sparse_int_values1], [self._sparse_int_shape1], learner_config=learner_config.SerializeToString(), - apply_dropout=False, - apply_averaging=False, - center_bias=False, reduce_dim=True) # The first example will get bias class 1 -0.2 from first tree and # leaf 2 payload (sparse feature missing) of 0.5 hence [0.5, -0.2], @@ -572,18 +614,9 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): learner_config.multi_class_strategy = ( learner_pb2.LearnerConfig.FULL_HESSIAN) - result, dropout_info = prediction_ops.gradient_trees_prediction( + result, dropout_info = self._get_predictions( tree_ensemble_handle, - self._seed, [self._dense_float_tensor], [ - self._sparse_float_indices1, self._sparse_float_indices2 - ], [self._sparse_float_values1, self._sparse_float_values2], - [self._sparse_float_shape1, - self._sparse_float_shape2], [self._sparse_int_indices1], - [self._sparse_int_values1], [self._sparse_int_shape1], learner_config=learner_config.SerializeToString(), - apply_dropout=False, - apply_averaging=False, - center_bias=False, reduce_dim=False) # The first example will get bias class 1 -0.2 from first tree and # leaf 2 payload (sparse feature missing) of 0.5 hence [0.5, -0.2], @@ -631,18 +664,9 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): learner_config.multi_class_strategy = ( learner_pb2.LearnerConfig.FULL_HESSIAN) - result, dropout_info = prediction_ops.gradient_trees_prediction( + result, dropout_info = self._get_predictions( tree_ensemble_handle, - self._seed, [self._dense_float_tensor], [ - self._sparse_float_indices1, self._sparse_float_indices2 - ], [self._sparse_float_values1, self._sparse_float_values2], - [self._sparse_float_shape1, - self._sparse_float_shape2], [self._sparse_int_indices1], - [self._sparse_int_values1], [self._sparse_int_shape1], learner_config=learner_config.SerializeToString(), - apply_dropout=False, - apply_averaging=False, - center_bias=False, reduce_dim=False) # The first example will get bias class 1 -0.2 and -2 for class 2 from # first tree and leaf 2 payload (sparse feature missing) of 0.5 hence @@ -653,26 +677,6 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): # Empty dropout. self.assertAllEqual([[], []], dropout_info.eval()) - def _get_predictions(self, - tree_ensemble_handle, - learner_config, - apply_dropout=False, - apply_averaging=False, - center_bias=False): - return prediction_ops.gradient_trees_prediction( - tree_ensemble_handle, - self._seed, [self._dense_float_tensor], [ - self._sparse_float_indices1, self._sparse_float_indices2 - ], [self._sparse_float_values1, self._sparse_float_values2], - [self._sparse_float_shape1, - self._sparse_float_shape2], [self._sparse_int_indices1], - [self._sparse_int_values1], [self._sparse_int_shape1], - learner_config=learner_config.SerializeToString(), - apply_dropout=apply_dropout, - apply_averaging=apply_averaging, - center_bias=center_bias, - reduce_dim=True) - def testDropout(self): with self.test_session(): # Empty tree ensenble. @@ -699,10 +703,11 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): result, dropout_info = self._get_predictions( tree_ensemble_handle, - learner_config=learner_config, + learner_config=learner_config.SerializeToString(), apply_dropout=True, apply_averaging=False, - center_bias=False) + center_bias=False, + reduce_dim=True) # We expect approx 500 trees were dropped. dropout_info = dropout_info.eval() @@ -719,10 +724,11 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): # Don't apply dropout. result_no_dropout, no_dropout_info = self._get_predictions( tree_ensemble_handle, - learner_config=learner_config, + learner_config=learner_config.SerializeToString(), apply_dropout=False, apply_averaging=False, - center_bias=False) + center_bias=False, + reduce_dim=True) self.assertEqual(result.eval().size, result_no_dropout.eval().size) for i in range(result.eval().size): @@ -760,17 +766,19 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): result, dropout_info = self._get_predictions( tree_ensemble_handle, - learner_config=learner_config, + learner_config=learner_config.SerializeToString(), apply_dropout=True, apply_averaging=False, - center_bias=False) + center_bias=False, + reduce_dim=True) result_center, dropout_info_center = self._get_predictions( tree_ensemble_handle, - learner_config=learner_config, + learner_config=learner_config.SerializeToString(), apply_dropout=True, apply_averaging=False, - center_bias=True) + center_bias=True, + reduce_dim=True) dropout_info = dropout_info.eval() dropout_info_center = dropout_info_center.eval() @@ -830,17 +838,19 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): result, dropout_info = self._get_predictions( tree_ensemble_handle, - learner_config=learner_config, + learner_config=learner_config.SerializeToString(), apply_dropout=True, apply_averaging=False, - center_bias=False) + center_bias=False, + reduce_dim=True) result_center, dropout_info_center = self._get_predictions( tree_ensemble_handle, - learner_config=learner_config, + learner_config=learner_config.SerializeToString(), apply_dropout=True, apply_averaging=False, - center_bias=True) + center_bias=True, + reduce_dim=True) dropout_info = dropout_info.eval() dropout_info_center = dropout_info_center.eval() @@ -888,28 +898,16 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): name="empty") resources.initialize_resources(resources.shared_resources()).run() - _, dropout_info_1 = prediction_ops.gradient_trees_prediction( + _, dropout_info_1 = self._get_predictions( tree_ensemble_handle, - self._seed, [self._dense_float_tensor], [ - self._sparse_float_indices1, self._sparse_float_indices2 - ], [self._sparse_float_values1, self._sparse_float_values2], - [self._sparse_float_shape1, - self._sparse_float_shape2], [self._sparse_int_indices1], - [self._sparse_int_values1], [self._sparse_int_shape1], learner_config=learner_config.SerializeToString(), apply_dropout=True, apply_averaging=False, center_bias=False, reduce_dim=True) - _, dropout_info_2 = prediction_ops.gradient_trees_prediction( + _, dropout_info_2 = self._get_predictions( tree_ensemble_handle, - self._seed, [self._dense_float_tensor], [ - self._sparse_float_indices1, self._sparse_float_indices2 - ], [self._sparse_float_values1, self._sparse_float_values2], - [self._sparse_float_shape1, - self._sparse_float_shape2], [self._sparse_int_indices1], - [self._sparse_int_values1], [self._sparse_int_shape1], learner_config=learner_config.SerializeToString(), apply_dropout=True, apply_averaging=False, @@ -919,12 +917,12 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): # Different seed. _, dropout_info_3 = prediction_ops.gradient_trees_prediction( tree_ensemble_handle, - 112314, [self._dense_float_tensor], [ - self._sparse_float_indices1, self._sparse_float_indices2 - ], [self._sparse_float_values1, self._sparse_float_values2], - [self._sparse_float_shape1, - self._sparse_float_shape2], [self._sparse_int_indices1], - [self._sparse_int_values1], [self._sparse_int_shape1], + 112314, [self._dense_float_tensor], + [self._sparse_float_indices1, self._sparse_float_indices2], + [self._sparse_float_values1, self._sparse_float_values2], + [self._sparse_float_shape1, self._sparse_float_shape2], + [self._sparse_int_indices1], [self._sparse_int_values1], + [self._sparse_int_shape1], learner_config=learner_config.SerializeToString(), apply_dropout=True, apply_averaging=False, @@ -932,14 +930,8 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): reduce_dim=True) # First seed with centering bias. - _, dropout_info_4 = prediction_ops.gradient_trees_prediction( + _, dropout_info_4 = self._get_predictions( tree_ensemble_handle, - self._seed, [self._dense_float_tensor], [ - self._sparse_float_indices1, self._sparse_float_indices2 - ], [self._sparse_float_values1, self._sparse_float_values2], - [self._sparse_float_shape1, - self._sparse_float_shape2], [self._sparse_int_indices1], - [self._sparse_int_values1], [self._sparse_int_shape1], learner_config=learner_config.SerializeToString(), apply_dropout=True, apply_averaging=False, @@ -983,17 +975,19 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): result, dropout_info = self._get_predictions( tree_ensemble_handle, - learner_config=learner_config, + learner_config=learner_config.SerializeToString(), apply_dropout=True, apply_averaging=False, - center_bias=False) + center_bias=False, + reduce_dim=True) result_no_dropout, _ = self._get_predictions( tree_ensemble_handle, - learner_config=learner_config, + learner_config=learner_config.SerializeToString(), apply_dropout=False, apply_averaging=False, - center_bias=False) + center_bias=False, + reduce_dim=True) self.assertAllEqual([[], []], dropout_info.eval()) self.assertAllClose(result.eval(), result_no_dropout.eval()) @@ -1048,12 +1042,16 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): # Do averaging. result, dropout_info = self._get_predictions( - tree_ensemble_handle, learner_config, apply_averaging=True) + tree_ensemble_handle, + learner_config.SerializeToString(), + apply_averaging=True, + reduce_dim=True) - pattern_result, pattern_dropout_info = (self._get_predictions( + pattern_result, pattern_dropout_info = self._get_predictions( adjusted_tree_ensemble_handle, - learner_config_no_averaging, - apply_averaging=False)) + learner_config_no_averaging.SerializeToString(), + apply_averaging=False, + reduce_dim=True) self.assertAllEqual(result.eval(), pattern_result.eval()) self.assertAllEqual(dropout_info.eval(), pattern_dropout_info.eval()) @@ -1116,15 +1114,22 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): resources.initialize_resources(resources.shared_resources()).run() result_1, dropout_info_1 = self._get_predictions( - tree_ensemble_handle, learner_config_1, apply_averaging=True) + tree_ensemble_handle, + learner_config_1.SerializeToString(), + apply_averaging=True, + reduce_dim=True) result_2, dropout_info_2 = self._get_predictions( - tree_ensemble_handle, learner_config_2, apply_averaging=True) + tree_ensemble_handle, + learner_config_2.SerializeToString(), + apply_averaging=True, + reduce_dim=True) pattern_result, pattern_dropout_info = self._get_predictions( adjusted_tree_ensemble_handle, - learner_config_no_averaging, - apply_averaging=False) + learner_config_no_averaging.SerializeToString(), + apply_averaging=False, + reduce_dim=True) self.assertAllEqual(result_1.eval(), pattern_result.eval()) self.assertAllEqual(result_2.eval(), pattern_result.eval()) @@ -1179,12 +1184,16 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): resources.initialize_resources(resources.shared_resources()).run() result, dropout_info = self._get_predictions( - tree_ensemble_handle, learner_config, apply_averaging=True) + tree_ensemble_handle, + learner_config.SerializeToString(), + apply_averaging=True, + reduce_dim=True) - pattern_result, pattern_dropout_info = (self._get_predictions( + pattern_result, pattern_dropout_info = self._get_predictions( adjusted_tree_ensemble_handle, - learner_config_no_averaging, - apply_averaging=False)) + learner_config_no_averaging.SerializeToString(), + apply_averaging=False, + reduce_dim=True) self.assertAllEqual(result.eval(), pattern_result.eval()) self.assertAllEqual(dropout_info.eval(), pattern_dropout_info.eval()) @@ -1224,10 +1233,6 @@ class PartitionExamplesOpsTest(test_util.TensorFlowTestCase): name="full_ensemble") resources.initialize_resources(resources.shared_resources()).run() - # Prepare learner config. - learner_config = learner_pb2.LearnerConfig() - learner_config.num_classes = 2 - result = prediction_ops.gradient_trees_partition_examples( tree_ensemble_handle, [self._dense_float_tensor], [ self._sparse_float_indices1, self._sparse_float_indices2 @@ -1263,10 +1268,6 @@ class PartitionExamplesOpsTest(test_util.TensorFlowTestCase): name="full_ensemble") resources.initialize_resources(resources.shared_resources()).run() - # Prepare learner config. - learner_config = learner_pb2.LearnerConfig() - learner_config.num_classes = 2 - result = prediction_ops.gradient_trees_partition_examples( tree_ensemble_handle, [self._dense_float_tensor], [ self._sparse_float_indices1, self._sparse_float_indices2 @@ -1302,10 +1303,6 @@ class PartitionExamplesOpsTest(test_util.TensorFlowTestCase): name="full_ensemble") resources.initialize_resources(resources.shared_resources()).run() - # Prepare learner config. - learner_config = learner_pb2.LearnerConfig() - learner_config.num_classes = 2 - result = prediction_ops.gradient_trees_partition_examples( tree_ensemble_handle, [self._dense_float_tensor], [ self._sparse_float_indices1, self._sparse_float_indices2 diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/quantile_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/quantile_ops_test.py index 1513c11c33d538dedabe10e4411bdd1373b16c7f..888d5c57ed33446c8b6f18d2d1e393647613d132 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/quantile_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/quantile_ops_test.py @@ -48,15 +48,16 @@ class QuantileBucketsOpTest(test_util.TensorFlowTestCase): def testBasicQuantileBuckets(self): """Sets up the quantile summary op test as follows. - Create a batch of 6 examples having a dense and sparse features. + Create a batch of 6 examples having a dense and sparse features. SparseM is + a sparse multi-dimensional (multivalent) feature. The data looks like this - | Instance | instance weights | Dense 0 | Sparse 0 - | 0 | 10 | 1 | - | 1 | 1 | 2 | 2 - | 2 | 1 | 3 | 3 - | 3 | 1 | 4 | 4 - | 4 | 1 | 4 | 5 - | 5 | 1 | 5 | 6 + | Instance | instance weights | Dense 0 | Sparse 0 | SparseM + | 0 | 10 | 1 | | | | + | 1 | 1 | 2 | 2 | 2 | | + | 2 | 1 | 3 | 3 | 3 | | + | 3 | 1 | 4 | 4 | | 4 | + | 4 | 1 | 4 | 5 | | 5 | + | 5 | 1 | 5 | 6 | | 6 | """ dense_float_tensor_0 = constant_op.constant( @@ -66,20 +67,29 @@ class QuantileBucketsOpTest(test_util.TensorFlowTestCase): sparse_values_0 = constant_op.constant( [2, 3, 4, 5, 6], dtype=dtypes.float32) sparse_shape_0 = constant_op.constant([6, 1], dtype=dtypes.int64) + # Multi-dimensional feature that should have the same quantiles as Sparse 0. + sparse_indices_m = constant_op.constant( + [[1, 1], [2, 0], [3, 1], [4, 1], [5, 1]], dtype=dtypes.int64) + sparse_values_m = constant_op.constant( + [2, 3, 4, 5, 6], dtype=dtypes.float32) + sparse_shape_m = constant_op.constant([6, 2], dtype=dtypes.int64) + example_weights = constant_op.constant( [10, 1, 1, 1, 1, 1], dtype=dtypes.float32) with self.test_session(): config = self._gen_config(0.33, 3) dense_buckets, sparse_buckets = quantile_ops.quantile_buckets( - [dense_float_tensor_0], [sparse_indices_0], [sparse_values_0], - [sparse_shape_0], + [dense_float_tensor_0], [sparse_indices_0, sparse_indices_m], + [sparse_values_0, sparse_values_m], [sparse_shape_0, sparse_shape_m], example_weights=example_weights, dense_config=[config], - sparse_config=[config]) + sparse_config=[config, config]) self.assertAllEqual([1, 3, 5], dense_buckets[0].eval()) self.assertAllEqual([2, 4, 6.], sparse_buckets[0].eval()) + # Multidimensional sparse. + self.assertAllEqual([2, 4, 6.], sparse_buckets[1].eval()) def testStreamingQuantileBucketsWithVaryingBatch(self): """Sets up the quantile summary op test as follows. @@ -214,10 +224,10 @@ class QuantileBucketsOpTest(test_util.TensorFlowTestCase): resources.initialize_resources(resources.shared_resources()).run() sparse_indices_0 = constant_op.constant( - [[1, 0], [2, 0], [3, 0], [4, 0], [5, 0]], dtype=dtypes.int64) + [[1, 0], [2, 1], [3, 0], [4, 2], [5, 0]], dtype=dtypes.int64) sparse_values_0 = constant_op.constant( [2.0, 3.0, 4.0, 5.0, 6.0], dtype=dtypes.float32) - sparse_shape_0 = constant_op.constant([6, 1], dtype=dtypes.int64) + sparse_shape_0 = constant_op.constant([6, 3], dtype=dtypes.int64) example_weights = constant_op.constant( [10, 1, 1, 1, 1, 1], dtype=dtypes.float32, shape=[6, 1]) update = accumulator.add_summary( @@ -349,19 +359,21 @@ class QuantilesOpTest(test_util.TensorFlowTestCase): def setUp(self): """Sets up the quantile op tests. - Create a batch of 4 examples having 2 dense and 3 sparse features. + Create a batch of 4 examples having 2 dense and 4 sparse features. + Forth sparse feature is multivalent (3 dimensional) The data looks like this - | Instance | Dense 0 | Dense 1 | Sparse 0 | Sparse 1 | Sparse 2 - | 0 | -0.1 | -1 | -2 | 0.1 | - | 1 | 0.4 | -15 | 5.5 | | 2 - | 2 | 3.2 | 18 | 16 | 3 | - | 3 | 190 | 1000 | 17.5 | -3 | 4 + | Instance | Dense 0 | Dense 1 | Sparse 0 | Sparse 1 |Sparse 2| SparseM + | 0 | -0.1 | -1 | -2 | 0.1 | |_ ,1,_ + | 1 | 0.4 | -15 | 5.5 | | 2 |2 ,_,_ + | 2 | 3.2 | 18 | 16 | 3 | |__,_,_ + | 3 | 190 | 1000 | 17.5 | -3 | 4 |1 ,8,1 Quantiles are: Dense 0: (-inf,0.4], (0.4,5], (5, 190] Dense 1: (-inf, -9], (-9,15], (15, 1000) Sparse 0: (-inf, 5], (5,16], (16, 100] Sparse 1: (-inf, 2], (2, 5] Sparse 2: (-inf, 100] + SparseM: (-inf, 1], (1,2], (2,1000] """ super(QuantilesOpTest, self).setUp() self._dense_float_tensor_0 = constant_op.constant( @@ -369,18 +381,26 @@ class QuantilesOpTest(test_util.TensorFlowTestCase): self._dense_float_tensor_1 = constant_op.constant( [[-1], [-15], [18], [1000]], dtype=dtypes.float32) # Sparse feature 0 - self._sparse_indices_0 = constant_op.constant([[0, 0], [1, 0], [2, 0], - [3, 0]]) + self._sparse_indices_0 = constant_op.constant( + [[0, 0], [1, 0], [2, 0], [3, 0]], dtype=dtypes.int64) self._sparse_values_0 = constant_op.constant([-2, 5.5, 16, 17.5]) self._sparse_shape_0 = constant_op.constant([4, 1]) # Sprase feature 1 - self._sparse_indices_1 = constant_op.constant([[0, 0], [2, 0], [3, 0]]) + self._sparse_indices_1 = constant_op.constant( + [[0, 0], [2, 0], [3, 0]], dtype=dtypes.int64) self._sparse_values_1 = constant_op.constant([0.1, 3, -3]) self._sparse_shape_1 = constant_op.constant([4, 1]) # Sprase feature 2 - self._sparse_indices_2 = constant_op.constant([[1, 0], [3, 0]]) + self._sparse_indices_2 = constant_op.constant( + [[1, 0], [3, 0]], dtype=dtypes.int64) self._sparse_values_2 = constant_op.constant([2, 4], dtype=dtypes.float32) self._sparse_shape_2 = constant_op.constant([4, 1]) + # Sprase feature M + self._sparse_indices_m = constant_op.constant( + [[0, 1], [1, 0], [3, 0], [3, 1], [3, 2]], dtype=dtypes.int64) + self._sparse_values_m = constant_op.constant( + [1, 2, 1, 8, 1], dtype=dtypes.float32) + self._sparse_shape_m = constant_op.constant([4, 1]) # Quantiles self._dense_thresholds_0 = [0.4, 5, 190] self._dense_thresholds_1 = [-9, 15, 1000] @@ -388,52 +408,76 @@ class QuantilesOpTest(test_util.TensorFlowTestCase): self._sparse_thresholds_0 = [5, 16, 100] self._sparse_thresholds_1 = [2, 5] self._sparse_thresholds_2 = [100] + self._sparse_thresholds_m = [1, 2, 1000] def testDenseFeaturesOnly(self): with self.test_session(): dense_quantiles, _ = quantile_ops.quantiles( [self._dense_float_tensor_0, self._dense_float_tensor_1], [], - [self._dense_thresholds_0, self._dense_thresholds_1], []) + [self._dense_thresholds_0, self._dense_thresholds_1], [], []) # Dense feature 0 - self.assertAllEqual([0, 0, 1, 2], dense_quantiles[0].eval()) + self.assertAllEqual([[0, 0], [0, 0], [1, 0], [2, 0]], + dense_quantiles[0].eval()) # Dense feature 1 - self.assertAllEqual([1, 0, 2, 2], dense_quantiles[1].eval()) + self.assertAllEqual([[1, 0], [0, 0], [2, 0], [2, 0]], + dense_quantiles[1].eval()) def testSparseFeaturesOnly(self): with self.test_session(): - _, sparse_quantiles = quantile_ops.quantiles( - [], - [self._sparse_values_0, self._sparse_values_1, self._sparse_values_2], - [], [self._sparse_thresholds_0, self._sparse_thresholds_1, - self._sparse_thresholds_2]) - + _, sparse_quantiles = quantile_ops.quantiles([], [ + self._sparse_values_0, self._sparse_values_1, self._sparse_values_2, + self._sparse_values_m + ], [], [ + self._sparse_thresholds_0, self._sparse_thresholds_1, + self._sparse_thresholds_2, self._sparse_thresholds_m + ], [ + self._sparse_indices_0, self._sparse_indices_1, + self._sparse_indices_2, self._sparse_indices_m + ]) + + self.assertAllEqual(4, len(sparse_quantiles)) # Sparse feature 0 - self.assertAllEqual([0, 1, 1, 2], sparse_quantiles[0].eval()) + self.assertAllEqual([[0, 0], [1, 0], [1, 0], [2, 0]], + sparse_quantiles[0].eval()) # Sparse feature 1 - self.assertAllEqual([0, 1, 0], sparse_quantiles[1].eval()) + self.assertAllEqual([[0, 0], [1, 0], [0, 0]], sparse_quantiles[1].eval()) # Sparse feature 2 - self.assertAllEqual([0, 0], sparse_quantiles[2].eval()) + self.assertAllEqual([[0, 0], [0, 0]], sparse_quantiles[2].eval()) + # Multidimensional feature. + self.assertAllEqual([[0, 1], [1, 0], [0, 0], [2, 1], [0, 2]], + sparse_quantiles[3].eval()) def testDenseAndSparseFeatures(self): with self.test_session(): dense_quantiles, sparse_quantiles = quantile_ops.quantiles( - [self._dense_float_tensor_0, self._dense_float_tensor_1], - [self._sparse_values_0, self._sparse_values_1, self._sparse_values_2], - [self._dense_thresholds_0, self._dense_thresholds_1], - [self._sparse_thresholds_0, self._sparse_thresholds_1, - self._sparse_thresholds_2]) + [self._dense_float_tensor_0, self._dense_float_tensor_1], [ + self._sparse_values_0, self._sparse_values_1, + self._sparse_values_2, self._sparse_values_m + ], [self._dense_thresholds_0, self._dense_thresholds_1], [ + self._sparse_thresholds_0, self._sparse_thresholds_1, + self._sparse_thresholds_2, self._sparse_thresholds_m + ], [ + self._sparse_indices_0, self._sparse_indices_1, + self._sparse_indices_2, self._sparse_indices_m + ]) # Dense feature 0 - self.assertAllEqual([0, 0, 1, 2], dense_quantiles[0].eval()) + self.assertAllEqual([[0, 0], [0, 0], [1, 0], [2, 0]], + dense_quantiles[0].eval()) # Dense feature 1 - self.assertAllEqual([1, 0, 2, 2], dense_quantiles[1].eval()) + self.assertAllEqual([[1, 0], [0, 0], [2, 0], [2, 0]], + dense_quantiles[1].eval()) # Sparse feature 0 - self.assertAllEqual([0, 1, 1, 2], sparse_quantiles[0].eval()) + self.assertAllEqual([[0, 0], [1, 0], [1, 0], [2, 0]], + sparse_quantiles[0].eval()) # Sparse feature 1 - self.assertAllEqual([0, 1, 0], sparse_quantiles[1].eval()) + self.assertAllEqual([[0, 0], [1, 0], [0, 0]], sparse_quantiles[1].eval()) # Sparse feature 2 - self.assertAllEqual([0, 0], sparse_quantiles[2].eval()) + self.assertAllEqual([[0, 0], [0, 0]], sparse_quantiles[2].eval()) + # Multidimensional feature. + self.assertAllEqual([[0, 1], [1, 0], [0, 0], [2, 1], [0, 2]], + sparse_quantiles[3].eval()) def testBucketizeWithInputBoundaries(self): with self.test_session(): diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py index edf088b5fa28d3e465d4e3d8ea7cf6745d48a91f..28834ef55bf8e1f32cc8f2380a4be3bf3824d8e1 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py @@ -38,7 +38,8 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase): # (-0.3, 0.19) | 0 | 1 | # (4.0, 0.13) | 1 | 1 | partition_ids = array_ops.constant([0, 0, 1], dtype=dtypes.int32) - bucket_ids = array_ops.constant([0, 1, 1], dtype=dtypes.int64) + bucket_ids = array_ops.constant( + [[0, 0], [1, 0], [1, 0]], dtype=dtypes.int64) gradients = array_ops.constant([2.4, -0.6, 8.0]) hessians = array_ops.constant([0.4, 0.38, 0.26]) bucket_boundaries = [0.3, 0.52] @@ -109,7 +110,8 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase): """Tests split handler op.""" with self.test_session() as sess: partition_ids = array_ops.constant([0, 0, 1], dtype=dtypes.int32) - bucket_ids = array_ops.constant([0, 1, 1], dtype=dtypes.int64) + bucket_ids = array_ops.constant( + [[0, 0], [1, 0], [1, 0]], dtype=dtypes.int64) gradients = array_ops.constant([[2.4, 3.0], [-0.6, 0.1], [8.0, 1.0]]) hessians = array_ops.constant([[[0.4, 1], [1, 1]], [[0.38, 1], [1, 1]], [[0.26, 1], [1, 1]]]) @@ -149,7 +151,7 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase): """Tests empty inputs op.""" with self.test_session() as sess: partition_ids = array_ops.constant([], dtype=dtypes.int32) - bucket_ids = array_ops.constant([], dtype=dtypes.int64) + bucket_ids = array_ops.constant([[]], dtype=dtypes.int64) gradients = array_ops.constant([]) hessians = array_ops.constant([]) bucket_boundaries = [0.3, 0.52] @@ -185,7 +187,11 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase): # (4.0, 0.13) | 1 | -1 | # (4.0, 0.13) | 1 | 1 | partition_ids = array_ops.constant([0, 0, 0, 1, 1], dtype=dtypes.int32) + # We have only 1 dimension in our sparse feature column. bucket_ids = array_ops.constant([-1, 0, 1, -1, 1], dtype=dtypes.int64) + dimension_ids = array_ops.constant([0, 0, 0, 0, 0], dtype=dtypes.int64) + bucket_ids = array_ops.stack([bucket_ids, dimension_ids], axis=1) + gradients = array_ops.constant([1.8, 2.4, 0.4, 8.0, 8.0]) hessians = array_ops.constant([0.78, 0.4, 0.24, 0.26, 0.26]) bucket_boundaries = array_ops.constant([0.3, 0.52]) @@ -207,6 +213,7 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase): multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS)) partitions, gains, splits = (sess.run([partitions, gains, splits])) self.assertAllEqual([0, 1], partitions) + self.assertEqual(2, len(splits)) # Check the split on partition 0. # -(0.2 + 1.2) / (0.12 + 0.2 + 2) expected_left_weight = -0.603448275862069 @@ -232,6 +239,8 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase): self.assertAllClose([expected_right_weight], right_child.value) self.assertEqual(0, split_node.split.feature_column) + # Sparse is one dimensional. + self.assertEqual(0, split_node.split.dimension_id) self.assertAllClose(0.52, split_node.split.threshold) @@ -253,14 +262,149 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase): self.assertAllClose([expected_right_weight], right_child.value) self.assertEqual(0, split_node.split.feature_column) + # Sparse is one dimensional. + self.assertEqual(0, split_node.split.dimension_id) self.assertAllClose(0.52, split_node.split.threshold) + def testMakeSparseSplitAllEmptyDimensions(self): + """Tests split handler op when all dimensions have only bias bucket id.""" + with self.test_session() as sess: + # The data looks like the following after dividing by number of steps (2). + # Gradients | Partition | Dimension | bucket ID | + # (0.9, 0.39) | 0 | 0 | -1 | + # (4.0, 0.13) | 1 | 0 | -1 | + partition_ids = array_ops.constant([0, 1], dtype=dtypes.int32) + # We have only 1 dimension in our sparse feature column. + bucket_ids = array_ops.constant([[-1, 0], [-1, 0]], dtype=dtypes.int64) + gradients = array_ops.constant([1.8, 8.0]) + hessians = array_ops.constant([0.78, 0.26]) + bucket_boundaries = array_ops.constant([0.3, 0.52]) + partitions, gains, splits = ( + split_handler_ops.build_sparse_inequality_splits( + num_minibatches=2, + partition_ids=partition_ids, + bucket_ids=bucket_ids, + gradients=gradients, + hessians=hessians, + bucket_boundaries=bucket_boundaries, + l1_regularization=0, + l2_regularization=2, + tree_complexity_regularization=0, + min_node_weight=0, + feature_column_group_id=0, + bias_feature_id=-1, + class_id=-1, + multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS)) + partitions, gains, splits = (sess.run([partitions, gains, splits])) + self.assertEqual(0, len(partitions)) + self.assertEqual(0, len(splits)) + + def testMakeSparseMultidimensionalSplit(self): + """Tests split handler op.""" + with self.test_session() as sess: + # Num of steps is 2. + # The feature column is three dimensional. + # First dimension has bias bucket only, the second has bias bucket and + # two valid buckets, the third has just one bias bucket and one valid + # bucket. + # Gradients | Partition | Dimension | bucket ID | + # (0.9, 0.39) | 0 | 0 | -1 | + # (1.2, 0.2) | 0 | 1 | 0 | + # (0.2, 0.12) | 0 | 1 | 2 | + # (0.1, 0.1) | 0 | 2 | 3 | + # Now second node - nothing interesting there, just one dimension. + # Second node has the same bucket ids for all dimensions. + # (4.0, 0.13) | 1 | 0 | -1 | + # (4.0, 0.13) | 1 | 2 | 3 | + + # Tree node ids. + partition_ids = array_ops.constant([0, 0, 0, 0, 1, 1], dtype=dtypes.int32) + + dimension_ids = array_ops.constant([0, 1, 1, 2, 0, 2], dtype=dtypes.int64) + bucket_ids = array_ops.constant([-1, 0, 2, 3, -1, 3], dtype=dtypes.int64) + bucket_ids = array_ops.stack([bucket_ids, dimension_ids], axis=1) + + gradients = array_ops.constant([1.8, 2.4, 0.4, 0.2, 8.0, 8.0]) + hessians = array_ops.constant([0.78, 0.4, 0.24, 0.2, 0.26, 0.26]) + bucket_boundaries = array_ops.constant([0.3, 0.52, 0.58, 0.6]) + partitions, gains, splits = ( + split_handler_ops.build_sparse_inequality_splits( + num_minibatches=2, + partition_ids=partition_ids, + bucket_ids=bucket_ids, + gradients=gradients, + hessians=hessians, + bucket_boundaries=bucket_boundaries, + l1_regularization=0, + l2_regularization=2, + tree_complexity_regularization=0, + min_node_weight=0, + feature_column_group_id=0, + bias_feature_id=-1, + class_id=-1, + multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS)) + partitions, gains, splits = (sess.run([partitions, gains, splits])) + self.assertAllEqual([0, 1], partitions) + self.assertEqual(2, len(splits)) + # Check the split on node 0 - it should split on second dimension + # -(0.2 + 1.2) / (0.12 + 0.2 + 2) + expected_left_weight = -0.603448275862069 + # (0.2 + 1.2) ** 2 / (0.12 + 0.2 + 2) + expected_left_gain = 0.8448275862068965 + # 0.5 / (0.07 + 2) + expected_right_weight = 0.24154589371980678 + # 0.5 ** 2 / (0.07 + 2) + expected_right_gain = 0.12077294685990339 + # (0.2 + 1.2 - 0.5) ** 2 / (0.12 + 0.2 + 0.07 + 2) + expected_bias_gain = 0.3389121338912133 + + split_info = split_info_pb2.SplitInfo() + split_info.ParseFromString(splits[0]) + left_child = split_info.left_child.vector + right_child = split_info.right_child.vector + split_node = split_info.split_node.sparse_float_binary_split_default_right + self.assertAllClose( + expected_left_gain + expected_right_gain - expected_bias_gain, gains[0]) + + self.assertAllClose([expected_left_weight], left_child.value) + + self.assertAllClose([expected_right_weight], right_child.value) + + self.assertEqual(0, split_node.split.feature_column) + # Split happened on second dimension. + self.assertEqual(1, split_node.split.dimension_id) + + self.assertAllClose(0.58, split_node.split.threshold) + + # Check the split on partition 1. + expected_left_weight = -1.8779342723004695 + expected_right_weight = 0 + + # Verify candidate for partition 1, there's only one active bucket here + # so zero gain is expected. + split_info.ParseFromString(splits[1]) + left_child = split_info.left_child.vector + right_child = split_info.right_child.vector + split_node = split_info.split_node.sparse_float_binary_split_default_left + + self.assertAllClose(0.0, gains[1]) + + self.assertAllClose([expected_left_weight], left_child.value) + + self.assertAllClose([expected_right_weight], right_child.value) + + self.assertEqual(0, split_node.split.feature_column) + self.assertEqual(2, split_node.split.dimension_id) + + self.assertAllClose(0.6, split_node.split.threshold) + def testMakeMulticlassSparseSplit(self): """Tests split handler op.""" with self.test_session() as sess: partition_ids = array_ops.constant([0, 0, 0, 1, 1], dtype=dtypes.int32) - bucket_ids = array_ops.constant([-1, 0, 1, -1, 1], dtype=dtypes.int64) + bucket_ids = array_ops.constant( + [[-1, 0], [0, 0], [1, 0], [-1, 0], [1, 0]], dtype=dtypes.int64) gradients = array_ops.constant([[1.8, 3.5], [2.4, 1.0], [0.4, 4.0], [8.0, 3.1], [8.0, 0.8]]) @@ -317,7 +461,8 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase): gradients = [1.8, 0.4, 2.8, 8.0, 8.0] hessians = [0.78, 0.24, 0.64, 0.26, 0.26] partition_ids = [0, 0, 0, 1, 1] - feature_ids = array_ops.constant([-1, 1, 2, -1, 1], dtype=dtypes.int64) + feature_ids = array_ops.constant( + [[-1, 0], [1, 0], [2, 0], [-1, 0], [1, 0]], dtype=dtypes.int64) partitions, gains, splits = ( split_handler_ops.build_categorical_equality_splits( num_minibatches=2, @@ -412,7 +557,8 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase): hessians = array_ops.constant( [hessian_0, hessian_1, hessian_2, hessian_3, hessian_4]) partition_ids = [0, 0, 0, 1, 1] - feature_ids = array_ops.constant([-1, 1, 2, -1, 1], dtype=dtypes.int64) + feature_ids = array_ops.constant( + [[-1, 0], [1, 0], [2, 0], [-1, 0], [1, 0]], dtype=dtypes.int64) partitions, gains, splits = ( split_handler_ops.build_categorical_equality_splits( num_minibatches=2, @@ -449,7 +595,7 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase): gradients = [] hessians = [] partition_ids = [] - feature_ids = [] + feature_ids = [[]] partitions, gains, splits = ( split_handler_ops.build_categorical_equality_splits( num_minibatches=0, diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/stats_accumulator_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/stats_accumulator_ops_test.py index 0022d4ad52b0699e6706ad04435f09d0d1cd57c3..978bf530cd99ec6af74a49cb96ff98023d7a15cb 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/stats_accumulator_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/stats_accumulator_ops_test.py @@ -38,22 +38,52 @@ class StatsAccumulatorScalarTest(test_util.TensorFlowTestCase): op1 = accumulator.add( stamp_token=0, partition_ids=[1, 2], - feature_ids=[2, 3], + feature_ids=[[2, 0], [3, 0]], gradients=[0.1, 0.3], hessians=[0.2, 0.4]) - op2 = accumulator.add(0, [1], [2], [0.1], [0.2]) + op2 = accumulator.add(0, [1], [[2, 0]], [0.1], [0.2]) with ops.control_dependencies([op1, op2]): - num_updates, partition, feature, grads, hessians = accumulator.flush( + num_updates, partition, bucket_ids, grads, hessians = accumulator.flush( stamp_token=0, next_stamp_token=1) - num_updates, partition, feature, grads, hessians = sess.run( - [num_updates, partition, feature, grads, hessians]) + num_updates, partition, bucket_ids, grads, hessians = sess.run( + [num_updates, partition, bucket_ids, grads, hessians]) - result = _AccumulatorResultToDict(partition, feature, grads, hessians) + result = _AccumulatorResultToDict(partition, bucket_ids, grads, hessians) self.assertEqual(num_updates, 2) self.assertEqual(len(result), 2) - self.assertAllClose(result[(1, 2)], [0.2, 0.4]) - self.assertAllClose(result[(2, 3)], [0.3, 0.4]) + # Key is partion, bucket, dimension + self.assertAllClose(result[(1, 2, 0)], [0.2, 0.4]) + self.assertAllClose(result[(2, 3, 0)], [0.3, 0.4]) + + def testMultidimensionalAcculumator(self): + with self.test_session() as sess: + accumulator = stats_accumulator_ops.StatsAccumulator( + stamp_token=0, + gradient_shape=tensor_shape.scalar(), + hessian_shape=tensor_shape.scalar()) + with ops.control_dependencies([accumulator._create_op]): + op1 = accumulator.add( + stamp_token=0, + partition_ids=[1, 2, 1], + feature_ids=[[2, 2], [3, 0], [2, 2]], + gradients=[0.1, 0.3, 0.8], + hessians=[0.2, 0.4, -9]) + op2 = accumulator.add(0, [2, 1], [[3, 1], [2, 2]], [0.1, 1], [0.2, -1]) + + with ops.control_dependencies([op1, op2]): + num_updates, partition, bucket_ids, grads, hessians = accumulator.flush( + stamp_token=0, next_stamp_token=1) + num_updates, partition, bucket_ids, grads, hessians = sess.run( + [num_updates, partition, bucket_ids, grads, hessians]) + + result = _AccumulatorResultToDict(partition, bucket_ids, grads, hessians) + self.assertEqual(num_updates, 2) + self.assertEqual(len(result), 3) + # Key is partion, bucket, dimension. + self.assertAllClose(result[(1, 2, 2)], [1.9, -9.8]) + self.assertAllClose(result[(2, 3, 0)], [0.3, 0.4]) + self.assertAllClose(result[(2, 3, 1)], [0.1, 0.2]) def testDropStaleUpdate(self): with self.test_session() as sess: @@ -65,13 +95,13 @@ class StatsAccumulatorScalarTest(test_util.TensorFlowTestCase): op1 = accumulator.add( stamp_token=0, partition_ids=[1, 2], - feature_ids=[2, 3], + feature_ids=[[2, 0], [3, 0]], gradients=[0.1, 0.3], hessians=[0.2, 0.4]) op2 = accumulator.add( stamp_token=-1, partition_ids=[1], - feature_ids=[2], + feature_ids=[[2, 0]], gradients=[0.1], hessians=[0.2]) @@ -84,8 +114,8 @@ class StatsAccumulatorScalarTest(test_util.TensorFlowTestCase): result = _AccumulatorResultToDict(partition, feature, grads, hessians) self.assertEqual(num_updates, 1) self.assertEqual(len(result), 2) - self.assertAllClose(result[(1, 2)], [0.1, 0.2]) - self.assertAllClose(result[(2, 3)], [0.3, 0.4]) + self.assertAllClose(result[(1, 2, 0)], [0.1, 0.2]) + self.assertAllClose(result[(2, 3, 0)], [0.3, 0.4]) def testSerialize(self): with self.test_session() as sess: @@ -97,7 +127,7 @@ class StatsAccumulatorScalarTest(test_util.TensorFlowTestCase): op1 = accumulator.add( stamp_token=0, partition_ids=[1, 2], - feature_ids=[2, 3], + feature_ids=[[2, 0], [3, 0]], gradients=[0.1, 0.3], hessians=[0.2, 0.4]) @@ -123,8 +153,8 @@ class StatsAccumulatorScalarTest(test_util.TensorFlowTestCase): self.assertEqual(num_updates, 1) self.assertEqual(num_updates_2, 1) self.assertEqual(len(result_1), 2) - self.assertAllClose(result_1[(1, 2)], [0.1, 0.2]) - self.assertAllClose(result_1[(2, 3)], [0.3, 0.4]) + self.assertAllClose(result_1[(1, 2, 0)], [0.1, 0.2]) + self.assertAllClose(result_1[(2, 3, 0)], [0.3, 0.4]) self.assertAllEqual(result_1, result_2) self.assertEqual(0, stamp_token) @@ -139,18 +169,19 @@ class StatsAccumulatorScalarTest(test_util.TensorFlowTestCase): op1 = accumulator.add( stamp_token=0, partition_ids=[1, 2], - feature_ids=[2, 3], + feature_ids=[[2, 0], [3, 1]], gradients=[0.1, 0.3], hessians=[0.2, 0.4]) with ops.control_dependencies([op1]): - deserialize = (accumulator.deserialize( - stamp_token=2, - num_updates=3, - partition_ids=[3, 4], - feature_ids=[5, 6], - gradients=[0.4, 0.5], - hessians=[0.6, 0.7])) + deserialize = ( + accumulator.deserialize( + stamp_token=2, + num_updates=3, + partition_ids=[3, 4], + feature_ids=[[5, 0], [6, 2]], + gradients=[0.4, 0.5], + hessians=[0.6, 0.7])) with ops.control_dependencies([deserialize]): num_updates, partition, feature, grads, hessians = accumulator.flush( stamp_token=2, next_stamp_token=3) @@ -161,8 +192,8 @@ class StatsAccumulatorScalarTest(test_util.TensorFlowTestCase): hessians) self.assertEqual(num_updates, 3) self.assertEqual(len(result), 2) - self.assertAllClose(result[(3, 5)], [0.4, 0.6]) - self.assertAllClose(result[(4, 6)], [0.5, 0.7]) + self.assertAllClose(result[(3, 5, 0)], [0.4, 0.6]) + self.assertAllClose(result[(4, 6, 2)], [0.5, 0.7]) def testMakeSummary(self): with self.test_session() as sess: @@ -172,15 +203,15 @@ class StatsAccumulatorScalarTest(test_util.TensorFlowTestCase): hessian_shape=tensor_shape.scalar()) partition, feature, grads, hessians = accumulator._make_summary( partition_ids=[1, 2, 1], - feature_ids=[2, 3, 2], + feature_ids=[[2, 0], [3, 1], [2, 0]], gradients=[0.1, 0.3, 0.1], hessians=[0.2, 0.4, 0.2]) partition, feature, grads, hessians = sess.run( [partition, feature, grads, hessians]) result = _AccumulatorResultToDict(partition, feature, grads, hessians) self.assertEqual(len(result), 2) - self.assertAllClose(result[(1, 2)], [0.2, 0.4]) - self.assertAllClose(result[(2, 3)], [0.3, 0.4]) + self.assertAllClose(result[(1, 2, 0)], [0.2, 0.4]) + self.assertAllClose(result[(2, 3, 1)], [0.3, 0.4]) class StatsAccumulatorTensorTest(test_util.TensorFlowTestCase): @@ -196,16 +227,54 @@ class StatsAccumulatorTensorTest(test_util.TensorFlowTestCase): op1 = accumulator.add( stamp_token=0, partition_ids=[1, 2], - feature_ids=[2, 3], + feature_ids=[[2, 0], [3, 0]], + # Two values for gradients, + gradients=[[0.1, 0.1], [0.2, 0.2]], + # A 2x2 matrix for each hessian. + hessians=[[[0.01, 0.02], [0.03, 0.04]], [[0.05, 0.06], [0.07, + 0.08]]]) + op2 = accumulator.add( + stamp_token=0, + partition_ids=[1], + feature_ids=[[2, 0]], + gradients=[[0.10, 0.11]], + hessians=[[[0.011, 0.022], [0.033, 0.044]]]) + + with ops.control_dependencies([op1, op2]): + num_updates, partition, feature, grads, hessians = accumulator.flush( + stamp_token=0, next_stamp_token=1) + num_updates, partition, feature, grads, hessians = sess.run( + [num_updates, partition, feature, grads, hessians]) + + result = _AccumulatorResultToDict(partition, feature, grads, hessians) + self.assertEqual(num_updates, 2) + self.assertEqual(len(result), 2) + self.assertAllClose(result[(1, 2, 0)][0], [0.20, 0.21]) + self.assertAllClose(result[(1, 2, 0)][1], + [[0.021, 0.042], [0.063, 0.084]]) + self.assertAllClose(result[(2, 3, 0)][0], [0.2, 0.2]) + self.assertAllClose(result[(2, 3, 0)][1], [[0.05, 0.06], [0.07, 0.08]]) + + def testMultidimensionalAcculumator(self): + with self.test_session() as sess: + accumulator = stats_accumulator_ops.StatsAccumulator( + stamp_token=0, + gradient_shape=tensor_shape.TensorShape([2]), + hessian_shape=tensor_shape.TensorShape([2, 2])) + with ops.control_dependencies([accumulator._create_op]): + op1 = accumulator.add( + stamp_token=0, + partition_ids=[1, 2], + feature_ids=[[2, 4], [3, 1]], # Two values for gradients, gradients=[[0.1, 0.1], [0.2, 0.2]], # A 2x2 matrix for each hessian. - hessians=[[[0.01, 0.02], [0.03, 0.04]], - [[0.05, 0.06], [0.07, 0.08]]]) + hessians=[[[0.01, 0.02], [0.03, 0.04]], [[0.05, 0.06], [0.07, + 0.08]]]) op2 = accumulator.add( stamp_token=0, partition_ids=[1], - feature_ids=[2], + feature_ids=[[2, 4]], gradients=[[0.10, 0.11]], hessians=[[[0.011, 0.022], [0.033, 0.044]]]) @@ -218,10 +287,11 @@ class StatsAccumulatorTensorTest(test_util.TensorFlowTestCase): result = _AccumulatorResultToDict(partition, feature, grads, hessians) self.assertEqual(num_updates, 2) self.assertEqual(len(result), 2) - self.assertAllClose(result[(1, 2)][0], [0.20, 0.21]) - self.assertAllClose(result[(1, 2)][1], [[0.021, 0.042], [0.063, 0.084]]) - self.assertAllClose(result[(2, 3)][0], [0.2, 0.2]) - self.assertAllClose(result[(2, 3)][1], [[0.05, 0.06], [0.07, 0.08]]) + self.assertAllClose(result[(1, 2, 4)][0], [0.20, 0.21]) + self.assertAllClose(result[(1, 2, 4)][1], + [[0.021, 0.042], [0.063, 0.084]]) + self.assertAllClose(result[(2, 3, 1)][0], [0.2, 0.2]) + self.assertAllClose(result[(2, 3, 1)][1], [[0.05, 0.06], [0.07, 0.08]]) def testDropStaleUpdate(self): with self.test_session() as sess: @@ -233,16 +303,16 @@ class StatsAccumulatorTensorTest(test_util.TensorFlowTestCase): op1 = accumulator.add( stamp_token=0, partition_ids=[1, 2], - feature_ids=[2, 3], + feature_ids=[[2, 5], [3, 0]], # Two values for gradients, gradients=[[0.1, 0.1], [0.2, 0.2]], # A 2x2 matrix for each hessian. - hessians=[[[0.01, 0.02], [0.03, 0.04]], - [[0.05, 0.06], [0.07, 0.08]]]) + hessians=[[[0.01, 0.02], [0.03, 0.04]], [[0.05, 0.06], [0.07, + 0.08]]]) op2 = accumulator.add( stamp_token=-1, partition_ids=[1], - feature_ids=[2], + feature_ids=[[2, 5]], gradients=[[0.10, 0.11]], hessians=[[[0.011, 0.022], [0.033, 0.044]]]) @@ -255,10 +325,10 @@ class StatsAccumulatorTensorTest(test_util.TensorFlowTestCase): result = _AccumulatorResultToDict(partition, feature, grads, hessians) self.assertEqual(num_updates, 1) self.assertEqual(len(result), 2) - self.assertAllClose(result[(1, 2)][0], [0.1, 0.1]) - self.assertAllClose(result[(1, 2)][1], [[0.01, 0.02], [0.03, 0.04]]) - self.assertAllClose(result[(2, 3)][0], [0.2, 0.2]) - self.assertAllClose(result[(2, 3)][1], [[0.05, 0.06], [0.07, 0.08]]) + self.assertAllClose(result[(1, 2, 5)][0], [0.1, 0.1]) + self.assertAllClose(result[(1, 2, 5)][1], [[0.01, 0.02], [0.03, 0.04]]) + self.assertAllClose(result[(2, 3, 0)][0], [0.2, 0.2]) + self.assertAllClose(result[(2, 3, 0)][1], [[0.05, 0.06], [0.07, 0.08]]) def testSerialize(self): with self.test_session() as sess: @@ -270,12 +340,12 @@ class StatsAccumulatorTensorTest(test_util.TensorFlowTestCase): op1 = accumulator.add( stamp_token=0, partition_ids=[1, 2], - feature_ids=[2, 3], + feature_ids=[[2, 0], [3, 0]], # Two values for gradients, gradients=[[0.1, 0.1], [0.2, 0.2]], # A 2x2 matrix for each hessian. - hessians=[[[0.01, 0.02], [0.03, 0.04]], - [[0.05, 0.06], [0.07, 0.08]]]) + hessians=[[[0.01, 0.02], [0.03, 0.04]], [[0.05, 0.06], [0.07, + 0.08]]]) with ops.control_dependencies([op1]): (stamp_token, num_updates_1, partition_1, feature_1, grads_1, @@ -300,15 +370,15 @@ class StatsAccumulatorTensorTest(test_util.TensorFlowTestCase): self.assertEqual(num_updates_1, 1) self.assertEqual(num_updates_2, 1) self.assertEqual(len(result_1), 2) - self.assertAllClose(result_1[(1, 2)][0], [0.1, 0.1]) - self.assertAllClose(result_1[(1, 2)][1], [[0.01, 0.02], [0.03, 0.04]]) - self.assertAllClose(result_1[(2, 3)][0], [0.2, 0.2]) - self.assertAllClose(result_1[(2, 3)][1], [[0.05, 0.06], [0.07, 0.08]]) + self.assertAllClose(result_1[(1, 2, 0)][0], [0.1, 0.1]) + self.assertAllClose(result_1[(1, 2, 0)][1], [[0.01, 0.02], [0.03, 0.04]]) + self.assertAllClose(result_1[(2, 3, 0)][0], [0.2, 0.2]) + self.assertAllClose(result_1[(2, 3, 0)][1], [[0.05, 0.06], [0.07, 0.08]]) - self.assertAllEqual(result_1[1, 2][0], result_2[1, 2][0]) - self.assertAllEqual(result_1[1, 2][1], result_2[1, 2][1]) - self.assertAllEqual(result_1[2, 3][0], result_2[2, 3][0]) - self.assertAllEqual(result_1[2, 3][1], result_2[2, 3][1]) + self.assertAllEqual(result_1[1, 2, 0][0], result_2[1, 2, 0][0]) + self.assertAllEqual(result_1[1, 2, 0][1], result_2[1, 2, 0][1]) + self.assertAllEqual(result_1[2, 3, 0][0], result_2[2, 3, 0][0]) + self.assertAllEqual(result_1[2, 3, 0][1], result_2[2, 3, 0][1]) def testDeserialize(self): with self.test_session() as sess: @@ -321,19 +391,19 @@ class StatsAccumulatorTensorTest(test_util.TensorFlowTestCase): op1 = accumulator.add( stamp_token=0, partition_ids=[1, 2], - feature_ids=[2, 3], + feature_ids=[[2, 0], [3, 0]], # Two values for gradients, gradients=[[0.1, 0.1], [0.2, 0.2]], # A 2x2 matrix for each hessian. - hessians=[[[0.01, 0.02], [0.03, 0.04]], - [[0.05, 0.06], [0.07, 0.08]]]) + hessians=[[[0.01, 0.02], [0.03, 0.04]], [[0.05, 0.06], [0.07, + 0.08]]]) with ops.control_dependencies([op1]): deserialize = accumulator.deserialize( stamp_token=2, num_updates=3, partition_ids=[3, 4], - feature_ids=[4, 5], + feature_ids=[[4, 0], [5, 0]], # Two values for gradients, gradients=[[0.3, 0.3], [0.5, 0.5]], # A 2x2 matrix for each hessian. @@ -349,10 +419,10 @@ class StatsAccumulatorTensorTest(test_util.TensorFlowTestCase): hessians) self.assertEqual(num_updates, 3) self.assertEqual(len(result), 2) - self.assertAllClose(result[(3, 4)][0], [0.3, 0.3]) - self.assertAllClose(result[(3, 4)][1], [[0.03, 0.04], [0.05, 0.06]]) - self.assertAllClose(result[(4, 5)][0], [0.5, 0.5]) - self.assertAllClose(result[(4, 5)][1], [[0.07, 0.08], [0.09, 0.10]]) + self.assertAllClose(result[(3, 4, 0)][0], [0.3, 0.3]) + self.assertAllClose(result[(3, 4, 0)][1], [[0.03, 0.04], [0.05, 0.06]]) + self.assertAllClose(result[(4, 5, 0)][0], [0.5, 0.5]) + self.assertAllClose(result[(4, 5, 0)][1], [[0.07, 0.08], [0.09, 0.10]]) def testMakeSummary(self): with self.test_session() as sess: @@ -362,7 +432,7 @@ class StatsAccumulatorTensorTest(test_util.TensorFlowTestCase): hessian_shape=tensor_shape.TensorShape([2, 2])) partition, feature, grads, hessians = accumulator._make_summary( partition_ids=[1, 2, 1], - feature_ids=[2, 3, 2], + feature_ids=[[2, 0], [3, 2], [2, 0]], # Two values for gradients, gradients=[[0.1, 0.1], [0.2, 0.2], [0.10, 0.11]], # A 2x2 matrix for each hessian. @@ -373,15 +443,16 @@ class StatsAccumulatorTensorTest(test_util.TensorFlowTestCase): result = _AccumulatorResultToDict(partition, feature, grads, hessians) self.assertEqual(len(result), 2) - self.assertAllClose(result[(1, 2)][0], [0.20, 0.21]) - self.assertAllClose(result[(1, 2)][1], [[0.021, 0.042], [0.063, 0.084]]) - self.assertAllClose(result[(2, 3)][0], [0.2, 0.2]) - self.assertAllClose(result[(2, 3)][1], [[0.05, 0.06], [0.07, 0.08]]) + self.assertAllClose(result[(1, 2, 0)][0], [0.20, 0.21]) + self.assertAllClose(result[(1, 2, 0)][1], + [[0.021, 0.042], [0.063, 0.084]]) + self.assertAllClose(result[(2, 3, 2)][0], [0.2, 0.2]) + self.assertAllClose(result[(2, 3, 2)][1], [[0.05, 0.06], [0.07, 0.08]]) def _AccumulatorResultToDict(partition, feature, grads, hessians): """Converts the inputs to a dictionary since the ordering changes.""" - return {(partition[i], feature[i]): (grads[i], hessians[i]) + return {(partition[i], feature[i, 0], feature[i, 1]): (grads[i], hessians[i]) for i in range(len(partition))} diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py index f0413fee5a8249d15f2cdae095dc7fa2c76a22b8..c2e65b643df90e88aadb0bb9acaf692da35b1a16 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py @@ -181,7 +181,6 @@ class CenterTreeEnsembleBiasOpTest(test_util.TensorFlowTestCase): tree_weights: 1.0 tree_metadata { num_layers_grown: 1 - is_finalized: true } growing_metadata { num_trees_attempted: 1 @@ -189,7 +188,7 @@ class CenterTreeEnsembleBiasOpTest(test_util.TensorFlowTestCase): } """ self.assertEqual(new_stamp, 1) - self.assertEqual(stats.num_trees, 1) + self.assertEqual(stats.num_trees, 0) self.assertEqual(stats.num_layers, 1) self.assertEqual(stats.active_tree, 1) self.assertEqual(stats.active_layer, 1) @@ -231,7 +230,6 @@ class CenterTreeEnsembleBiasOpTest(test_util.TensorFlowTestCase): tree_weights: 1.0 tree_metadata { num_layers_grown: 1 - is_finalized: true } growing_metadata { num_trees_attempted: 1 @@ -239,7 +237,7 @@ class CenterTreeEnsembleBiasOpTest(test_util.TensorFlowTestCase): } """ self.assertEqual(new_stamp, 2) - self.assertEqual(stats.num_trees, 1) + self.assertEqual(stats.num_trees, 0) self.assertEqual(stats.num_layers, 1) self.assertEqual(stats.active_tree, 1) self.assertEqual(stats.active_layer, 1) diff --git a/tensorflow/contrib/boosted_trees/python/ops/prediction_ops.py b/tensorflow/contrib/boosted_trees/python/ops/prediction_ops.py index d1e6d98efbc588df3db7a8d8186c1135e09bbe57..58f0d36b0f78eeed6abcec1c4fa696f4ccffa615 100644 --- a/tensorflow/contrib/boosted_trees/python/ops/prediction_ops.py +++ b/tensorflow/contrib/boosted_trees/python/ops/prediction_ops.py @@ -19,7 +19,6 @@ from __future__ import print_function # pylint: disable=unused-import from tensorflow.contrib.boosted_trees.python.ops import boosted_trees_ops_loader +from tensorflow.contrib.boosted_trees.python.ops.gen_prediction_ops import gradient_trees_partition_examples +from tensorflow.contrib.boosted_trees.python.ops.gen_prediction_ops import gradient_trees_prediction # pylint: enable=unused-import -# pylint: disable=wildcard-import -from tensorflow.contrib.boosted_trees.python.ops.gen_prediction_ops import * -# pylint: enable=wildcard-import diff --git a/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py b/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py index 7e8e15e7d8c89d1adaa472b1da7e8bb3c73ca17e..294e04002adac62fc123a3242a05a1b36f422433 100644 --- a/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py +++ b/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py @@ -45,6 +45,7 @@ class QuantileAccumulator(saver.BaseSaverBuilder.SaveableObject): init_stamp_token, epsilon, num_quantiles, + max_elements=None, name=None, container=None): """Creates a QuantileAccumulator object. @@ -53,6 +54,7 @@ class QuantileAccumulator(saver.BaseSaverBuilder.SaveableObject): init_stamp_token: The initial value for the stamp token. epsilon: Error bound on the quantile computation. num_quantiles: Number of quantiles to produce from the final summary. + max_elements: Maximum number of elements added to the accumulator. name: the name to save the accumulator under. container: An optional `string`. Defaults to `""` """ @@ -67,6 +69,7 @@ class QuantileAccumulator(saver.BaseSaverBuilder.SaveableObject): self._quantile_accumulator_handle, init_stamp_token, epsilon=epsilon, + max_elements=max_elements, num_quantiles=num_quantiles) is_initialized_op = gen_quantile_ops.quantile_accumulator_is_initialized( self._quantile_accumulator_handle) diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py index 5a917ca42897a263bf9f868393453ba232745e65..b95956dae2a62b28643cd31815c5f5650eca337b 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py @@ -208,7 +208,7 @@ def extract_features(features, feature_columns): if tensor.dtype == dtypes.float32: if len(tensor.shape) > 1 and tensor.shape[1] > 1: unstacked = array_ops.unstack(tensor, axis=1) - for i in xrange(len(unstacked)): + for i in range(len(unstacked)): dense_float_names.append(_FEATURE_NAME_TEMPLATE % (key, i)) dense_floats.append(array_ops.reshape(unstacked[i], [-1, 1])) else: @@ -322,9 +322,11 @@ class GradientBoostedDecisionTreeModel(object): self._feature_columns = feature_columns self._learner_config_serialized = learner_config.SerializeToString() self._attempted_trees = variables.Variable( - initial_value=array_ops.zeros([], dtypes.int64), trainable=False) + initial_value=array_ops.zeros([], dtypes.int64), trainable=False, + name="attempted_trees") self._finalized_trees = variables.Variable( - initial_value=array_ops.zeros([], dtypes.int64), trainable=False) + initial_value=array_ops.zeros([], dtypes.int64), trainable=False, + name="finalized_trees") if not features: raise ValueError("Features dictionary must be specified.") (fc_names, dense_floats, sparse_float_indices, sparse_float_values, @@ -494,7 +496,6 @@ class GradientBoostedDecisionTreeModel(object): gate_gradients=0, aggregation_method=None)[0] strategy = self._learner_config.multi_class_strategy - num_classes = self._learner_config.num_classes class_id = -1 # Handle different multiclass strategies. @@ -503,7 +504,7 @@ class GradientBoostedDecisionTreeModel(object): gradient_shape = tensor_shape.scalar() hessian_shape = tensor_shape.scalar() - if num_classes == 2: + if self._logits_dimension == 1: # We have only 1 score, gradients is of shape [batch, 1]. hessians = gradients_impl.gradients( gradients, @@ -522,7 +523,7 @@ class GradientBoostedDecisionTreeModel(object): # Choose the class for which the tree is built (one vs rest). class_id = math_ops.to_int32( - predictions_dict[NUM_TREES_ATTEMPTED] % num_classes) + predictions_dict[NUM_TREES_ATTEMPTED] % self._logits_dimension) # Use class id tensor to get the column with that index from gradients # and hessians. @@ -532,14 +533,15 @@ class GradientBoostedDecisionTreeModel(object): _get_column_by_index(hessians, class_id)) else: # Other multiclass strategies. - gradient_shape = tensor_shape.TensorShape([num_classes]) + gradient_shape = tensor_shape.TensorShape([self._logits_dimension]) if strategy == learner_pb2.LearnerConfig.FULL_HESSIAN: - hessian_shape = tensor_shape.TensorShape(([num_classes, num_classes])) + hessian_shape = tensor_shape.TensorShape( + ([self._logits_dimension, self._logits_dimension])) hessian_list = self._full_hessian(gradients, predictions) else: # Diagonal hessian strategy. - hessian_shape = tensor_shape.TensorShape(([num_classes])) + hessian_shape = tensor_shape.TensorShape(([self._logits_dimension])) hessian_list = self._diagonal_hessian(gradients, predictions) squeezed_gradients = gradients @@ -739,7 +741,7 @@ class GradientBoostedDecisionTreeModel(object): # Accumulate a step after updating stats. batch_size = math_ops.cast(array_ops.shape(labels)[0], dtypes.float32) with ops.control_dependencies(stats_update_ops): - add_step_op = steps_accumulator.add(ensemble_stamp, [0], [0], + add_step_op = steps_accumulator.add(ensemble_stamp, [0], [[0, 0]], [batch_size], [1.0]) # Determine learning rate. @@ -804,10 +806,10 @@ class GradientBoostedDecisionTreeModel(object): # compute the full hessian with a single call to gradients, but instead # must compute it row-by-row. gradients_list = array_ops.unstack( - grads, num=self._learner_config.num_classes, axis=1) + grads, num=self._logits_dimension, axis=1) hessian_rows = [] - for row in range(self._learner_config.num_classes): + for row in range(self._logits_dimension): # If current row is i, K is number of classes,each row returns a tensor of # size batch_size x K representing for each example dx_i dx_1, dx_i dx_2 # etc dx_i dx_K @@ -830,7 +832,7 @@ class GradientBoostedDecisionTreeModel(object): diag_hessian_list = [] gradients_list = array_ops.unstack( - grads, num=self._learner_config.num_classes, axis=1) + grads, num=self._logits_dimension, axis=1) for row, row_grads in enumerate(gradients_list): # If current row is i, K is number of classes,each row returns a tensor of @@ -891,8 +893,10 @@ class GradientBoostedDecisionTreeModel(object): hess_sum = math_ops.reduce_sum(hess, 0) # Accumulate gradients and hessians. - partition_ids = math_ops.range(predictions.get_shape()[1]) - feature_ids = array_ops.zeros_like(partition_ids, dtype=dtypes.int64) + partition_ids = math_ops.range(self._logits_dimension) + feature_ids = array_ops.zeros( + [self._logits_dimension, 2], dtype=dtypes.int64) + add_stats_op = bias_stats_accumulator.add( ensemble_stamp, partition_ids, feature_ids, grads_sum, hess_sum) return control_flow_ops.group(*[add_stats_op], name="update_bias_stats") diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py index 16e24d97ddee0751e0b808b89080074c1b4baba7..dba51d4f527792d2a8dedc693f74c07119fd231d 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py @@ -912,8 +912,10 @@ class GbdtTest(test_util.TensorFlowTestCase): self.assertEqual(1, len(output.trees[0].nodes[2].leaf.sparse_vector.index)) self.assertEqual(3, output.trees[0].nodes[2].leaf.sparse_vector.index[0]) - self.assertAlmostEqual( - 0.893284678459, output.trees[0].nodes[2].leaf.sparse_vector.value[0]) + self.assertAllClose( + 0.893284678459, + output.trees[0].nodes[2].leaf.sparse_vector.value[0], + atol=1e-4, rtol=1e-4) if __name__ == "__main__": diff --git a/tensorflow/contrib/boosted_trees/python/utils/losses_test.py b/tensorflow/contrib/boosted_trees/python/utils/losses_test.py index dde16426863b60e9df64da1ee6b36caec273bfd6..ccb8509c0347f9c9b6f1e8f4f620230aac9a6c2d 100644 --- a/tensorflow/contrib/boosted_trees/python/utils/losses_test.py +++ b/tensorflow/contrib/boosted_trees/python/utils/losses_test.py @@ -18,8 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import math - import numpy as np from tensorflow.contrib.boosted_trees.python.utils import losses @@ -60,35 +58,27 @@ class LossesTest(test_util.TensorFlowTestCase): neg_loss = loss_for_negatives.eval() # For positive labels, points <= 0.3 get max loss of e. # For negative labels, these points have minimum loss of 1/e. - for i in range(2): - self.assertAlmostEqual(math.exp(1), pos_loss[i], places=4) - self.assertAlmostEqual(math.exp(-1), neg_loss[i], places=4) + self.assertAllClose(np.exp(np.ones([2, 1])), pos_loss[:2], atol=1e-4) + self.assertAllClose(np.exp(-np.ones([2, 1])), neg_loss[:2], atol=1e-4) # For positive lables, p oints with predictions 0.7 and larger get minimum # loss value of 1/e. For negative labels, these points are wrongly # classified and get loss e. - for i in range(6, 10): - self.assertAlmostEqual(math.exp(-1), pos_loss[i], places=4) - self.assertAlmostEqual(math.exp(1), neg_loss[i], places=4) + self.assertAllClose(np.exp(-np.ones([4, 1])), pos_loss[6:10], atol=1e-4) + self.assertAllClose(np.exp(np.ones([4, 1])), neg_loss[6:10], atol=1e-4) # Points in between 0.5-eps, 0..5+eps get loss exp(-label_m*y), where # y = 1/eps *x -1/(2eps), where x is the probability and label_m is either # 1 or -1 (for label of 0). - for i in range(2, 6): - self.assertAlmostEqual( - math.exp(-1.0 * (predictions_probs[i] * 1.0 / eps - 0.5 / eps)), - pos_loss[i], - places=4) - self.assertAlmostEqual( - math.exp(1.0 * (predictions_probs[i] * 1.0 / eps - 0.5 / eps)), - neg_loss[i], - places=4) + self.assertAllClose( + np.exp(-(predictions_probs[2:6] * 1.0 / eps - 0.5 / eps)), + pos_loss[2:6], atol=1e-4) + self.assertAllClose( + np.exp(predictions_probs[2:6] * 1.0 / eps - 0.5 / eps), + neg_loss[2:6], atol=1e-4) def test_per_example_squared_loss(self): - def _squared_loss(p, y): - return np.mean(1.0 * (p - y) * (p - y)) - labels = np.array([[0.123], [224.2], [-3], [2], [.3]], dtype=np.float32) weights = array_ops.ones([5, 1], dtypes.float32) predictions = np.array( @@ -99,9 +89,8 @@ class LossesTest(test_util.TensorFlowTestCase): predictions) loss = loss_tensor.eval() - for i in range(5): - self.assertAlmostEqual( - _squared_loss(labels[i], predictions[i]), loss[i], places=4) + self.assertAllClose( + np.square(labels[:5] - predictions[:5]), loss[:5], atol=1e-4) if __name__ == "__main__": diff --git a/tensorflow/contrib/cloud/BUILD b/tensorflow/contrib/cloud/BUILD index aa8f5ed12bc6f779e3c1a923b9225ec283189747..fe8bd072afd43a64fa62a65bd8900b5a98dbe761 100644 --- a/tensorflow/contrib/cloud/BUILD +++ b/tensorflow/contrib/cloud/BUILD @@ -60,9 +60,7 @@ tf_py_test( size = "small", srcs = ["python/ops/bigquery_reader_ops_test.py"], additional_deps = [ - ":bigquery_reader_ops_op_lib", ":cloud_py", - "//tensorflow/contrib/cloud/kernels:bigquery_reader_ops", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", diff --git a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test.cc b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test.cc index b31b882fa19a7eaad304d6d423961234f9affef4..e9b79a066def566096d6c3f3745974423e3371d1 100644 --- a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test.cc +++ b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test.cc @@ -421,7 +421,7 @@ TEST_F(BigQueryTableAccessorTest, MultiplePagesTest) { TF_EXPECT_OK(accessor_->ReadRow(&row_id, &example)); EXPECT_EQ(3, row_id); EXPECT_TRUE(accessor_->Done()); - + Example expected_example; ASSERT_TRUE(protobuf::TextFormat::ParseFromString(kTestExampleProtoWithNulls, &expected_example)); diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py index d76ddf8c657b9b5d02bbdc4d6759053396dcd6d2..c74da9cabd6816bc9c7891e32937534cff2d677d 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py @@ -80,16 +80,31 @@ class TPUClusterResolver(ClusterResolver): raise ImportError('googleapiclient must be installed before using the ' 'TPU cluster resolver') - # TODO(b/67375680): Remove custom URL once TPU APIs are finalized self._service = discovery.build( - 'tpu', - 'v1', - credentials=self._credentials, - discoveryServiceUrl='https://storage.googleapis.com' - '/tpu-api-definition/v1alpha1.json') + 'tpu', 'v1alpha1', + credentials=self._credentials) else: self._service = service + def get_master(self): + """Get the ClusterSpec grpc master path. + + This returns the grpc path (grpc://1.2.3.4:8470) of first instance in the + ClusterSpec returned by the cluster_spec function. This is suitable for use + for the `master` argument in tf.Session() when you are using one TPU. + + Returns: + string, the grpc path of the first instance in the ClusterSpec. + + Raises: + ValueError: If none of the TPUs specified exists. + """ + job_tasks = self.cluster_spec().job_tasks(self._job_name) + if not job_tasks: + raise ValueError('No TPUs exists with the specified names exist.') + + return 'grpc://' + job_tasks[0] + def cluster_spec(self): """Returns a ClusterSpec object based on the latest TPU information. diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py index 5bd5cd1a8702840bd3eeb264ff19810fefa1fb62..db7419be06b58e1c5737f69f2c7fd9fee44b9d95 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py @@ -26,6 +26,28 @@ from tensorflow.python.training import server_lib mock = test.mock +class MockRequestClass(object): + + def __init__(self, name, tpu_map): + self._name = name + self._tpu_map = tpu_map + + def execute(self): + if self._name in self._tpu_map: + return self._tpu_map[self._name] + else: + raise KeyError('Resource %s was not found' % self._name) + + +class MockNodeClass(object): + + def __init__(self, tpu_map): + self._tpu_map = tpu_map + + def get(self, name): + return MockRequestClass(name, self._tpu_map) + + class TPUClusterResolverTest(test.TestCase): def _verifyClusterSpecEquality(self, cluster_spec, expected_proto): @@ -56,11 +78,15 @@ class TPUClusterResolverTest(test.TestCase): if tpu_map is None: tpu_map = {} - def get_side_effect(name): - return tpu_map[name] + mock_locations = mock.MagicMock() + mock_locations.nodes.return_value = MockNodeClass(tpu_map) + + mock_project = mock.MagicMock() + mock_project.locations.return_value = mock_locations mock_client = mock.MagicMock() - mock_client.projects.locations.nodes.get.side_effect = get_side_effect + mock_client.projects.return_value = mock_project + return mock_client def testSimpleSuccessfulRetrieval(self): @@ -109,3 +135,38 @@ class TPUClusterResolverTest(test.TestCase): tasks { key: 1 value: '10.1.2.3:8470' } } """ self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) + + def testGetMasterMultipleEntries(self): + tpu_map = { + 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { + 'ipAddress': '10.1.2.3', + 'port': '8470' + }, + 'projects/test-project/locations/us-central1-c/nodes/test-tpu-2': { + 'ipAddress': '10.4.5.6', + 'port': '8470' + } + } + + tpu_cluster_resolver = TPUClusterResolver( + project='test-project', + zone='us-central1-c', + tpu_names=['test-tpu-2', 'test-tpu-1'], + credentials=None, + service=self.mock_service_client(tpu_map=tpu_map)) + self.assertEqual('grpc://10.4.5.6:8470', tpu_cluster_resolver.get_master()) + + def testGetMasterNoEntries(self): + tpu_map = {} + + tpu_cluster_resolver = TPUClusterResolver( + project='test-project', + zone='us-central1-c', + tpu_names=[], + credentials=None, + service=self.mock_service_client(tpu_map=tpu_map)) + with self.assertRaises(ValueError): + tpu_cluster_resolver.get_master() + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/cmake/CMakeLists.txt b/tensorflow/contrib/cmake/CMakeLists.txt index 8744fc492ff67064bff2097c99be5af8a739b60d..481caf6bb076fe823b3cce7a5b574b2e8d08de00 100644 --- a/tensorflow/contrib/cmake/CMakeLists.txt +++ b/tensorflow/contrib/cmake/CMakeLists.txt @@ -18,7 +18,6 @@ cmake_policy(SET CMP0022 NEW) # Options option(tensorflow_VERBOSE "Enable for verbose output" OFF) -option(tensorflow_ENABLE_GPU "Enable GPU support" OFF) option(tensorflow_ENABLE_SSL_SUPPORT "Enable boringssl support" OFF) option(tensorflow_ENABLE_GRPC_SUPPORT "Enable gRPC support" ON) option(tensorflow_ENABLE_HDFS_SUPPORT "Enable HDFS support" OFF) @@ -35,12 +34,46 @@ option(tensorflow_OPTIMIZE_FOR_NATIVE_ARCH "Enable compiler optimizations for th option(tensorflow_WIN_CPU_SIMD_OPTIONS "Enables CPU SIMD instructions") option(tensorflow_ENABLE_SNAPPY_SUPPORT "Enable SNAPPY compression support" ON) +# GPU, CUDA and cuDNN options +option(tensorflow_ENABLE_GPU "Enable GPU support" OFF) +option(tensorflow_CUDA_VERSION "CUDA version to build against" 9.0) +option(tensorflow_CUDNN_VERSION "cuDNN version to build against" 7) + +if(HAIKU) + option(tensorflow_ENABLE_POSITION_INDEPENDENT_CODE "Enable PIE support" OFF) +else() + option(tensorflow_ENABLE_POSITION_INDEPENDENT_CODE "Enable PIE support" ON) +endif() + + if (NOT WIN32) # Threads: defines CMAKE_THREAD_LIBS_INIT and adds -pthread compile option # for targets that link ${CMAKE_THREAD_LIBS_INIT}. find_package (Threads) + + option(tensorflow_PATH_STATIC_LIB "Additional library search path for libcudnn_static.a, libnccl_static.a, libculibos.a" /usr/local/cuda/lib64/) + option(tensorflow_CUDNN_INCLUDE "cudnn.h header install path" /usr/include/) + if (NOT tensorflow_CUDNN_INCLUDE) + # option's default value is OFF. Fill it with real default values + set(tensorflow_CUDNN_INCLUDE /usr/include) + endif (NOT tensorflow_CUDNN_INCLUDE) + option(tensorflow_PATH_CUDNN_STATIC_LIB "Override PATH_STATIC_LIB for libcudnn_static.a" ${tensorflow_PATH_STATIC_LIB}) + option(tensorflow_PATH_NCCL_STATIC_LIB "Override PATH_STATIC_LIB for libnccl_static.a" ${tensorflow_PATH_STATIC_LIB}) + option(tensorflow_CUDA_LIBRARY_PATH "Designate the default CUDA library paths" /usr/local/cuda/lib64) + if (NOT tensorflow_CUDA_LIBRARY_PATH) + # option's default value is OFF. Fill it with real default values + set(tensorflow_CUDA_LIBRARY_PATH /usr/local/cuda/lib64) + endif (NOT tensorflow_CUDA_LIBRARY_PATH) endif() +if (WIN32) + set(BOOL_WIN32 ON) +else (WIN32) + set(BOOL_WIN32 OFF) + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fPIC") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC") +endif (WIN32) + # [CLEANUP] Remove when done # For debugging function(SHOW_VARIABLES) @@ -58,7 +91,12 @@ set (DOWNLOAD_LOCATION "${CMAKE_CURRENT_BINARY_DIR}/downloads" CACHE PATH "Location where external projects will be downloaded.") mark_as_advanced(DOWNLOAD_LOCATION) -set(CMAKE_POSITION_INDEPENDENT_CODE ON) +if (tensorflow_ENABLE_POSITION_INDEPENDENT_CODE) + set(CMAKE_POSITION_INDEPENDENT_CODE ON) +else() + set(CMAKE_POSITION_INDEPENDENT_CODE OFF) +endif() + add_definitions(-DEIGEN_AVOID_STL_ARRAY) if(WIN32) add_definitions(-DNOMINMAX -D_WIN32_WINNT=0x0A00 -DLANG_CXX11 -DCOMPILER_MSVC) @@ -217,20 +255,35 @@ endif() if(UNIX) list(APPEND tensorflow_EXTERNAL_LIBRARIES ${CMAKE_THREAD_LIBS_INIT} ${CMAKE_DL_LIBS}) endif() +if(HAIKU) + list(APPEND tensorflow_EXTERNAL_LIBRARIES network) +endif() if (tensorflow_ENABLE_GPU) + if (NOT WIN32) + # Default install paths for cuda libraries in Linux + # In some Linux distros, find_package(CUDA) seems to require CMAKE_LIBRARY_PATH to include cuda-lib paths + list(APPEND CMAKE_LIBRARY_PATH "${tensorflow_CUDA_LIBRARY_PATH}") + list(APPEND CMAKE_LIBRARY_PATH "${tensorflow_CUDA_LIBRARY_PATH}/stubs") + endif (NOT WIN32) + + find_package(CUDA ${tensorflow_CUDA_VERSION} REQUIRED) + + # by default we assume compute cabability 3.5 and 5.2. If you change this change it in + # CUDA_NVCC_FLAGS and cuda_config.h below + set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};-gencode arch=compute_30,code=\"sm_30,compute_30\";-gencode arch=compute_35,code=\"sm_35,compute_35\";-gencode arch=compute_52,code=\"sm_52,compute_52\") + set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};--include-path ${PROJECT_BINARY_DIR}/$\{build_configuration\};--expt-relaxed-constexpr) + set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};-ftz=true) # Flush denormals to zero + set(CUDA_INCLUDE ${CUDA_TOOLKIT_TARGET_DIR} ${CUDA_TOOLKIT_TARGET_DIR}/extras/CUPTI/include) + include_directories(${CUDA_INCLUDE}) if (WIN32) - find_package(CUDA 8.0 REQUIRED) - - # by default we assume compute cabability 3.5 and 5.2. If you change this change it in - # CUDA_NVCC_FLAGS and cuda_config.h below - set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};-gencode arch=compute_30,code=\"sm_30,compute_30\";-gencode arch=compute_35,code=\"sm_35,compute_35\";-gencode arch=compute_52,code=\"sm_52,compute_52\") - set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};--include-path ${PROJECT_BINARY_DIR}/$\{build_configuration\};--expt-relaxed-constexpr) - set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS};-ftz=true) # Flush denormals to zero - set(CUDA_INCLUDE ${CUDA_TOOLKIT_TARGET_DIR} ${CUDA_TOOLKIT_TARGET_DIR}/extras/CUPTI/include) - include_directories(${CUDA_INCLUDE}) add_definitions(-DGOOGLE_CUDA=1 -DTF_EXTRA_CUDA_CAPABILITIES=3.0,3.5,5.2) + else (WIN32) + # Without these double quotes, cmake in Linux makes it "-DTF_EXTRA_CUDA_CAPABILITIES=3.0, -D3.5, -D5.2" for cc, which incurs build breaks + add_definitions(-DGOOGLE_CUDA=1 -D"TF_EXTRA_CUDA_CAPABILITIES=3.0,3.5,5.2") + endif (WIN32) + if (WIN32) # add cudnn if(NOT CUDNN_HOME) set(CUDNN_HOME ${CUDA_TOOLKIT_TARGET_DIR}) @@ -238,18 +291,51 @@ if (tensorflow_ENABLE_GPU) include_directories(${CUDNN_HOME}) set(CUDA_LIBRARIES ${CUDA_LIBRARIES} ${CUDA_CUDA_LIBRARY} ${CUDA_CUBLAS_LIBRARIES} ${CUDA_CUFFT_LIBRARIES} ${CUDA_curand_LIBRARY} ${CUDA_cupti_LIBRARY} ${CUDA_cusolver_LIBRARY} ${CUDNN_HOME}/lib/x64/cudnn.lib) + else (WIN32) + set(CUDNN_INCLUDE "${tensorflow_CUDNN_INCLUDE}") - # create cuda_config.h - FILE(WRITE ${tensorflow_source_dir}/third_party/gpus/cuda/cuda_config.h - "#ifndef CUDA_CUDA_CONFIG_H_\n" - "#define CUDA_CUDA_CONFIG_H_\n" - "#define TF_CUDA_CAPABILITIES CudaVersion(\"3.0\"),CudaVersion(\"3.5\"),CudaVersion(\"5.2\")\n" - "#define TF_CUDA_VERSION \"64_80\"\n" - "#define TF_CUDNN_VERSION \"64_6\"\n" - "#define TF_CUDA_TOOLKIT_PATH \"${CUDA_TOOLKIT_ROOT_DIR}\"\n" - "#endif // CUDA_CUDA_CONFIG_H_\n" - ) + find_library(nccl_STATIC_LIBRARY NAMES libnccl_static.a PATHS ${tensorflow_PATH_NCCL_STATIC_LIB} ${CUDA_TOOLKIT_ROOT_DIR}) + if (NOT nccl_STATIC_LIBRARY) + message(FATAL_ERROR "NCCL is required for GPU-build") + else (NOT nccl_STATIC_LIBRARY) + message("nccl-static: ${nccl_STATIC_LIBRARY}") + # something like /usr/lib64/libnccl_static.a + endif (NOT nccl_STATIC_LIBRARY) + + find_library(cudnn_STATIC_LIBRARY NAMES libcudnn_static.a PATHS ${tensorflow_PATH_CUDNN_STATIC_LIB} ${CUDA_TOOLKIT_ROOT_DIR}) + if (NOT cudnn_STATIC_LIBRARY) + message(FATAL_ERROR "CUDNN is required for GPU-build") + else (NOT cudnn_STATIC_LIBRARY) + message("cudnn-static: ${cudnn_STATIC_LIBRARY}") + endif (NOT cudnn_STATIC_LIBRARY) + + find_library(culibos_STATIC_LIBRARY NAMES libculibos.a PATHS ${tensorflow_PATH_STATIC_LIB} ${CUDA_TOOLKIT_ROOT_DIR}) + if (NOT culibos_STATIC_LIBRARY) + message(FATAL_ERROR "CULIBOS is required for GPU-build") + else (NOT culibos_STATIC_LIBRARY) + message("culibos-static: ${culibos_STATIC_LIBRARY}") + endif (NOT culibos_STATIC_LIBRARY) + + include_directories(${CUDNN_INCLUDE}) + set(CUDA_LIBRARIES ${CUDA_LIBRARIES} ${CUDA_CUDA_LIBRARY} ${CUDA_CUBLAS_LIBRARIES} ${CUDA_CUFFT_LIBRARIES} + ${CUDA_curand_LIBRARY} ${CUDA_cupti_LIBRARY} ${CUDA_cusolver_LIBRARY} ${cudnn_STATIC_LIBRARY} ${culibos_STATIC_LIBRARY} ${nccl_STATIC_LIBRARY}) + endif (WIN32) + # Remove "." from CUDA version variable. + string(REPLACE "." "" short_CUDA_VER ${tensorflow_CUDA_VERSION}) + + # create cuda_config.h + FILE(WRITE ${tensorflow_source_dir}/third_party/gpus/cuda/cuda_config.h + "#ifndef CUDA_CUDA_CONFIG_H_\n" + "#define CUDA_CUDA_CONFIG_H_\n" + "#define TF_CUDA_CAPABILITIES CudaVersion(\"3.0\"),CudaVersion(\"3.5\"),CudaVersion(\"5.2\")\n" + "#define TF_CUDA_VERSION \"64_${short_CUDA_VER}\"\n" + "#define TF_CUDNN_VERSION \"64_${tensorflow_CUDNN_VERSION}\"\n" + "#define TF_CUDA_TOOLKIT_PATH \"${CUDA_TOOLKIT_ROOT_DIR}\"\n" + "#endif // CUDA_CUDA_CONFIG_H_\n" + ) + + if (WIN32) # tf assumes in various places header files to be in cuda/include. On windows the cuda sdk # installs them under cuda/version/include and to avoid that we need to change tf we copy a # few files to cuda/include @@ -261,21 +347,36 @@ if (tensorflow_ENABLE_GPU) ${CUDA_TOOLKIT_TARGET_DIR}/include/cusolverDn.h DESTINATION ${tensorflow_source_dir}/third_party/gpus/cuda/include ) - include_directories(${tensorflow_source_dir}/third_party/gpus) - # add cuda libraries to tensorflow_EXTERNAL_LIBRARIES - list(APPEND tensorflow_EXTERNAL_LIBRARIES ${CUDA_LIBRARIES}) + else(WIN32) + # Linux has slightly differnt install paths than Windows + FILE(COPY + ${CUDA_TOOLKIT_TARGET_DIR}/include/cuda.h ${CUDA_TOOLKIT_TARGET_DIR}/include/cuComplex.h + ${CUDA_TOOLKIT_TARGET_DIR}/include/cublas_v2.h ${CUDNN_INCLUDE}/cudnn.h + ${CUDA_TOOLKIT_TARGET_DIR}/include/cufft.h ${CUDA_TOOLKIT_TARGET_DIR}/include/curand.h + ${CUDA_TOOLKIT_TARGET_DIR}/include/cuda_runtime_api.h + ${CUDA_TOOLKIT_TARGET_DIR}/include/cusolverDn.h + DESTINATION ${tensorflow_source_dir}/third_party/gpus/cuda/include + ) + endif(WIN32) - # NOTE(mrry): Update these flags when the version of CUDA or cuDNN used - # in the default build is upgraded. + include_directories(${tensorflow_source_dir}/third_party/gpus) + # add cuda libraries to tensorflow_EXTERNAL_LIBRARIES + list(APPEND tensorflow_EXTERNAL_LIBRARIES ${CUDA_LIBRARIES}) + + # NOTE(mrry): Update these flags when the version of CUDA or cuDNN used + # in the default build is upgraded. + if(WIN32) set(tensorflow_BUILD_INFO_FLAGS --build_config cuda --key_value msvcp_dll_name=msvcp140.dll - cudart_dll_name=cudart64_80.dll - cuda_version_number=8.0 + cudart_dll_name=cudart64_${short_CUDA_VER}.dll + cuda_version_number=${tensorflow_CUDA_VERSION} nvcuda_dll_name=nvcuda.dll - cudnn_dll_name=cudnn64_6.dll - cudnn_version_number=6) + cudnn_dll_name=cudnn64_${tensorflow_CUDNN_VERSION}.dll + cudnn_version_number=${tensorflow_CUDNN_VERSION}) else(WIN32) - message(FATAL_ERROR "CMake GPU build is currently only supported on Windows.") + set(tensorflow_BUILD_INFO_FLAGS --build_config cuda --key_value + cuda_version_number=${tensorflow_CUDA_VERSION} + cudnn_version_number=${tensorflow_CUDNN_VERSION}) endif(WIN32) else(tensorflow_ENABLE_GPU) set(tensorflow_BUILD_INFO_FLAGS --build_config cpu --key_value @@ -290,13 +391,7 @@ endif() # Let's get to work! include(tf_core_framework.cmake) -# NOTE: Disabled until issue #3996 is fixed. -# include(tf_stream_executor.cmake) -if (tensorflow_ENABLE_GPU) - if (WIN32) - include(tf_stream_executor.cmake) - endif() -endif() +include(tf_stream_executor.cmake) include(tf_core_cpu.cmake) include(tf_core_ops.cmake) diff --git a/tensorflow/contrib/cmake/README.md b/tensorflow/contrib/cmake/README.md index 4ddfec5960d2b759bacb376202cd8dab6ef2b024..4be733a2809f366a214fa2bb853bccffb10ecaba 100644 --- a/tensorflow/contrib/cmake/README.md +++ b/tensorflow/contrib/cmake/README.md @@ -19,23 +19,6 @@ for instructions on how to install a pre-built TensorFlow package on Windows. ### Current known limitations * It is not possible to load a custom Op library. * GCS file system is not supported. -* The following Ops are not currently implemented: - - Dequantize - - QuantizeAndDequantize - - QuantizedAvgPool - - QuantizedBatchNomWithGlobalNormalization - - QuantizedBiasAdd - - QuantizedConcat - - QuantizedConv2D - - QuantizedMatmul - - QuantizedMaxPoo - - QuantizeDownAndShrinkRange - - QuantizedRelu - - QuantizedRelu6 - - QuantizedReshape - - QuantizeV2 - - RequantizationRange - - Requantize ## Building with CMake diff --git a/tensorflow/contrib/cmake/external/boringssl.cmake b/tensorflow/contrib/cmake/external/boringssl.cmake index dc27eadaca14361ffeffa6eadf6d4d97524de310..cca8444e2ae9952ea7c69a9392580ead715d363b 100644 --- a/tensorflow/contrib/cmake/external/boringssl.cmake +++ b/tensorflow/contrib/cmake/external/boringssl.cmake @@ -39,8 +39,12 @@ ExternalProject_Add(boringssl # BUILD_IN_SOURCE 1 INSTALL_COMMAND "" CMAKE_CACHE_ARGS + if(tensorflow_ENABLE_POSITION_INDEPENDENT_CODE) + -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON + else() + -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=OFF + endif() -DCMAKE_BUILD_TYPE:STRING=Release -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF - -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON ) diff --git a/tensorflow/contrib/cmake/external/farmhash.cmake b/tensorflow/contrib/cmake/external/farmhash.cmake index 96fade8b53273afdc379c7c13017e4917ee534f3..0cd0c1030c73d5218411f281d2b077af217e8275 100644 --- a/tensorflow/contrib/cmake/external/farmhash.cmake +++ b/tensorflow/contrib/cmake/external/farmhash.cmake @@ -15,8 +15,8 @@ include (ExternalProject) set(farmhash_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/farmhash_archive ${CMAKE_CURRENT_BINARY_DIR}/external/farmhash_archive/util) -set(farmhash_URL https://github.com/google/farmhash/archive/34c13ddfab0e35422f4c3979f360635a8c050260.zip) -set(farmhash_HASH SHA256=e3d37a59101f38fd58fb799ed404d630f0eee18bfc2a2433910977cc8fea9c28) +set(farmhash_URL https://mirror.bazel.build/github.com/google/farmhash/archive/816a4ae622e964763ca0862d9dbd19324a1eaf45.tar.gz) +set(farmhash_HASH SHA256=6560547c63e4af82b0f202cb710ceabb3f21347a4b996db565a411da5b17aba0) set(farmhash_BUILD ${CMAKE_CURRENT_BINARY_DIR}/farmhash/src/farmhash) set(farmhash_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/farmhash/install) set(farmhash_INCLUDES ${farmhash_BUILD}) diff --git a/tensorflow/contrib/cmake/external/fft2d.cmake b/tensorflow/contrib/cmake/external/fft2d.cmake index a35c24e9e01101f837ba961c06429c981ddc4648..d3af2a46761c0f7f0b5db134af8400fc93f2f095 100644 --- a/tensorflow/contrib/cmake/external/fft2d.cmake +++ b/tensorflow/contrib/cmake/external/fft2d.cmake @@ -15,7 +15,7 @@ include (ExternalProject) -set(fft2d_URL http://www.kurims.kyoto-u.ac.jp/~ooura/fft.tgz) +set(fft2d_URL https://mirror.bazel.build/www.kurims.kyoto-u.ac.jp/~ooura/fft.tgz) set(fft2d_HASH SHA256=52bb637c70b971958ec79c9c8752b1df5ff0218a4db4510e60826e0cb79b5296) set(fft2d_BUILD ${CMAKE_CURRENT_BINARY_DIR}/fft2d/) set(fft2d_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/fft2d/src) diff --git a/tensorflow/contrib/cmake/external/gemmlowp.cmake b/tensorflow/contrib/cmake/external/gemmlowp.cmake index 54a9e96ce58c5501217368b0d12089aa14696b71..a235442dc5c0a07e249653381436eeae81575883 100644 --- a/tensorflow/contrib/cmake/external/gemmlowp.cmake +++ b/tensorflow/contrib/cmake/external/gemmlowp.cmake @@ -14,8 +14,8 @@ # ============================================================================== include (ExternalProject) -set(gemmlowp_URL http://github.com/google/gemmlowp/archive/010bb3e71a26ca1d0884a167081d092b43563996.tar.gz) -set(gemmlowp_HASH SHA256=861cc6d9d902861f54fd77e1ab79286477dcc559b2a283e75b9c22d37b61f6ae) +set(gemmlowp_URL https://github.com/google/gemmlowp/archive/6a2a90822e8546fc2bfa7044de0faf1c1cb4862f.zip) +set(gemmlowp_HASH SHA256=3447948d219f3270383766bbe08942888c0eb4e0ca6663c0e0548502ec5bb77d) set(gemmlowp_BUILD ${CMAKE_CURRENT_BINARY_DIR}/gemmlowp/src/gemmlowp) set(gemmlowp_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/gemmlowp/src/gemmlowp) diff --git a/tensorflow/contrib/cmake/external/grpc.cmake b/tensorflow/contrib/cmake/external/grpc.cmake index 464aad74c6c8623981338695af01b026dcc0e6e3..41ea0b48a4600d7ca2dd2f4a61c14ec0cc5b4734 100644 --- a/tensorflow/contrib/cmake/external/grpc.cmake +++ b/tensorflow/contrib/cmake/external/grpc.cmake @@ -17,7 +17,7 @@ include (ExternalProject) set(GRPC_INCLUDE_DIRS ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/include) set(GRPC_URL https://github.com/grpc/grpc.git) set(GRPC_BUILD ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc) -set(GRPC_TAG 781fd6f6ea03645a520cd5c675da67ab61f87e4b) +set(GRPC_TAG 54e8f37e537794c2d814c1604c1282125f64f093) if(WIN32) set(grpc_STATIC_LIBRARIES @@ -28,10 +28,11 @@ else() set(grpc_STATIC_LIBRARIES ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgrpc++_unsecure.a ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgrpc_unsecure.a - ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgpr.a - ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/third_party/cares/libcares.a) + ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgpr.a) endif() +add_definitions(-DGRPC_ARES=0) + ExternalProject_Add(grpc PREFIX grpc DEPENDS protobuf zlib @@ -39,9 +40,6 @@ ExternalProject_Add(grpc GIT_TAG ${GRPC_TAG} DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" BUILD_IN_SOURCE 1 - # TODO(jhseu): Remove this PATCH_COMMAND once grpc removes the dependency - # on "grpc" from the "grpc++_unsecure" rule. - PATCH_COMMAND ${CMAKE_COMMAND} -E copy_if_different ${CMAKE_CURRENT_SOURCE_DIR}/patches/grpc/CMakeLists.txt ${GRPC_BUILD} BUILD_COMMAND ${CMAKE_COMMAND} --build . --config Release --target grpc++_unsecure COMMAND ${CMAKE_COMMAND} --build . --config Release --target grpc_cpp_plugin INSTALL_COMMAND "" diff --git a/tensorflow/contrib/cmake/external/jemalloc.cmake b/tensorflow/contrib/cmake/external/jemalloc.cmake index e4737a1dd825409133cdfd8a54f20dac819c0d5b..198ba13e64e4b6df57c4325a0104b1a6745d173a 100644 --- a/tensorflow/contrib/cmake/external/jemalloc.cmake +++ b/tensorflow/contrib/cmake/external/jemalloc.cmake @@ -15,7 +15,7 @@ include (ExternalProject) set(jemalloc_INCLUDE_DIRS ${CMAKE_CURRENT_BINARY_DIR}/jemalloc/src/jemalloc/include) -set(jemalloc_URL https://github.com/jemalloc/jemalloc-cmake/archive/jemalloc-cmake.4.3.1.tar.gz) +set(jemalloc_URL https://mirror.bazel.build/github.com/jemalloc/jemalloc-cmake/archive/jemalloc-cmake.4.3.1.tar.gz) set(jemalloc_HASH SHA256=f9be9a05fe906deb5c1c8ca818071a7d2e27d66fd87f5ba9a7bf3750bcedeaf0) set(jemalloc_BUILD ${CMAKE_CURRENT_BINARY_DIR}/jemalloc/src/jemalloc) diff --git a/tensorflow/contrib/cmake/external/jsoncpp.cmake b/tensorflow/contrib/cmake/external/jsoncpp.cmake index 5127d7e8f79abdda4516eb9f006e243b7438bc65..d2ae4c76e8cd175cdc3ba41fdf4e4009f8237309 100644 --- a/tensorflow/contrib/cmake/external/jsoncpp.cmake +++ b/tensorflow/contrib/cmake/external/jsoncpp.cmake @@ -42,8 +42,12 @@ ExternalProject_Add(jsoncpp BUILD_IN_SOURCE 1 INSTALL_COMMAND "" CMAKE_CACHE_ARGS + if(tensorflow_ENABLE_POSITION_INDEPENDENT_CODE) + -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON + else() + -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=OFF + endif() -DCMAKE_BUILD_TYPE:STRING=Release -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF - -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON ) diff --git a/tensorflow/contrib/cmake/external/lmdb.cmake b/tensorflow/contrib/cmake/external/lmdb.cmake index 79971b7cfc3c72e4b6290ccb71d40a20d1180c01..e41384f023ca9fc4cba697917b491af5a9db92bc 100644 --- a/tensorflow/contrib/cmake/external/lmdb.cmake +++ b/tensorflow/contrib/cmake/external/lmdb.cmake @@ -29,10 +29,14 @@ ExternalProject_Add(lmdb INSTALL_DIR ${lmdb_INSTALL} DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" CMAKE_CACHE_ARGS + if(tensorflow_ENABLE_POSITION_INDEPENDENT_CODE) + -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON + else() + -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=OFF + endif() -DCMAKE_BUILD_TYPE:STRING=Release -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF -DCMAKE_INSTALL_PREFIX:STRING=${lmdb_INSTALL} - -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON ) if(WIN32) diff --git a/tensorflow/contrib/cmake/external/nsync.cmake b/tensorflow/contrib/cmake/external/nsync.cmake index 2c42377f5078d55e72e37eb5e880624bc09ddef0..05080060479b6240edb8ab9f65160b3dd182feb9 100644 --- a/tensorflow/contrib/cmake/external/nsync.cmake +++ b/tensorflow/contrib/cmake/external/nsync.cmake @@ -16,7 +16,7 @@ include (ExternalProject) set(nsync_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/nsync/public) set(nsync_URL https://github.com/google/nsync) -set(nsync_TAG 394e71f0ebeed6788ae6c84d42c1bedf6e1ee9f7) +set(nsync_TAG 8502189abfa44c249c01c2cad64e6ed660a9a668) set(nsync_BUILD ${CMAKE_CURRENT_BINARY_DIR}/nsync/src/nsync) set(nsync_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/nsync/install) diff --git a/tensorflow/contrib/cmake/external/png.cmake b/tensorflow/contrib/cmake/external/png.cmake index 2b2bd47d1c95ca886469c525191c27f22d416c29..aad6618f52f909096fd2388e867ef3a965d033cb 100644 --- a/tensorflow/contrib/cmake/external/png.cmake +++ b/tensorflow/contrib/cmake/external/png.cmake @@ -41,10 +41,14 @@ ExternalProject_Add(png INSTALL_DIR ${png_INSTALL} DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" CMAKE_CACHE_ARGS + if(tensorflow_ENABLE_POSITION_INDEPENDENT_CODE) + -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON + else() + -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=OFF + endif() -DCMAKE_BUILD_TYPE:STRING=Release -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF -DCMAKE_INSTALL_PREFIX:STRING=${png_INSTALL} - -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON -DZLIB_ROOT:STRING=${ZLIB_INSTALL} ) diff --git a/tensorflow/contrib/cmake/external/protobuf.cmake b/tensorflow/contrib/cmake/external/protobuf.cmake index 1e300e21df17eeee0abfc2becdab746fbfc62ff6..b53857a47bfbf797af02fe7f69474263119161cd 100644 --- a/tensorflow/contrib/cmake/external/protobuf.cmake +++ b/tensorflow/contrib/cmake/external/protobuf.cmake @@ -44,8 +44,12 @@ ExternalProject_Add(protobuf ${PROTOBUF_ADDITIONAL_CMAKE_OPTIONS} INSTALL_COMMAND "" CMAKE_CACHE_ARGS + if(tensorflow_ENABLE_POSITION_INDEPENDENT_CODE) + -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON + else() + -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=OFF + endif() -DCMAKE_BUILD_TYPE:STRING=Release -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF - -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON -DZLIB_ROOT:STRING=${ZLIB_INSTALL} ) diff --git a/tensorflow/contrib/cmake/external/re2.cmake b/tensorflow/contrib/cmake/external/re2.cmake index cb4ec9c2de3388ef918c75d842dab6e1f4ffee9b..d10f5959f71dd350e6e2bcb81be8882b203fb231 100644 --- a/tensorflow/contrib/cmake/external/re2.cmake +++ b/tensorflow/contrib/cmake/external/re2.cmake @@ -38,7 +38,12 @@ ExternalProject_Add(re2 BUILD_IN_SOURCE 1 DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" CMAKE_CACHE_ARGS + if(tensorflow_ENABLE_POSITION_INDEPENDENT_CODE) + -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON + else() + -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=OFF + endif() -DCMAKE_BUILD_TYPE:STRING=Release -DCMAKE_INSTALL_PREFIX:STRING=${re2_INSTALL} - -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON -) \ No newline at end of file + -DRE2_BUILD_TESTING:BOOL=OFF +) diff --git a/tensorflow/contrib/cmake/external/snappy.cmake b/tensorflow/contrib/cmake/external/snappy.cmake index 2d2451521c0f9127e2c76e6270694ac21fe8db93..926c271fd9ea6e2a30251aa408bd49859ae95070 100644 --- a/tensorflow/contrib/cmake/external/snappy.cmake +++ b/tensorflow/contrib/cmake/external/snappy.cmake @@ -40,11 +40,15 @@ ExternalProject_Add(snappy LOG_CONFIGURE ON LOG_BUILD ON CMAKE_CACHE_ARGS + if(tensorflow_ENABLE_POSITION_INDEPENDENT_CODE) + -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON + else() + -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=OFF + endif() -DCMAKE_BUILD_TYPE:STRING=Release -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF -DSNAPPY_BUILD_TESTS:BOOL=OFF - -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON ) # actually enables snappy in the source code -add_definitions(-DTF_USE_SNAPPY) +add_definitions(-DTF_USE_SNAPPY) \ No newline at end of file diff --git a/tensorflow/contrib/cmake/external/sqlite.cmake b/tensorflow/contrib/cmake/external/sqlite.cmake index 6fa3a576998acef529942ccfab3a6a544795d712..785039a46983747557607562675349c150e064ad 100644 --- a/tensorflow/contrib/cmake/external/sqlite.cmake +++ b/tensorflow/contrib/cmake/external/sqlite.cmake @@ -15,7 +15,7 @@ include (ExternalProject) set(sqlite_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/sqlite) -set(sqlite_URL http://www.sqlite.org/2017/sqlite-amalgamation-3200000.zip) +set(sqlite_URL https://mirror.bazel.build/www.sqlite.org/2017/sqlite-amalgamation-3200000.zip) set(sqlite_HASH SHA256=208780b3616f9de0aeb50822b7a8f5482f6515193859e91ed61637be6ad74fd4) set(sqlite_BUILD ${CMAKE_CURRENT_BINARY_DIR}/sqlite/src/sqlite) set(sqlite_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/sqlite/install) @@ -53,9 +53,13 @@ else() INSTALL_DIR ${sqlite_INSTALL} DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" CMAKE_CACHE_ARGS + if(tensorflow_ENABLE_POSITION_INDEPENDENT_CODE) + -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON + else() + -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=OFF + endif() -DCMAKE_BUILD_TYPE:STRING=Release -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF - -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON -DCMAKE_INSTALL_PREFIX:STRING=${sqlite_INSTALL} ) diff --git a/tensorflow/contrib/cmake/external/zlib.cmake b/tensorflow/contrib/cmake/external/zlib.cmake index c8af611e1eaefdf135551940a66985a4d50b26ed..f10f84336e8b1c0a2c7de7ea1f8b8af7c21f8b51 100644 --- a/tensorflow/contrib/cmake/external/zlib.cmake +++ b/tensorflow/contrib/cmake/external/zlib.cmake @@ -42,9 +42,13 @@ ExternalProject_Add(zlib BUILD_IN_SOURCE 1 DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" CMAKE_CACHE_ARGS + if(tensorflow_ENABLE_POSITION_INDEPENDENT_CODE) + -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON + else() + -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=OFF + endif() -DCMAKE_BUILD_TYPE:STRING=Release -DCMAKE_INSTALL_PREFIX:STRING=${ZLIB_INSTALL} - -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON ) # put zlib includes in the directory where they are expected diff --git a/tensorflow/contrib/cmake/patches/grpc/CMakeLists.txt b/tensorflow/contrib/cmake/patches/grpc/CMakeLists.txt deleted file mode 100644 index 84722c5ca2a9f9253c7a76dd610dde615a176c07..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/cmake/patches/grpc/CMakeLists.txt +++ /dev/null @@ -1,14415 +0,0 @@ -# GRPC global cmake file -# This currently builds C and C++ code. -# This file has been automatically generated from a template file. -# Please look at the templates directory instead. -# This file can be regenerated from the template by running -# tools/buildgen/generate_projects.sh -# -# Copyright 2015 gRPC authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - - -cmake_minimum_required(VERSION 2.8) - -set(PACKAGE_NAME "grpc") -set(PACKAGE_VERSION "1.5.0-dev") -set(PACKAGE_STRING "${PACKAGE_NAME} ${PACKAGE_VERSION}") -set(PACKAGE_TARNAME "${PACKAGE_NAME}-${PACKAGE_VERSION}") -set(PACKAGE_BUGREPORT "https://github.com/grpc/grpc/issues/") -project(${PACKAGE_NAME} C CXX) - -set(gRPC_INSTALL_BINDIR "${CMAKE_INSTALL_PREFIX}/bin" CACHE PATH "Installation directory for executables") -set(gRPC_INSTALL_LIBDIR "${CMAKE_INSTALL_PREFIX}/lib" CACHE PATH "Installation directory for libraries") -set(gRPC_INSTALL_INCLUDEDIR "${CMAKE_INSTALL_PREFIX}/include" CACHE PATH "Installation directory for headers") -set(gRPC_INSTALL_CMAKEDIR "${CMAKE_INSTALL_PREFIX}/lib/cmake/${PACKAGE_NAME}" CACHE PATH "Installation directory for cmake config files") - -# Options -option(gRPC_BUILD_TESTS "Build tests" OFF) - -set(gRPC_INSTALL_default ON) -if (NOT CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR) - # Disable gRPC_INSTALL by default if building as a submodule - set(gRPC_INSTALL_default OFF) -endif() -set(gRPC_INSTALL ${gRPC_INSTALL_default} CACHE BOOL - "Generate installation target: gRPC_ZLIB_PROVIDER, gRPC_CARES_PROVIDER, gRPC_SSL_PROVIDER and gRPC_PROTOBUF_PROVIDER must all be \"package\"") - -set(gRPC_ZLIB_PROVIDER "module" CACHE STRING "Provider of zlib library") -set_property(CACHE gRPC_ZLIB_PROVIDER PROPERTY STRINGS "module" "package") - -set(gRPC_CARES_PROVIDER "module" CACHE STRING "Provider of c-ares library") -set_property(CACHE gRPC_CARES_PROVIDER PROPERTY STRINGS "module" "package") - -set(gRPC_SSL_PROVIDER "module" CACHE STRING "Provider of ssl library") -set_property(CACHE gRPC_SSL_PROVIDER PROPERTY STRINGS "module" "package") - -set(gRPC_PROTOBUF_PROVIDER "module" CACHE STRING "Provider of protobuf library") -set_property(CACHE gRPC_PROTOBUF_PROVIDER PROPERTY STRINGS "module" "package") - -set(gRPC_PROTOBUF_PACKAGE_TYPE "" CACHE STRING "Algorithm for searching protobuf package") -set_property(CACHE gRPC_PROTOBUF_PACKAGE_TYPE PROPERTY STRINGS "CONFIG" "MODULE") - -set(gRPC_GFLAGS_PROVIDER "module" CACHE STRING "Provider of gflags library") -set_property(CACHE gRPC_GFLAGS_PROVIDER PROPERTY STRINGS "module" "package") - -set(gRPC_BENCHMARK_PROVIDER "module" CACHE STRING "Provider of benchmark library") -set_property(CACHE gRPC_BENCHMARK_PROVIDER PROPERTY STRINGS "module" "package") - -set(gRPC_USE_PROTO_LITE OFF CACHE BOOL "Use the protobuf-lite library") - -if(UNIX) - if(${CMAKE_SYSTEM_NAME} MATCHES "Linux") - set(_gRPC_PLATFORM_LINUX ON) - elseif(${CMAKE_SYSTEM_NAME} MATCHES "Darwin") - set(_gRPC_PLATFORM_MAC ON) - else() - set(_gRPC_PLATFORM_POSIX ON) - endif() -endif() -if(WIN32) - set(_gRPC_PLATFORM_WINDOWS ON) -endif() - -set(CMAKE_POSITION_INDEPENDENT_CODE TRUE) - -if (MSVC) - include(cmake/msvc_static_runtime.cmake) - add_definitions(-D_WIN32_WINNT=0x600 -D_SCL_SECURE_NO_WARNINGS -D_CRT_SECURE_NO_WARNINGS -D_WINSOCK_DEPRECATED_NO_WARNINGS) - # needed to compile protobuf - add_definitions(/wd4065 /wd4506) - # TODO(jtattermusch): revisit C4267 occurrences throughout the code - add_definitions(/wd4267) -endif() - -if (gRPC_USE_PROTO_LITE) - set(_gRPC_PROTOBUF_LIBRARY_NAME "libprotobuf-lite") - add_definitions("-DGRPC_USE_PROTO_LITE") -else() - set(_gRPC_PROTOBUF_LIBRARY_NAME "libprotobuf") -endif() - -if("${gRPC_ZLIB_PROVIDER}" STREQUAL "module") - if(NOT ZLIB_ROOT_DIR) - set(ZLIB_ROOT_DIR ${CMAKE_CURRENT_SOURCE_DIR}/third_party/zlib) - endif() - set(ZLIB_INCLUDE_DIR "${ZLIB_ROOT_DIR}") - if(EXISTS "${ZLIB_ROOT_DIR}/CMakeLists.txt") - # TODO(jtattermusch): workaround for https://github.com/madler/zlib/issues/218 - include_directories(${ZLIB_INCLUDE_DIR}) - - add_subdirectory(${ZLIB_ROOT_DIR} third_party/zlib) - if(TARGET zlibstatic) - set(_gRPC_ZLIB_LIBRARIES zlibstatic) - endif() - else() - message(WARNING "gRPC_ZLIB_PROVIDER is \"module\" but ZLIB_ROOT_DIR is wrong") - endif() - if(gRPC_INSTALL) - message(WARNING "gRPC_INSTALL will be forced to FALSE because gRPC_ZLIB_PROVIDER is \"module\"") - set(gRPC_INSTALL FALSE) - endif() -elseif("${gRPC_ZLIB_PROVIDER}" STREQUAL "package") - find_package(ZLIB) - if(TARGET ZLIB::ZLIB) - set(_gRPC_ZLIB_LIBRARIES ZLIB::ZLIB) - endif() - set(_gRPC_FIND_ZLIB "if(NOT ZLIB_FOUND)\n find_package(ZLIB)\nendif()") -endif() - -if("${gRPC_CARES_PROVIDER}" STREQUAL "module") - if(NOT CARES_ROOT_DIR) - set(CARES_ROOT_DIR ${CMAKE_CURRENT_SOURCE_DIR}/src/c-ares) - endif() - string(TOLOWER ${CMAKE_SYSTEM_NAME} CARES_SYSTEM_NAME) - set(CARES_INCLUDE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/third_party/cares/cares") - set(CARES_BUILD_INCLUDE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/third_party/cares") - set(CARES_PLATFORM_INCLUDE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/third_party/cares/config_${CARES_SYSTEM_NAME}") - if(EXISTS "${CARES_ROOT_DIR}/CMakeLists.txt") - if("${CARES_SYSTEM_NAME}" MATCHES "windows") - add_definitions(-DCARES_STATICLIB=1) - add_definitions(-DWIN32_LEAN_AND_MEAN=1) - else() - add_definitions(-DHAVE_CONFIG_H=1) - add_definitions(-D_GNU_SOURCE=1) - endif() - add_subdirectory(src/c-ares third_party/cares) - if(TARGET cares) - set(_gRPC_CARES_LIBRARIES cares) - endif() - else() - message(WARNING "gRPC_CARES_PROVIDER is \"module\" but CARES_ROOT_DIR is wrong") - endif() - if(gRPC_INSTALL) - message(WARNING "gRPC_INSTALL will be forced to FALSE because gRPC_CARES_PROVIDER is \"module\"") - set(gRPC_INSTALL FALSE) - endif() -elseif("${gRPC_CARES_PROVIDER}" STREQUAL "package") - find_package(c-ares CONFIG) - if(TARGET c-ares::cares) - set(_gRPC_CARES_LIBRARIES c-ares::cares) - endif() - set(_gRPC_FIND_CARES "if(NOT c-ares_FOUND)\n find_package(c-ares CONFIG)\nendif()") -endif() - -if("${gRPC_PROTOBUF_PROVIDER}" STREQUAL "module") - # Building the protobuf tests require gmock what is not part of a standard protobuf checkout. - # Disable them unless they are explicitly requested from the cmake command line (when we assume - # gmock is downloaded to the right location inside protobuf). - if(NOT protobuf_BUILD_TESTS) - set(protobuf_BUILD_TESTS OFF CACHE BOOL "Build protobuf tests") - endif() - # Disable building protobuf with zlib. Building protobuf with zlib breaks - # the build if zlib is not installed on the system. - if(NOT protobuf_WITH_ZLIB) - set(protobuf_WITH_ZLIB OFF CACHE BOOL "Build protobuf with zlib.") - endif() - if(NOT PROTOBUF_ROOT_DIR) - set(PROTOBUF_ROOT_DIR ${CMAKE_CURRENT_SOURCE_DIR}/third_party/protobuf) - endif() - set(PROTOBUF_WELLKNOWN_IMPORT_DIR ${PROTOBUF_ROOT_DIR}/src) - if(EXISTS "${PROTOBUF_ROOT_DIR}/cmake/CMakeLists.txt") - set(protobuf_MSVC_STATIC_RUNTIME OFF CACHE BOOL "Link static runtime libraries") - add_subdirectory(${PROTOBUF_ROOT_DIR}/cmake third_party/protobuf) - if(TARGET ${_gRPC_PROTOBUF_LIBRARY_NAME}) - set(_gRPC_PROTOBUF_LIBRARIES ${_gRPC_PROTOBUF_LIBRARY_NAME}) - endif() - if(TARGET libprotoc) - set(_gRPC_PROTOBUF_PROTOC_LIBRARIES libprotoc) - endif() - if(TARGET protoc) - set(_gRPC_PROTOBUF_PROTOC protoc) - endif() - else() - message(WARNING "gRPC_PROTOBUF_PROVIDER is \"module\" but PROTOBUF_ROOT_DIR is wrong") - endif() - if(gRPC_INSTALL) - message(WARNING "gRPC_INSTALL will be forced to FALSE because gRPC_PROTOBUF_PROVIDER is \"module\"") - set(gRPC_INSTALL FALSE) - endif() -elseif("${gRPC_PROTOBUF_PROVIDER}" STREQUAL "package") - find_package(Protobuf ${gRPC_PROTOBUF_PACKAGE_TYPE}) - if(Protobuf_FOUND OR PROTOBUF_FOUND) - if(TARGET protobuf::${_gRPC_PROTOBUF_LIBRARY_NAME}) - set(_gRPC_PROTOBUF_LIBRARIES protobuf::${_gRPC_PROTOBUF_LIBRARY_NAME}) - else() - set(_gRPC_PROTOBUF_LIBRARIES ${PROTOBUF_LIBRARIES}) - endif() - if(TARGET protobuf::libprotoc) - set(_gRPC_PROTOBUF_PROTOC_LIBRARIES protobuf::libprotoc) - else() - set(_gRPC_PROTOBUF_PROTOC_LIBRARIES ${PROTOBUF_PROTOC_LIBRARIES}) - endif() - if(TARGET protobuf::protoc) - set(_gRPC_PROTOBUF_PROTOC protobuf::protoc) - else() - set(_gRPC_PROTOBUF_PROTOC ${PROTOBUF_PROTOC_EXECUTABLE}) - endif() - set(_gRPC_FIND_PROTOBUF "if(NOT Protobuf_FOUND AND NOT PROTOBUF_FOUND)\n find_package(Protobuf ${gRPC_PROTOBUF_PACKAGE_TYPE})\nendif()") - endif() - if(PROTOBUF_FOUND) - include_directories(${PROTOBUF_INCLUDE_DIRS}) - endif() - set(PROTOBUF_WELLKNOWN_IMPORT_DIR /usr/local/include) -endif() - -if("${gRPC_SSL_PROVIDER}" STREQUAL "module") - if(NOT BORINGSSL_ROOT_DIR) - set(BORINGSSL_ROOT_DIR ${CMAKE_CURRENT_SOURCE_DIR}/third_party/boringssl) - endif() - if(EXISTS "${BORINGSSL_ROOT_DIR}/CMakeLists.txt") - set(OPENSSL_NO_ASM ON) # make boringssl buildable with Visual Studio - add_subdirectory(${BORINGSSL_ROOT_DIR} third_party/boringssl) - if(TARGET ssl) - set(_gRPC_SSL_LIBRARIES ssl) - endif() - else() - message(WARNING "gRPC_SSL_PROVIDER is \"module\" but BORINGSSL_ROOT_DIR is wrong") - endif() - if(gRPC_INSTALL) - message(WARNING "gRPC_INSTALL will be forced to FALSE because gRPC_SSL_PROVIDER is \"module\"") - set(gRPC_INSTALL FALSE) - endif() -elseif("${gRPC_SSL_PROVIDER}" STREQUAL "package") - find_package(OpenSSL) - if(TARGET OpenSSL::SSL) - set(_gRPC_SSL_LIBRARIES OpenSSL::SSL) - endif() - set(_gRPC_FIND_SSL "if(NOT OpenSSL_FOUND)\n find_package(OpenSSL)\nendif()") -endif() - -if("${gRPC_GFLAGS_PROVIDER}" STREQUAL "module") - if(NOT GFLAGS_ROOT_DIR) - set(GFLAGS_ROOT_DIR ${CMAKE_CURRENT_SOURCE_DIR}/third_party/gflags) - endif() - if(EXISTS "${GFLAGS_ROOT_DIR}/CMakeLists.txt") - add_subdirectory(${GFLAGS_ROOT_DIR} third_party/gflags) - if(TARGET gflags_static) - set(_gRPC_GFLAGS_LIBRARIES gflags_static) - endif() - else() - message(WARNING "gRPC_GFLAGS_PROVIDER is \"module\" but GFLAGS_ROOT_DIR is wrong") - endif() -elseif("${gRPC_GFLAGS_PROVIDER}" STREQUAL "package") - find_package(gflags) - if(TARGET gflags::gflags) - set(_gRPC_GFLAGS_LIBRARIES gflags::gflags) - endif() - set(_gRPC_FIND_GFLAGS "if(NOT gflags_FOUND)\n find_package(gflags)\nendif()") -endif() - -if("${gRPC_BENCHMARK_PROVIDER}" STREQUAL "module") - if(NOT BENCHMARK_ROOT_DIR) - set(BENCHMARK_ROOT_DIR ${CMAKE_CURRENT_SOURCE_DIR}/third_party/benchmark) - endif() - if(EXISTS "${BENCHMARK_ROOT_DIR}/CMakeLists.txt") - add_subdirectory(${BENCHMARK_ROOT_DIR} third_party/benchmark) - if(TARGET benchmark) - set(_gRPC_BENCHMARK_LIBRARIES benchmark) - endif() - else() - message(WARNING "gRPC_BENCHMARK_PROVIDER is \"module\" but BENCHMARK_ROOT_DIR is wrong") - endif() -elseif("${gRPC_BENCHMARK_PROVIDER}" STREQUAL "package") - find_package(benchmark) - if(TARGET benchmark::benchmark) - set(_gRPC_BENCHMARK_LIBRARIES benchmark::benchmark) - endif() - set(_gRPC_FIND_BENCHMARK "if(NOT benchmark_FOUND)\n find_package(benchmark)\nendif()") -endif() - -if(NOT MSVC) - set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -std=c99") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11") -endif() - -if(_gRPC_PLATFORM_MAC) - set(_gRPC_ALLTARGETS_LIBRARIES ${CMAKE_DL_LIBS} m pthread) -elseif(UNIX) - set(_gRPC_ALLTARGETS_LIBRARIES ${CMAKE_DL_LIBS} rt m pthread) -endif() - -if(WIN32 AND MSVC) - set(_gRPC_BASELIB_LIBRARIES wsock32 ws2_32) -endif() - -# Create directory for generated .proto files -set(_gRPC_PROTO_GENS_DIR ${CMAKE_BINARY_DIR}/gens) -file(MAKE_DIRECTORY ${_gRPC_PROTO_GENS_DIR}) - -# protobuf_generate_grpc_cpp -# -------------------------- -# -# Add custom commands to process ``.proto`` files to C++ using protoc and -# GRPC plugin:: -# -# protobuf_generate_grpc_cpp [...] -# -# ``ARGN`` -# ``.proto`` files -# -function(protobuf_generate_grpc_cpp) - if(NOT ARGN) - message(SEND_ERROR "Error: PROTOBUF_GENERATE_GRPC_CPP() called without any proto files") - return() - endif() - - set(_protobuf_include_path -I . -I ${PROTOBUF_WELLKNOWN_IMPORT_DIR}) - foreach(FIL ${ARGN}) - get_filename_component(ABS_FIL ${FIL} ABSOLUTE) - get_filename_component(FIL_WE ${FIL} NAME_WE) - file(RELATIVE_PATH REL_FIL ${CMAKE_CURRENT_SOURCE_DIR} ${ABS_FIL}) - get_filename_component(REL_DIR ${REL_FIL} DIRECTORY) - set(RELFIL_WE "${REL_DIR}/${FIL_WE}") - - add_custom_command( - OUTPUT "${_gRPC_PROTO_GENS_DIR}/${RELFIL_WE}.grpc.pb.cc" - "${_gRPC_PROTO_GENS_DIR}/${RELFIL_WE}.grpc.pb.h" - "${_gRPC_PROTO_GENS_DIR}/${RELFIL_WE}_mock.grpc.pb.h" - "${_gRPC_PROTO_GENS_DIR}/${RELFIL_WE}.pb.cc" - "${_gRPC_PROTO_GENS_DIR}/${RELFIL_WE}.pb.h" - COMMAND $ - ARGS --grpc_out=generate_mock_code=true:${_gRPC_PROTO_GENS_DIR} - --cpp_out=${_gRPC_PROTO_GENS_DIR} - --plugin=protoc-gen-grpc=$ - ${_protobuf_include_path} - ${REL_FIL} - DEPENDS ${ABS_FIL} ${_gRPC_PROTOBUF_PROTOC} grpc_cpp_plugin - WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} - COMMENT "Running gRPC C++ protocol buffer compiler on ${FIL}" - VERBATIM) - - set_source_files_properties("${_gRPC_PROTO_GENS_DIR}/${RELFIL_WE}.grpc.pb.cc" "${_gRPC_PROTO_GENS_DIR}/${RELFIL_WE}.grpc.pb.h" "${_gRPC_PROTO_GENS_DIR}/${RELFIL_WE}_mock.grpc.pb.h" "${_gRPC_PROTO_GENS_DIR}/${RELFIL_WE}.pb.cc" "${_gRPC_PROTO_GENS_DIR}/${RELFIL_WE}.pb.h" PROPERTIES GENERATED TRUE) - endforeach() -endfunction() - -add_custom_target(plugins - DEPENDS - grpc_cpp_plugin - grpc_csharp_plugin - grpc_node_plugin - grpc_objective_c_plugin - grpc_php_plugin - grpc_python_plugin - grpc_ruby_plugin -) - -add_custom_target(tools_c - DEPENDS - check_epollexclusive - gen_hpack_tables - gen_legal_metadata_characters - gen_percent_encoding_tables - grpc_create_jwt - grpc_print_google_default_creds_token - grpc_verify_jwt -) - -add_custom_target(tools_cxx - DEPENDS -) - -add_custom_target(tools - DEPENDS tools_c tools_cxx) - -if (gRPC_BUILD_TESTS) -add_custom_target(buildtests_c) -add_dependencies(buildtests_c alarm_test) -add_dependencies(buildtests_c algorithm_test) -add_dependencies(buildtests_c alloc_test) -add_dependencies(buildtests_c alpn_test) -add_dependencies(buildtests_c arena_test) -add_dependencies(buildtests_c bad_server_response_test) -add_dependencies(buildtests_c bdp_estimator_test) -add_dependencies(buildtests_c bin_decoder_test) -add_dependencies(buildtests_c bin_encoder_test) -add_dependencies(buildtests_c census_context_test) -add_dependencies(buildtests_c census_intrusive_hash_map_test) -add_dependencies(buildtests_c census_resource_test) -add_dependencies(buildtests_c census_trace_context_test) -add_dependencies(buildtests_c channel_create_test) -add_dependencies(buildtests_c chttp2_hpack_encoder_test) -add_dependencies(buildtests_c chttp2_stream_map_test) -add_dependencies(buildtests_c chttp2_varint_test) -add_dependencies(buildtests_c combiner_test) -add_dependencies(buildtests_c compression_test) -add_dependencies(buildtests_c concurrent_connectivity_test) -add_dependencies(buildtests_c connection_refused_test) -add_dependencies(buildtests_c dns_resolver_connectivity_test) -add_dependencies(buildtests_c dns_resolver_test) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_c dualstack_socket_test) -endif() -add_dependencies(buildtests_c endpoint_pair_test) -add_dependencies(buildtests_c error_test) -if(_gRPC_PLATFORM_LINUX) -add_dependencies(buildtests_c ev_epollsig_linux_test) -endif() -add_dependencies(buildtests_c fake_resolver_test) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_c fd_conservation_posix_test) -endif() -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_c fd_posix_test) -endif() -add_dependencies(buildtests_c fling_client) -add_dependencies(buildtests_c fling_server) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_c fling_stream_test) -endif() -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_c fling_test) -endif() -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_c goaway_server_test) -endif() -add_dependencies(buildtests_c gpr_avl_test) -add_dependencies(buildtests_c gpr_backoff_test) -add_dependencies(buildtests_c gpr_cmdline_test) -add_dependencies(buildtests_c gpr_cpu_test) -add_dependencies(buildtests_c gpr_env_test) -add_dependencies(buildtests_c gpr_histogram_test) -add_dependencies(buildtests_c gpr_host_port_test) -add_dependencies(buildtests_c gpr_log_test) -add_dependencies(buildtests_c gpr_mpscq_test) -add_dependencies(buildtests_c gpr_spinlock_test) -add_dependencies(buildtests_c gpr_stack_lockfree_test) -add_dependencies(buildtests_c gpr_string_test) -add_dependencies(buildtests_c gpr_sync_test) -add_dependencies(buildtests_c gpr_thd_test) -add_dependencies(buildtests_c gpr_time_test) -add_dependencies(buildtests_c gpr_tls_test) -add_dependencies(buildtests_c gpr_useful_test) -add_dependencies(buildtests_c grpc_auth_context_test) -add_dependencies(buildtests_c grpc_b64_test) -add_dependencies(buildtests_c grpc_byte_buffer_reader_test) -add_dependencies(buildtests_c grpc_channel_args_test) -add_dependencies(buildtests_c grpc_channel_stack_test) -add_dependencies(buildtests_c grpc_completion_queue_test) -add_dependencies(buildtests_c grpc_completion_queue_threading_test) -add_dependencies(buildtests_c grpc_credentials_test) -add_dependencies(buildtests_c grpc_fetch_oauth2) -add_dependencies(buildtests_c grpc_invalid_channel_args_test) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_c grpc_json_token_test) -endif() -add_dependencies(buildtests_c grpc_jwt_verifier_test) -add_dependencies(buildtests_c grpc_security_connector_test) -if(_gRPC_PLATFORM_LINUX) -add_dependencies(buildtests_c handshake_client) -endif() -if(_gRPC_PLATFORM_LINUX) -add_dependencies(buildtests_c handshake_server) -endif() -add_dependencies(buildtests_c hpack_parser_test) -add_dependencies(buildtests_c hpack_table_test) -add_dependencies(buildtests_c http_parser_test) -add_dependencies(buildtests_c httpcli_format_request_test) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_c httpcli_test) -endif() -if(_gRPC_PLATFORM_LINUX) -add_dependencies(buildtests_c httpscli_test) -endif() -add_dependencies(buildtests_c init_test) -add_dependencies(buildtests_c invalid_call_argument_test) -add_dependencies(buildtests_c json_rewrite) -add_dependencies(buildtests_c json_rewrite_test) -add_dependencies(buildtests_c json_stream_error_test) -add_dependencies(buildtests_c json_test) -add_dependencies(buildtests_c lame_client_test) -add_dependencies(buildtests_c lb_policies_test) -add_dependencies(buildtests_c load_file_test) -add_dependencies(buildtests_c memory_profile_client) -add_dependencies(buildtests_c memory_profile_server) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_c memory_profile_test) -endif() -add_dependencies(buildtests_c message_compress_test) -add_dependencies(buildtests_c minimal_stack_is_minimal_test) -add_dependencies(buildtests_c mlog_test) -add_dependencies(buildtests_c multiple_server_queues_test) -add_dependencies(buildtests_c murmur_hash_test) -add_dependencies(buildtests_c no_server_test) -add_dependencies(buildtests_c num_external_connectivity_watchers_test) -add_dependencies(buildtests_c parse_address_test) -add_dependencies(buildtests_c percent_encoding_test) -if(_gRPC_PLATFORM_LINUX) -add_dependencies(buildtests_c pollset_set_test) -endif() -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_c resolve_address_posix_test) -endif() -add_dependencies(buildtests_c resolve_address_test) -add_dependencies(buildtests_c resource_quota_test) -add_dependencies(buildtests_c secure_channel_create_test) -add_dependencies(buildtests_c secure_endpoint_test) -add_dependencies(buildtests_c sequential_connectivity_test) -add_dependencies(buildtests_c server_chttp2_test) -add_dependencies(buildtests_c server_test) -add_dependencies(buildtests_c slice_buffer_test) -add_dependencies(buildtests_c slice_hash_table_test) -add_dependencies(buildtests_c slice_string_helpers_test) -add_dependencies(buildtests_c slice_test) -add_dependencies(buildtests_c sockaddr_resolver_test) -add_dependencies(buildtests_c sockaddr_utils_test) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_c socket_utils_test) -endif() -add_dependencies(buildtests_c status_conversion_test) -add_dependencies(buildtests_c stream_compression_test) -add_dependencies(buildtests_c stream_owned_slice_test) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_c tcp_client_posix_test) -endif() -add_dependencies(buildtests_c tcp_client_uv_test) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_c tcp_posix_test) -endif() -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_c tcp_server_posix_test) -endif() -add_dependencies(buildtests_c tcp_server_uv_test) -add_dependencies(buildtests_c time_averaged_stats_test) -add_dependencies(buildtests_c timeout_encoding_test) -add_dependencies(buildtests_c timer_heap_test) -add_dependencies(buildtests_c timer_list_test) -add_dependencies(buildtests_c transport_connectivity_state_test) -add_dependencies(buildtests_c transport_metadata_test) -add_dependencies(buildtests_c transport_pid_controller_test) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_c transport_security_test) -endif() -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_c udp_server_test) -endif() -add_dependencies(buildtests_c uri_parser_test) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_c wakeup_fd_cv_test) -endif() -add_dependencies(buildtests_c public_headers_must_be_c89) -add_dependencies(buildtests_c badreq_bad_client_test) -add_dependencies(buildtests_c connection_prefix_bad_client_test) -add_dependencies(buildtests_c head_of_line_blocking_bad_client_test) -add_dependencies(buildtests_c headers_bad_client_test) -add_dependencies(buildtests_c initial_settings_frame_bad_client_test) -add_dependencies(buildtests_c large_metadata_bad_client_test) -add_dependencies(buildtests_c server_registered_method_bad_client_test) -add_dependencies(buildtests_c simple_request_bad_client_test) -add_dependencies(buildtests_c unknown_frame_bad_client_test) -add_dependencies(buildtests_c window_overflow_bad_client_test) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_c bad_ssl_cert_server) -endif() -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_c bad_ssl_cert_test) -endif() -add_dependencies(buildtests_c h2_census_test) -add_dependencies(buildtests_c h2_compress_test) -add_dependencies(buildtests_c h2_fakesec_test) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_c h2_fd_test) -endif() -add_dependencies(buildtests_c h2_full_test) -if(_gRPC_PLATFORM_LINUX) -add_dependencies(buildtests_c h2_full+pipe_test) -endif() -add_dependencies(buildtests_c h2_full+trace_test) -add_dependencies(buildtests_c h2_full+workarounds_test) -add_dependencies(buildtests_c h2_http_proxy_test) -add_dependencies(buildtests_c h2_load_reporting_test) -add_dependencies(buildtests_c h2_oauth2_test) -add_dependencies(buildtests_c h2_proxy_test) -add_dependencies(buildtests_c h2_sockpair_test) -add_dependencies(buildtests_c h2_sockpair+trace_test) -add_dependencies(buildtests_c h2_sockpair_1byte_test) -add_dependencies(buildtests_c h2_ssl_test) -add_dependencies(buildtests_c h2_ssl_cert_test) -add_dependencies(buildtests_c h2_ssl_proxy_test) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_c h2_uds_test) -endif() -add_dependencies(buildtests_c inproc_test) -add_dependencies(buildtests_c h2_census_nosec_test) -add_dependencies(buildtests_c h2_compress_nosec_test) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_c h2_fd_nosec_test) -endif() -add_dependencies(buildtests_c h2_full_nosec_test) -if(_gRPC_PLATFORM_LINUX) -add_dependencies(buildtests_c h2_full+pipe_nosec_test) -endif() -add_dependencies(buildtests_c h2_full+trace_nosec_test) -add_dependencies(buildtests_c h2_full+workarounds_nosec_test) -add_dependencies(buildtests_c h2_http_proxy_nosec_test) -add_dependencies(buildtests_c h2_load_reporting_nosec_test) -add_dependencies(buildtests_c h2_proxy_nosec_test) -add_dependencies(buildtests_c h2_sockpair_nosec_test) -add_dependencies(buildtests_c h2_sockpair+trace_nosec_test) -add_dependencies(buildtests_c h2_sockpair_1byte_nosec_test) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_c h2_uds_nosec_test) -endif() -add_dependencies(buildtests_c inproc_nosec_test) -add_dependencies(buildtests_c api_fuzzer_one_entry) -add_dependencies(buildtests_c client_fuzzer_one_entry) -add_dependencies(buildtests_c hpack_parser_fuzzer_test_one_entry) -add_dependencies(buildtests_c http_request_fuzzer_test_one_entry) -add_dependencies(buildtests_c http_response_fuzzer_test_one_entry) -add_dependencies(buildtests_c json_fuzzer_test_one_entry) -add_dependencies(buildtests_c nanopb_fuzzer_response_test_one_entry) -add_dependencies(buildtests_c nanopb_fuzzer_serverlist_test_one_entry) -add_dependencies(buildtests_c percent_decode_fuzzer_one_entry) -add_dependencies(buildtests_c percent_encode_fuzzer_one_entry) -add_dependencies(buildtests_c server_fuzzer_one_entry) -add_dependencies(buildtests_c ssl_server_fuzzer_one_entry) -add_dependencies(buildtests_c uri_fuzzer_test_one_entry) - -add_custom_target(buildtests_cxx) -add_dependencies(buildtests_cxx alarm_cpp_test) -add_dependencies(buildtests_cxx async_end2end_test) -add_dependencies(buildtests_cxx auth_property_iterator_test) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx bm_arena) -endif() -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx bm_call_create) -endif() -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx bm_chttp2_hpack) -endif() -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx bm_chttp2_transport) -endif() -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx bm_closure) -endif() -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx bm_cq) -endif() -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx bm_cq_multiple_threads) -endif() -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx bm_error) -endif() -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx bm_fullstack_streaming_ping_pong) -endif() -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx bm_fullstack_streaming_pump) -endif() -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx bm_fullstack_trickle) -endif() -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx bm_fullstack_unary_ping_pong) -endif() -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx bm_metadata) -endif() -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx bm_pollset) -endif() -add_dependencies(buildtests_cxx channel_arguments_test) -add_dependencies(buildtests_cxx channel_filter_test) -add_dependencies(buildtests_cxx cli_call_test) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx client_crash_test) -endif() -add_dependencies(buildtests_cxx client_crash_test_server) -add_dependencies(buildtests_cxx client_lb_end2end_test) -add_dependencies(buildtests_cxx codegen_test_full) -add_dependencies(buildtests_cxx codegen_test_minimal) -add_dependencies(buildtests_cxx credentials_test) -add_dependencies(buildtests_cxx cxx_byte_buffer_test) -add_dependencies(buildtests_cxx cxx_slice_test) -add_dependencies(buildtests_cxx cxx_string_ref_test) -add_dependencies(buildtests_cxx cxx_time_test) -add_dependencies(buildtests_cxx end2end_test) -add_dependencies(buildtests_cxx error_details_test) -add_dependencies(buildtests_cxx filter_end2end_test) -add_dependencies(buildtests_cxx generic_end2end_test) -add_dependencies(buildtests_cxx golden_file_test) -add_dependencies(buildtests_cxx grpc_cli) -add_dependencies(buildtests_cxx grpc_tool_test) -add_dependencies(buildtests_cxx grpclb_api_test) -add_dependencies(buildtests_cxx grpclb_end2end_test) -add_dependencies(buildtests_cxx grpclb_test) -add_dependencies(buildtests_cxx health_service_end2end_test) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx http2_client) -endif() -add_dependencies(buildtests_cxx hybrid_end2end_test) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx interop_client) -endif() -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx interop_server) -endif() -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx interop_test) -endif() -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx json_run_localhost) -endif() -add_dependencies(buildtests_cxx memory_test) -add_dependencies(buildtests_cxx metrics_client) -add_dependencies(buildtests_cxx mock_test) -add_dependencies(buildtests_cxx noop-benchmark) -add_dependencies(buildtests_cxx proto_server_reflection_test) -add_dependencies(buildtests_cxx proto_utils_test) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx qps_interarrival_test) -endif() -add_dependencies(buildtests_cxx qps_json_driver) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx qps_openloop_test) -endif() -add_dependencies(buildtests_cxx qps_worker) -add_dependencies(buildtests_cxx reconnect_interop_client) -add_dependencies(buildtests_cxx reconnect_interop_server) -add_dependencies(buildtests_cxx secure_auth_context_test) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx secure_sync_unary_ping_pong_test) -endif() -add_dependencies(buildtests_cxx server_builder_plugin_test) -add_dependencies(buildtests_cxx server_builder_test) -add_dependencies(buildtests_cxx server_context_test_spouse_test) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx server_crash_test) -endif() -add_dependencies(buildtests_cxx server_crash_test_client) -add_dependencies(buildtests_cxx server_request_call_test) -add_dependencies(buildtests_cxx shutdown_test) -add_dependencies(buildtests_cxx status_test) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx streaming_throughput_test) -endif() -add_dependencies(buildtests_cxx stress_test) -add_dependencies(buildtests_cxx thread_manager_test) -add_dependencies(buildtests_cxx thread_stress_test) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) -add_dependencies(buildtests_cxx writes_per_rpc_test) -endif() - -add_custom_target(buildtests - DEPENDS buildtests_c buildtests_cxx) -endif (gRPC_BUILD_TESTS) - - -add_library(gpr - src/core/lib/profiling/basic_timers.c - src/core/lib/profiling/stap_timers.c - src/core/lib/support/alloc.c - src/core/lib/support/arena.c - src/core/lib/support/atm.c - src/core/lib/support/avl.c - src/core/lib/support/backoff.c - src/core/lib/support/cmdline.c - src/core/lib/support/cpu_iphone.c - src/core/lib/support/cpu_linux.c - src/core/lib/support/cpu_posix.c - src/core/lib/support/cpu_windows.c - src/core/lib/support/env_linux.c - src/core/lib/support/env_posix.c - src/core/lib/support/env_windows.c - src/core/lib/support/histogram.c - src/core/lib/support/host_port.c - src/core/lib/support/log.c - src/core/lib/support/log_android.c - src/core/lib/support/log_linux.c - src/core/lib/support/log_posix.c - src/core/lib/support/log_windows.c - src/core/lib/support/mpscq.c - src/core/lib/support/murmur_hash.c - src/core/lib/support/stack_lockfree.c - src/core/lib/support/string.c - src/core/lib/support/string_posix.c - src/core/lib/support/string_util_windows.c - src/core/lib/support/string_windows.c - src/core/lib/support/subprocess_posix.c - src/core/lib/support/subprocess_windows.c - src/core/lib/support/sync.c - src/core/lib/support/sync_posix.c - src/core/lib/support/sync_windows.c - src/core/lib/support/thd.c - src/core/lib/support/thd_posix.c - src/core/lib/support/thd_windows.c - src/core/lib/support/time.c - src/core/lib/support/time_posix.c - src/core/lib/support/time_precise.c - src/core/lib/support/time_windows.c - src/core/lib/support/tls_pthread.c - src/core/lib/support/tmpfile_msys.c - src/core/lib/support/tmpfile_posix.c - src/core/lib/support/tmpfile_windows.c - src/core/lib/support/wrap_memcpy.c -) - -if(WIN32 AND MSVC) - set_target_properties(gpr PROPERTIES COMPILE_PDB_NAME "gpr" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/gpr.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - - -target_include_directories(gpr - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(gpr - ${_gRPC_ALLTARGETS_LIBRARIES} -) - -foreach(_hdr - include/grpc/support/alloc.h - include/grpc/support/atm.h - include/grpc/support/atm_gcc_atomic.h - include/grpc/support/atm_gcc_sync.h - include/grpc/support/atm_windows.h - include/grpc/support/avl.h - include/grpc/support/cmdline.h - include/grpc/support/cpu.h - include/grpc/support/histogram.h - include/grpc/support/host_port.h - include/grpc/support/log.h - include/grpc/support/log_windows.h - include/grpc/support/port_platform.h - include/grpc/support/string_util.h - include/grpc/support/subprocess.h - include/grpc/support/sync.h - include/grpc/support/sync_generic.h - include/grpc/support/sync_posix.h - include/grpc/support/sync_windows.h - include/grpc/support/thd.h - include/grpc/support/time.h - include/grpc/support/tls.h - include/grpc/support/tls_gcc.h - include/grpc/support/tls_msvc.h - include/grpc/support/tls_pthread.h - include/grpc/support/useful.h - include/grpc/impl/codegen/atm.h - include/grpc/impl/codegen/atm_gcc_atomic.h - include/grpc/impl/codegen/atm_gcc_sync.h - include/grpc/impl/codegen/atm_windows.h - include/grpc/impl/codegen/gpr_slice.h - include/grpc/impl/codegen/gpr_types.h - include/grpc/impl/codegen/port_platform.h - include/grpc/impl/codegen/sync.h - include/grpc/impl/codegen/sync_generic.h - include/grpc/impl/codegen/sync_posix.h - include/grpc/impl/codegen/sync_windows.h -) - string(REPLACE "include/" "" _path ${_hdr}) - get_filename_component(_path ${_path} PATH) - install(FILES ${_hdr} - DESTINATION "${gRPC_INSTALL_INCLUDEDIR}/${_path}" - ) -endforeach() - - -if (gRPC_INSTALL) - install(TARGETS gpr EXPORT gRPCTargets - RUNTIME DESTINATION ${gRPC_INSTALL_BINDIR} - LIBRARY DESTINATION ${gRPC_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${gRPC_INSTALL_LIBDIR} - ) -endif() - -if (gRPC_BUILD_TESTS) - -add_library(gpr_test_util - test/core/util/test_config.c -) - -if(WIN32 AND MSVC) - set_target_properties(gpr_test_util PROPERTIES COMPILE_PDB_NAME "gpr_test_util" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/gpr_test_util.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - - -target_include_directories(gpr_test_util - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(gpr_test_util - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr -) - - -endif (gRPC_BUILD_TESTS) - -add_library(grpc - src/core/lib/surface/init.c - src/core/lib/channel/channel_args.c - src/core/lib/channel/channel_stack.c - src/core/lib/channel/channel_stack_builder.c - src/core/lib/channel/connected_channel.c - src/core/lib/channel/handshaker.c - src/core/lib/channel/handshaker_factory.c - src/core/lib/channel/handshaker_registry.c - src/core/lib/compression/compression.c - src/core/lib/compression/message_compress.c - src/core/lib/compression/stream_compression.c - src/core/lib/http/format_request.c - src/core/lib/http/httpcli.c - src/core/lib/http/parser.c - src/core/lib/iomgr/closure.c - src/core/lib/iomgr/combiner.c - src/core/lib/iomgr/endpoint.c - src/core/lib/iomgr/endpoint_pair_posix.c - src/core/lib/iomgr/endpoint_pair_uv.c - src/core/lib/iomgr/endpoint_pair_windows.c - src/core/lib/iomgr/error.c - src/core/lib/iomgr/ev_epoll1_linux.c - src/core/lib/iomgr/ev_epoll_limited_pollers_linux.c - src/core/lib/iomgr/ev_epoll_thread_pool_linux.c - src/core/lib/iomgr/ev_epollex_linux.c - src/core/lib/iomgr/ev_epollsig_linux.c - src/core/lib/iomgr/ev_poll_posix.c - src/core/lib/iomgr/ev_posix.c - src/core/lib/iomgr/ev_windows.c - src/core/lib/iomgr/exec_ctx.c - src/core/lib/iomgr/executor.c - src/core/lib/iomgr/iocp_windows.c - src/core/lib/iomgr/iomgr.c - src/core/lib/iomgr/iomgr_posix.c - src/core/lib/iomgr/iomgr_uv.c - src/core/lib/iomgr/iomgr_windows.c - src/core/lib/iomgr/is_epollexclusive_available.c - src/core/lib/iomgr/load_file.c - src/core/lib/iomgr/lockfree_event.c - src/core/lib/iomgr/network_status_tracker.c - src/core/lib/iomgr/polling_entity.c - src/core/lib/iomgr/pollset_set_uv.c - src/core/lib/iomgr/pollset_set_windows.c - src/core/lib/iomgr/pollset_uv.c - src/core/lib/iomgr/pollset_windows.c - src/core/lib/iomgr/resolve_address_posix.c - src/core/lib/iomgr/resolve_address_uv.c - src/core/lib/iomgr/resolve_address_windows.c - src/core/lib/iomgr/resource_quota.c - src/core/lib/iomgr/sockaddr_utils.c - src/core/lib/iomgr/socket_factory_posix.c - src/core/lib/iomgr/socket_mutator.c - src/core/lib/iomgr/socket_utils_common_posix.c - src/core/lib/iomgr/socket_utils_linux.c - src/core/lib/iomgr/socket_utils_posix.c - src/core/lib/iomgr/socket_utils_uv.c - src/core/lib/iomgr/socket_utils_windows.c - src/core/lib/iomgr/socket_windows.c - src/core/lib/iomgr/tcp_client_posix.c - src/core/lib/iomgr/tcp_client_uv.c - src/core/lib/iomgr/tcp_client_windows.c - src/core/lib/iomgr/tcp_posix.c - src/core/lib/iomgr/tcp_server_posix.c - src/core/lib/iomgr/tcp_server_utils_posix_common.c - src/core/lib/iomgr/tcp_server_utils_posix_ifaddrs.c - src/core/lib/iomgr/tcp_server_utils_posix_noifaddrs.c - src/core/lib/iomgr/tcp_server_uv.c - src/core/lib/iomgr/tcp_server_windows.c - src/core/lib/iomgr/tcp_uv.c - src/core/lib/iomgr/tcp_windows.c - src/core/lib/iomgr/time_averaged_stats.c - src/core/lib/iomgr/timer_generic.c - src/core/lib/iomgr/timer_heap.c - src/core/lib/iomgr/timer_manager.c - src/core/lib/iomgr/timer_uv.c - src/core/lib/iomgr/udp_server.c - src/core/lib/iomgr/unix_sockets_posix.c - src/core/lib/iomgr/unix_sockets_posix_noop.c - src/core/lib/iomgr/wakeup_fd_cv.c - src/core/lib/iomgr/wakeup_fd_eventfd.c - src/core/lib/iomgr/wakeup_fd_nospecial.c - src/core/lib/iomgr/wakeup_fd_pipe.c - src/core/lib/iomgr/wakeup_fd_posix.c - src/core/lib/json/json.c - src/core/lib/json/json_reader.c - src/core/lib/json/json_string.c - src/core/lib/json/json_writer.c - src/core/lib/slice/b64.c - src/core/lib/slice/percent_encoding.c - src/core/lib/slice/slice.c - src/core/lib/slice/slice_buffer.c - src/core/lib/slice/slice_hash_table.c - src/core/lib/slice/slice_intern.c - src/core/lib/slice/slice_string_helpers.c - src/core/lib/surface/alarm.c - src/core/lib/surface/api_trace.c - src/core/lib/surface/byte_buffer.c - src/core/lib/surface/byte_buffer_reader.c - src/core/lib/surface/call.c - src/core/lib/surface/call_details.c - src/core/lib/surface/call_log_batch.c - src/core/lib/surface/channel.c - src/core/lib/surface/channel_init.c - src/core/lib/surface/channel_ping.c - src/core/lib/surface/channel_stack_type.c - src/core/lib/surface/completion_queue.c - src/core/lib/surface/completion_queue_factory.c - src/core/lib/surface/event_string.c - src/core/lib/surface/lame_client.cc - src/core/lib/surface/metadata_array.c - src/core/lib/surface/server.c - src/core/lib/surface/validate_metadata.c - src/core/lib/surface/version.c - src/core/lib/transport/bdp_estimator.c - src/core/lib/transport/byte_stream.c - src/core/lib/transport/connectivity_state.c - src/core/lib/transport/error_utils.c - src/core/lib/transport/metadata.c - src/core/lib/transport/metadata_batch.c - src/core/lib/transport/pid_controller.c - src/core/lib/transport/service_config.c - src/core/lib/transport/static_metadata.c - src/core/lib/transport/status_conversion.c - src/core/lib/transport/timeout_encoding.c - src/core/lib/transport/transport.c - src/core/lib/transport/transport_op_string.c - src/core/lib/debug/trace.c - src/core/ext/transport/chttp2/server/secure/server_secure_chttp2.c - src/core/ext/transport/chttp2/transport/bin_decoder.c - src/core/ext/transport/chttp2/transport/bin_encoder.c - src/core/ext/transport/chttp2/transport/chttp2_plugin.c - src/core/ext/transport/chttp2/transport/chttp2_transport.c - src/core/ext/transport/chttp2/transport/frame_data.c - src/core/ext/transport/chttp2/transport/frame_goaway.c - src/core/ext/transport/chttp2/transport/frame_ping.c - src/core/ext/transport/chttp2/transport/frame_rst_stream.c - src/core/ext/transport/chttp2/transport/frame_settings.c - src/core/ext/transport/chttp2/transport/frame_window_update.c - src/core/ext/transport/chttp2/transport/hpack_encoder.c - src/core/ext/transport/chttp2/transport/hpack_parser.c - src/core/ext/transport/chttp2/transport/hpack_table.c - src/core/ext/transport/chttp2/transport/http2_settings.c - src/core/ext/transport/chttp2/transport/huffsyms.c - src/core/ext/transport/chttp2/transport/incoming_metadata.c - src/core/ext/transport/chttp2/transport/parsing.c - src/core/ext/transport/chttp2/transport/stream_lists.c - src/core/ext/transport/chttp2/transport/stream_map.c - src/core/ext/transport/chttp2/transport/varint.c - src/core/ext/transport/chttp2/transport/writing.c - src/core/ext/transport/chttp2/alpn/alpn.c - src/core/ext/filters/http/client/http_client_filter.c - src/core/ext/filters/http/http_filters_plugin.c - src/core/ext/filters/http/message_compress/message_compress_filter.c - src/core/ext/filters/http/server/http_server_filter.c - src/core/lib/http/httpcli_security_connector.c - src/core/lib/security/context/security_context.c - src/core/lib/security/credentials/composite/composite_credentials.c - src/core/lib/security/credentials/credentials.c - src/core/lib/security/credentials/credentials_metadata.c - src/core/lib/security/credentials/fake/fake_credentials.c - src/core/lib/security/credentials/google_default/credentials_generic.c - src/core/lib/security/credentials/google_default/google_default_credentials.c - src/core/lib/security/credentials/iam/iam_credentials.c - src/core/lib/security/credentials/jwt/json_token.c - src/core/lib/security/credentials/jwt/jwt_credentials.c - src/core/lib/security/credentials/jwt/jwt_verifier.c - src/core/lib/security/credentials/oauth2/oauth2_credentials.c - src/core/lib/security/credentials/plugin/plugin_credentials.c - src/core/lib/security/credentials/ssl/ssl_credentials.c - src/core/lib/security/transport/client_auth_filter.c - src/core/lib/security/transport/lb_targets_info.c - src/core/lib/security/transport/secure_endpoint.c - src/core/lib/security/transport/security_connector.c - src/core/lib/security/transport/security_handshaker.c - src/core/lib/security/transport/server_auth_filter.c - src/core/lib/security/transport/tsi_error.c - src/core/lib/security/util/json_util.c - src/core/lib/surface/init_secure.c - src/core/tsi/fake_transport_security.c - src/core/tsi/gts_transport_security.c - src/core/tsi/ssl_transport_security.c - src/core/tsi/transport_security.c - src/core/tsi/transport_security_adapter.c - src/core/ext/transport/chttp2/server/chttp2_server.c - src/core/ext/transport/chttp2/client/secure/secure_channel_create.c - src/core/ext/filters/client_channel/channel_connectivity.c - src/core/ext/filters/client_channel/client_channel.c - src/core/ext/filters/client_channel/client_channel_factory.c - src/core/ext/filters/client_channel/client_channel_plugin.c - src/core/ext/filters/client_channel/connector.c - src/core/ext/filters/client_channel/http_connect_handshaker.c - src/core/ext/filters/client_channel/http_proxy.c - src/core/ext/filters/client_channel/lb_policy.c - src/core/ext/filters/client_channel/lb_policy_factory.c - src/core/ext/filters/client_channel/lb_policy_registry.c - src/core/ext/filters/client_channel/parse_address.c - src/core/ext/filters/client_channel/proxy_mapper.c - src/core/ext/filters/client_channel/proxy_mapper_registry.c - src/core/ext/filters/client_channel/resolver.c - src/core/ext/filters/client_channel/resolver_factory.c - src/core/ext/filters/client_channel/resolver_registry.c - src/core/ext/filters/client_channel/retry_throttle.c - src/core/ext/filters/client_channel/subchannel.c - src/core/ext/filters/client_channel/subchannel_index.c - src/core/ext/filters/client_channel/uri_parser.c - src/core/ext/filters/deadline/deadline_filter.c - src/core/ext/transport/chttp2/client/chttp2_connector.c - src/core/ext/transport/chttp2/server/insecure/server_chttp2.c - src/core/ext/transport/chttp2/server/insecure/server_chttp2_posix.c - src/core/ext/transport/chttp2/client/insecure/channel_create.c - src/core/ext/transport/chttp2/client/insecure/channel_create_posix.c - src/core/ext/transport/inproc/inproc_plugin.c - src/core/ext/transport/inproc/inproc_transport.c - src/core/ext/filters/client_channel/lb_policy/grpclb/client_load_reporting_filter.c - src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb.c - src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb_channel_secure.c - src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb_client_stats.c - src/core/ext/filters/client_channel/lb_policy/grpclb/load_balancer_api.c - src/core/ext/filters/client_channel/lb_policy/grpclb/proto/grpc/lb/v1/load_balancer.pb.c - third_party/nanopb/pb_common.c - third_party/nanopb/pb_decode.c - third_party/nanopb/pb_encode.c - src/core/ext/filters/client_channel/resolver/fake/fake_resolver.c - src/core/ext/filters/client_channel/lb_policy/pick_first/pick_first.c - src/core/ext/filters/client_channel/lb_policy/round_robin/round_robin.c - src/core/ext/filters/client_channel/resolver/dns/c_ares/dns_resolver_ares.c - src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_ev_driver_posix.c - src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper.c - src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper_fallback.c - src/core/ext/filters/client_channel/resolver/dns/native/dns_resolver.c - src/core/ext/filters/client_channel/resolver/sockaddr/sockaddr_resolver.c - src/core/ext/filters/load_reporting/load_reporting.c - src/core/ext/filters/load_reporting/load_reporting_filter.c - src/core/ext/census/base_resources.c - src/core/ext/census/context.c - src/core/ext/census/gen/census.pb.c - src/core/ext/census/gen/trace_context.pb.c - src/core/ext/census/grpc_context.c - src/core/ext/census/grpc_filter.c - src/core/ext/census/grpc_plugin.c - src/core/ext/census/initialize.c - src/core/ext/census/intrusive_hash_map.c - src/core/ext/census/mlog.c - src/core/ext/census/operation.c - src/core/ext/census/placeholders.c - src/core/ext/census/resource.c - src/core/ext/census/trace_context.c - src/core/ext/census/tracing.c - src/core/ext/filters/max_age/max_age_filter.c - src/core/ext/filters/message_size/message_size_filter.c - src/core/ext/filters/workarounds/workaround_cronet_compression_filter.c - src/core/ext/filters/workarounds/workaround_utils.c - src/core/plugin_registry/grpc_plugin_registry.c -) - -if(WIN32 AND MSVC) - set_target_properties(grpc PROPERTIES COMPILE_PDB_NAME "grpc" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/grpc.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - - -target_include_directories(grpc - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(grpc - ${_gRPC_BASELIB_LIBRARIES} - ${_gRPC_SSL_LIBRARIES} - ${_gRPC_ZLIB_LIBRARIES} - ${_gRPC_CARES_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr -) - -foreach(_hdr - include/grpc/byte_buffer.h - include/grpc/byte_buffer_reader.h - include/grpc/compression.h - include/grpc/grpc.h - include/grpc/grpc_posix.h - include/grpc/grpc_security_constants.h - include/grpc/load_reporting.h - include/grpc/slice.h - include/grpc/slice_buffer.h - include/grpc/status.h - include/grpc/support/workaround_list.h - include/grpc/impl/codegen/byte_buffer_reader.h - include/grpc/impl/codegen/compression_types.h - include/grpc/impl/codegen/connectivity_state.h - include/grpc/impl/codegen/exec_ctx_fwd.h - include/grpc/impl/codegen/grpc_types.h - include/grpc/impl/codegen/propagation_bits.h - include/grpc/impl/codegen/slice.h - include/grpc/impl/codegen/status.h - include/grpc/impl/codegen/atm.h - include/grpc/impl/codegen/atm_gcc_atomic.h - include/grpc/impl/codegen/atm_gcc_sync.h - include/grpc/impl/codegen/atm_windows.h - include/grpc/impl/codegen/gpr_slice.h - include/grpc/impl/codegen/gpr_types.h - include/grpc/impl/codegen/port_platform.h - include/grpc/impl/codegen/sync.h - include/grpc/impl/codegen/sync_generic.h - include/grpc/impl/codegen/sync_posix.h - include/grpc/impl/codegen/sync_windows.h - include/grpc/grpc_security.h - include/grpc/census.h -) - string(REPLACE "include/" "" _path ${_hdr}) - get_filename_component(_path ${_path} PATH) - install(FILES ${_hdr} - DESTINATION "${gRPC_INSTALL_INCLUDEDIR}/${_path}" - ) -endforeach() - - -if (gRPC_INSTALL) - install(TARGETS grpc EXPORT gRPCTargets - RUNTIME DESTINATION ${gRPC_INSTALL_BINDIR} - LIBRARY DESTINATION ${gRPC_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${gRPC_INSTALL_LIBDIR} - ) -endif() - - -add_library(grpc_cronet - src/core/lib/surface/init.c - src/core/lib/channel/channel_args.c - src/core/lib/channel/channel_stack.c - src/core/lib/channel/channel_stack_builder.c - src/core/lib/channel/connected_channel.c - src/core/lib/channel/handshaker.c - src/core/lib/channel/handshaker_factory.c - src/core/lib/channel/handshaker_registry.c - src/core/lib/compression/compression.c - src/core/lib/compression/message_compress.c - src/core/lib/compression/stream_compression.c - src/core/lib/http/format_request.c - src/core/lib/http/httpcli.c - src/core/lib/http/parser.c - src/core/lib/iomgr/closure.c - src/core/lib/iomgr/combiner.c - src/core/lib/iomgr/endpoint.c - src/core/lib/iomgr/endpoint_pair_posix.c - src/core/lib/iomgr/endpoint_pair_uv.c - src/core/lib/iomgr/endpoint_pair_windows.c - src/core/lib/iomgr/error.c - src/core/lib/iomgr/ev_epoll1_linux.c - src/core/lib/iomgr/ev_epoll_limited_pollers_linux.c - src/core/lib/iomgr/ev_epoll_thread_pool_linux.c - src/core/lib/iomgr/ev_epollex_linux.c - src/core/lib/iomgr/ev_epollsig_linux.c - src/core/lib/iomgr/ev_poll_posix.c - src/core/lib/iomgr/ev_posix.c - src/core/lib/iomgr/ev_windows.c - src/core/lib/iomgr/exec_ctx.c - src/core/lib/iomgr/executor.c - src/core/lib/iomgr/iocp_windows.c - src/core/lib/iomgr/iomgr.c - src/core/lib/iomgr/iomgr_posix.c - src/core/lib/iomgr/iomgr_uv.c - src/core/lib/iomgr/iomgr_windows.c - src/core/lib/iomgr/is_epollexclusive_available.c - src/core/lib/iomgr/load_file.c - src/core/lib/iomgr/lockfree_event.c - src/core/lib/iomgr/network_status_tracker.c - src/core/lib/iomgr/polling_entity.c - src/core/lib/iomgr/pollset_set_uv.c - src/core/lib/iomgr/pollset_set_windows.c - src/core/lib/iomgr/pollset_uv.c - src/core/lib/iomgr/pollset_windows.c - src/core/lib/iomgr/resolve_address_posix.c - src/core/lib/iomgr/resolve_address_uv.c - src/core/lib/iomgr/resolve_address_windows.c - src/core/lib/iomgr/resource_quota.c - src/core/lib/iomgr/sockaddr_utils.c - src/core/lib/iomgr/socket_factory_posix.c - src/core/lib/iomgr/socket_mutator.c - src/core/lib/iomgr/socket_utils_common_posix.c - src/core/lib/iomgr/socket_utils_linux.c - src/core/lib/iomgr/socket_utils_posix.c - src/core/lib/iomgr/socket_utils_uv.c - src/core/lib/iomgr/socket_utils_windows.c - src/core/lib/iomgr/socket_windows.c - src/core/lib/iomgr/tcp_client_posix.c - src/core/lib/iomgr/tcp_client_uv.c - src/core/lib/iomgr/tcp_client_windows.c - src/core/lib/iomgr/tcp_posix.c - src/core/lib/iomgr/tcp_server_posix.c - src/core/lib/iomgr/tcp_server_utils_posix_common.c - src/core/lib/iomgr/tcp_server_utils_posix_ifaddrs.c - src/core/lib/iomgr/tcp_server_utils_posix_noifaddrs.c - src/core/lib/iomgr/tcp_server_uv.c - src/core/lib/iomgr/tcp_server_windows.c - src/core/lib/iomgr/tcp_uv.c - src/core/lib/iomgr/tcp_windows.c - src/core/lib/iomgr/time_averaged_stats.c - src/core/lib/iomgr/timer_generic.c - src/core/lib/iomgr/timer_heap.c - src/core/lib/iomgr/timer_manager.c - src/core/lib/iomgr/timer_uv.c - src/core/lib/iomgr/udp_server.c - src/core/lib/iomgr/unix_sockets_posix.c - src/core/lib/iomgr/unix_sockets_posix_noop.c - src/core/lib/iomgr/wakeup_fd_cv.c - src/core/lib/iomgr/wakeup_fd_eventfd.c - src/core/lib/iomgr/wakeup_fd_nospecial.c - src/core/lib/iomgr/wakeup_fd_pipe.c - src/core/lib/iomgr/wakeup_fd_posix.c - src/core/lib/json/json.c - src/core/lib/json/json_reader.c - src/core/lib/json/json_string.c - src/core/lib/json/json_writer.c - src/core/lib/slice/b64.c - src/core/lib/slice/percent_encoding.c - src/core/lib/slice/slice.c - src/core/lib/slice/slice_buffer.c - src/core/lib/slice/slice_hash_table.c - src/core/lib/slice/slice_intern.c - src/core/lib/slice/slice_string_helpers.c - src/core/lib/surface/alarm.c - src/core/lib/surface/api_trace.c - src/core/lib/surface/byte_buffer.c - src/core/lib/surface/byte_buffer_reader.c - src/core/lib/surface/call.c - src/core/lib/surface/call_details.c - src/core/lib/surface/call_log_batch.c - src/core/lib/surface/channel.c - src/core/lib/surface/channel_init.c - src/core/lib/surface/channel_ping.c - src/core/lib/surface/channel_stack_type.c - src/core/lib/surface/completion_queue.c - src/core/lib/surface/completion_queue_factory.c - src/core/lib/surface/event_string.c - src/core/lib/surface/lame_client.cc - src/core/lib/surface/metadata_array.c - src/core/lib/surface/server.c - src/core/lib/surface/validate_metadata.c - src/core/lib/surface/version.c - src/core/lib/transport/bdp_estimator.c - src/core/lib/transport/byte_stream.c - src/core/lib/transport/connectivity_state.c - src/core/lib/transport/error_utils.c - src/core/lib/transport/metadata.c - src/core/lib/transport/metadata_batch.c - src/core/lib/transport/pid_controller.c - src/core/lib/transport/service_config.c - src/core/lib/transport/static_metadata.c - src/core/lib/transport/status_conversion.c - src/core/lib/transport/timeout_encoding.c - src/core/lib/transport/transport.c - src/core/lib/transport/transport_op_string.c - src/core/lib/debug/trace.c - src/core/ext/transport/cronet/client/secure/cronet_channel_create.c - src/core/ext/transport/cronet/transport/cronet_api_dummy.c - src/core/ext/transport/cronet/transport/cronet_transport.c - src/core/ext/transport/chttp2/client/secure/secure_channel_create.c - src/core/ext/transport/chttp2/transport/bin_decoder.c - src/core/ext/transport/chttp2/transport/bin_encoder.c - src/core/ext/transport/chttp2/transport/chttp2_plugin.c - src/core/ext/transport/chttp2/transport/chttp2_transport.c - src/core/ext/transport/chttp2/transport/frame_data.c - src/core/ext/transport/chttp2/transport/frame_goaway.c - src/core/ext/transport/chttp2/transport/frame_ping.c - src/core/ext/transport/chttp2/transport/frame_rst_stream.c - src/core/ext/transport/chttp2/transport/frame_settings.c - src/core/ext/transport/chttp2/transport/frame_window_update.c - src/core/ext/transport/chttp2/transport/hpack_encoder.c - src/core/ext/transport/chttp2/transport/hpack_parser.c - src/core/ext/transport/chttp2/transport/hpack_table.c - src/core/ext/transport/chttp2/transport/http2_settings.c - src/core/ext/transport/chttp2/transport/huffsyms.c - src/core/ext/transport/chttp2/transport/incoming_metadata.c - src/core/ext/transport/chttp2/transport/parsing.c - src/core/ext/transport/chttp2/transport/stream_lists.c - src/core/ext/transport/chttp2/transport/stream_map.c - src/core/ext/transport/chttp2/transport/varint.c - src/core/ext/transport/chttp2/transport/writing.c - src/core/ext/transport/chttp2/alpn/alpn.c - src/core/ext/filters/http/client/http_client_filter.c - src/core/ext/filters/http/http_filters_plugin.c - src/core/ext/filters/http/message_compress/message_compress_filter.c - src/core/ext/filters/http/server/http_server_filter.c - src/core/ext/filters/client_channel/channel_connectivity.c - src/core/ext/filters/client_channel/client_channel.c - src/core/ext/filters/client_channel/client_channel_factory.c - src/core/ext/filters/client_channel/client_channel_plugin.c - src/core/ext/filters/client_channel/connector.c - src/core/ext/filters/client_channel/http_connect_handshaker.c - src/core/ext/filters/client_channel/http_proxy.c - src/core/ext/filters/client_channel/lb_policy.c - src/core/ext/filters/client_channel/lb_policy_factory.c - src/core/ext/filters/client_channel/lb_policy_registry.c - src/core/ext/filters/client_channel/parse_address.c - src/core/ext/filters/client_channel/proxy_mapper.c - src/core/ext/filters/client_channel/proxy_mapper_registry.c - src/core/ext/filters/client_channel/resolver.c - src/core/ext/filters/client_channel/resolver_factory.c - src/core/ext/filters/client_channel/resolver_registry.c - src/core/ext/filters/client_channel/retry_throttle.c - src/core/ext/filters/client_channel/subchannel.c - src/core/ext/filters/client_channel/subchannel_index.c - src/core/ext/filters/client_channel/uri_parser.c - src/core/ext/filters/deadline/deadline_filter.c - src/core/lib/http/httpcli_security_connector.c - src/core/lib/security/context/security_context.c - src/core/lib/security/credentials/composite/composite_credentials.c - src/core/lib/security/credentials/credentials.c - src/core/lib/security/credentials/credentials_metadata.c - src/core/lib/security/credentials/fake/fake_credentials.c - src/core/lib/security/credentials/google_default/credentials_generic.c - src/core/lib/security/credentials/google_default/google_default_credentials.c - src/core/lib/security/credentials/iam/iam_credentials.c - src/core/lib/security/credentials/jwt/json_token.c - src/core/lib/security/credentials/jwt/jwt_credentials.c - src/core/lib/security/credentials/jwt/jwt_verifier.c - src/core/lib/security/credentials/oauth2/oauth2_credentials.c - src/core/lib/security/credentials/plugin/plugin_credentials.c - src/core/lib/security/credentials/ssl/ssl_credentials.c - src/core/lib/security/transport/client_auth_filter.c - src/core/lib/security/transport/lb_targets_info.c - src/core/lib/security/transport/secure_endpoint.c - src/core/lib/security/transport/security_connector.c - src/core/lib/security/transport/security_handshaker.c - src/core/lib/security/transport/server_auth_filter.c - src/core/lib/security/transport/tsi_error.c - src/core/lib/security/util/json_util.c - src/core/lib/surface/init_secure.c - src/core/tsi/fake_transport_security.c - src/core/tsi/gts_transport_security.c - src/core/tsi/ssl_transport_security.c - src/core/tsi/transport_security.c - src/core/tsi/transport_security_adapter.c - src/core/ext/transport/chttp2/client/chttp2_connector.c - src/core/ext/filters/load_reporting/load_reporting.c - src/core/ext/filters/load_reporting/load_reporting_filter.c - src/core/plugin_registry/grpc_cronet_plugin_registry.c -) - -if(WIN32 AND MSVC) - set_target_properties(grpc_cronet PROPERTIES COMPILE_PDB_NAME "grpc_cronet" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/grpc_cronet.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - - -target_include_directories(grpc_cronet - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(grpc_cronet - ${_gRPC_BASELIB_LIBRARIES} - ${_gRPC_SSL_LIBRARIES} - ${_gRPC_ZLIB_LIBRARIES} - ${_gRPC_CARES_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr -) - -foreach(_hdr - include/grpc/byte_buffer.h - include/grpc/byte_buffer_reader.h - include/grpc/compression.h - include/grpc/grpc.h - include/grpc/grpc_posix.h - include/grpc/grpc_security_constants.h - include/grpc/load_reporting.h - include/grpc/slice.h - include/grpc/slice_buffer.h - include/grpc/status.h - include/grpc/support/workaround_list.h - include/grpc/impl/codegen/byte_buffer_reader.h - include/grpc/impl/codegen/compression_types.h - include/grpc/impl/codegen/connectivity_state.h - include/grpc/impl/codegen/exec_ctx_fwd.h - include/grpc/impl/codegen/grpc_types.h - include/grpc/impl/codegen/propagation_bits.h - include/grpc/impl/codegen/slice.h - include/grpc/impl/codegen/status.h - include/grpc/impl/codegen/atm.h - include/grpc/impl/codegen/atm_gcc_atomic.h - include/grpc/impl/codegen/atm_gcc_sync.h - include/grpc/impl/codegen/atm_windows.h - include/grpc/impl/codegen/gpr_slice.h - include/grpc/impl/codegen/gpr_types.h - include/grpc/impl/codegen/port_platform.h - include/grpc/impl/codegen/sync.h - include/grpc/impl/codegen/sync_generic.h - include/grpc/impl/codegen/sync_posix.h - include/grpc/impl/codegen/sync_windows.h - include/grpc/grpc_cronet.h - include/grpc/grpc_security.h -) - string(REPLACE "include/" "" _path ${_hdr}) - get_filename_component(_path ${_path} PATH) - install(FILES ${_hdr} - DESTINATION "${gRPC_INSTALL_INCLUDEDIR}/${_path}" - ) -endforeach() - - -if (gRPC_INSTALL) - install(TARGETS grpc_cronet EXPORT gRPCTargets - RUNTIME DESTINATION ${gRPC_INSTALL_BINDIR} - LIBRARY DESTINATION ${gRPC_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${gRPC_INSTALL_LIBDIR} - ) -endif() - -if (gRPC_BUILD_TESTS) - -add_library(grpc_test_util - test/core/end2end/data/client_certs.c - test/core/end2end/data/server1_cert.c - test/core/end2end/data/server1_key.c - test/core/end2end/data/test_root_cert.c - test/core/security/oauth2_utils.c - src/core/ext/filters/client_channel/resolver/fake/fake_resolver.c - test/core/end2end/cq_verifier.c - test/core/end2end/fixtures/http_proxy_fixture.c - test/core/end2end/fixtures/proxy.c - test/core/iomgr/endpoint_tests.c - test/core/util/debugger_macros.c - test/core/util/grpc_profiler.c - test/core/util/memory_counters.c - test/core/util/mock_endpoint.c - test/core/util/parse_hexstring.c - test/core/util/passthru_endpoint.c - test/core/util/port.c - test/core/util/port_server_client.c - test/core/util/slice_splitter.c - test/core/util/trickle_endpoint.c - src/core/lib/channel/channel_args.c - src/core/lib/channel/channel_stack.c - src/core/lib/channel/channel_stack_builder.c - src/core/lib/channel/connected_channel.c - src/core/lib/channel/handshaker.c - src/core/lib/channel/handshaker_factory.c - src/core/lib/channel/handshaker_registry.c - src/core/lib/compression/compression.c - src/core/lib/compression/message_compress.c - src/core/lib/compression/stream_compression.c - src/core/lib/http/format_request.c - src/core/lib/http/httpcli.c - src/core/lib/http/parser.c - src/core/lib/iomgr/closure.c - src/core/lib/iomgr/combiner.c - src/core/lib/iomgr/endpoint.c - src/core/lib/iomgr/endpoint_pair_posix.c - src/core/lib/iomgr/endpoint_pair_uv.c - src/core/lib/iomgr/endpoint_pair_windows.c - src/core/lib/iomgr/error.c - src/core/lib/iomgr/ev_epoll1_linux.c - src/core/lib/iomgr/ev_epoll_limited_pollers_linux.c - src/core/lib/iomgr/ev_epoll_thread_pool_linux.c - src/core/lib/iomgr/ev_epollex_linux.c - src/core/lib/iomgr/ev_epollsig_linux.c - src/core/lib/iomgr/ev_poll_posix.c - src/core/lib/iomgr/ev_posix.c - src/core/lib/iomgr/ev_windows.c - src/core/lib/iomgr/exec_ctx.c - src/core/lib/iomgr/executor.c - src/core/lib/iomgr/iocp_windows.c - src/core/lib/iomgr/iomgr.c - src/core/lib/iomgr/iomgr_posix.c - src/core/lib/iomgr/iomgr_uv.c - src/core/lib/iomgr/iomgr_windows.c - src/core/lib/iomgr/is_epollexclusive_available.c - src/core/lib/iomgr/load_file.c - src/core/lib/iomgr/lockfree_event.c - src/core/lib/iomgr/network_status_tracker.c - src/core/lib/iomgr/polling_entity.c - src/core/lib/iomgr/pollset_set_uv.c - src/core/lib/iomgr/pollset_set_windows.c - src/core/lib/iomgr/pollset_uv.c - src/core/lib/iomgr/pollset_windows.c - src/core/lib/iomgr/resolve_address_posix.c - src/core/lib/iomgr/resolve_address_uv.c - src/core/lib/iomgr/resolve_address_windows.c - src/core/lib/iomgr/resource_quota.c - src/core/lib/iomgr/sockaddr_utils.c - src/core/lib/iomgr/socket_factory_posix.c - src/core/lib/iomgr/socket_mutator.c - src/core/lib/iomgr/socket_utils_common_posix.c - src/core/lib/iomgr/socket_utils_linux.c - src/core/lib/iomgr/socket_utils_posix.c - src/core/lib/iomgr/socket_utils_uv.c - src/core/lib/iomgr/socket_utils_windows.c - src/core/lib/iomgr/socket_windows.c - src/core/lib/iomgr/tcp_client_posix.c - src/core/lib/iomgr/tcp_client_uv.c - src/core/lib/iomgr/tcp_client_windows.c - src/core/lib/iomgr/tcp_posix.c - src/core/lib/iomgr/tcp_server_posix.c - src/core/lib/iomgr/tcp_server_utils_posix_common.c - src/core/lib/iomgr/tcp_server_utils_posix_ifaddrs.c - src/core/lib/iomgr/tcp_server_utils_posix_noifaddrs.c - src/core/lib/iomgr/tcp_server_uv.c - src/core/lib/iomgr/tcp_server_windows.c - src/core/lib/iomgr/tcp_uv.c - src/core/lib/iomgr/tcp_windows.c - src/core/lib/iomgr/time_averaged_stats.c - src/core/lib/iomgr/timer_generic.c - src/core/lib/iomgr/timer_heap.c - src/core/lib/iomgr/timer_manager.c - src/core/lib/iomgr/timer_uv.c - src/core/lib/iomgr/udp_server.c - src/core/lib/iomgr/unix_sockets_posix.c - src/core/lib/iomgr/unix_sockets_posix_noop.c - src/core/lib/iomgr/wakeup_fd_cv.c - src/core/lib/iomgr/wakeup_fd_eventfd.c - src/core/lib/iomgr/wakeup_fd_nospecial.c - src/core/lib/iomgr/wakeup_fd_pipe.c - src/core/lib/iomgr/wakeup_fd_posix.c - src/core/lib/json/json.c - src/core/lib/json/json_reader.c - src/core/lib/json/json_string.c - src/core/lib/json/json_writer.c - src/core/lib/slice/b64.c - src/core/lib/slice/percent_encoding.c - src/core/lib/slice/slice.c - src/core/lib/slice/slice_buffer.c - src/core/lib/slice/slice_hash_table.c - src/core/lib/slice/slice_intern.c - src/core/lib/slice/slice_string_helpers.c - src/core/lib/surface/alarm.c - src/core/lib/surface/api_trace.c - src/core/lib/surface/byte_buffer.c - src/core/lib/surface/byte_buffer_reader.c - src/core/lib/surface/call.c - src/core/lib/surface/call_details.c - src/core/lib/surface/call_log_batch.c - src/core/lib/surface/channel.c - src/core/lib/surface/channel_init.c - src/core/lib/surface/channel_ping.c - src/core/lib/surface/channel_stack_type.c - src/core/lib/surface/completion_queue.c - src/core/lib/surface/completion_queue_factory.c - src/core/lib/surface/event_string.c - src/core/lib/surface/lame_client.cc - src/core/lib/surface/metadata_array.c - src/core/lib/surface/server.c - src/core/lib/surface/validate_metadata.c - src/core/lib/surface/version.c - src/core/lib/transport/bdp_estimator.c - src/core/lib/transport/byte_stream.c - src/core/lib/transport/connectivity_state.c - src/core/lib/transport/error_utils.c - src/core/lib/transport/metadata.c - src/core/lib/transport/metadata_batch.c - src/core/lib/transport/pid_controller.c - src/core/lib/transport/service_config.c - src/core/lib/transport/static_metadata.c - src/core/lib/transport/status_conversion.c - src/core/lib/transport/timeout_encoding.c - src/core/lib/transport/transport.c - src/core/lib/transport/transport_op_string.c - src/core/lib/debug/trace.c -) - -if(WIN32 AND MSVC) - set_target_properties(grpc_test_util PROPERTIES COMPILE_PDB_NAME "grpc_test_util" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/grpc_test_util.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - - -target_include_directories(grpc_test_util - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(grpc_test_util - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr_test_util - gpr - grpc -) - -foreach(_hdr - include/grpc/byte_buffer.h - include/grpc/byte_buffer_reader.h - include/grpc/compression.h - include/grpc/grpc.h - include/grpc/grpc_posix.h - include/grpc/grpc_security_constants.h - include/grpc/load_reporting.h - include/grpc/slice.h - include/grpc/slice_buffer.h - include/grpc/status.h - include/grpc/support/workaround_list.h - include/grpc/impl/codegen/byte_buffer_reader.h - include/grpc/impl/codegen/compression_types.h - include/grpc/impl/codegen/connectivity_state.h - include/grpc/impl/codegen/exec_ctx_fwd.h - include/grpc/impl/codegen/grpc_types.h - include/grpc/impl/codegen/propagation_bits.h - include/grpc/impl/codegen/slice.h - include/grpc/impl/codegen/status.h - include/grpc/impl/codegen/atm.h - include/grpc/impl/codegen/atm_gcc_atomic.h - include/grpc/impl/codegen/atm_gcc_sync.h - include/grpc/impl/codegen/atm_windows.h - include/grpc/impl/codegen/gpr_slice.h - include/grpc/impl/codegen/gpr_types.h - include/grpc/impl/codegen/port_platform.h - include/grpc/impl/codegen/sync.h - include/grpc/impl/codegen/sync_generic.h - include/grpc/impl/codegen/sync_posix.h - include/grpc/impl/codegen/sync_windows.h -) - string(REPLACE "include/" "" _path ${_hdr}) - get_filename_component(_path ${_path} PATH) - install(FILES ${_hdr} - DESTINATION "${gRPC_INSTALL_INCLUDEDIR}/${_path}" - ) -endforeach() - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_library(grpc_test_util_unsecure - src/core/ext/filters/client_channel/resolver/fake/fake_resolver.c - test/core/end2end/cq_verifier.c - test/core/end2end/fixtures/http_proxy_fixture.c - test/core/end2end/fixtures/proxy.c - test/core/iomgr/endpoint_tests.c - test/core/util/debugger_macros.c - test/core/util/grpc_profiler.c - test/core/util/memory_counters.c - test/core/util/mock_endpoint.c - test/core/util/parse_hexstring.c - test/core/util/passthru_endpoint.c - test/core/util/port.c - test/core/util/port_server_client.c - test/core/util/slice_splitter.c - test/core/util/trickle_endpoint.c -) - -if(WIN32 AND MSVC) - set_target_properties(grpc_test_util_unsecure PROPERTIES COMPILE_PDB_NAME "grpc_test_util_unsecure" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/grpc_test_util_unsecure.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - - -target_include_directories(grpc_test_util_unsecure - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(grpc_test_util_unsecure - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr - gpr_test_util - grpc_unsecure - grpc -) - - -endif (gRPC_BUILD_TESTS) - -add_library(grpc_unsecure - src/core/lib/surface/init.c - src/core/lib/surface/init_unsecure.c - src/core/lib/channel/channel_args.c - src/core/lib/channel/channel_stack.c - src/core/lib/channel/channel_stack_builder.c - src/core/lib/channel/connected_channel.c - src/core/lib/channel/handshaker.c - src/core/lib/channel/handshaker_factory.c - src/core/lib/channel/handshaker_registry.c - src/core/lib/compression/compression.c - src/core/lib/compression/message_compress.c - src/core/lib/compression/stream_compression.c - src/core/lib/http/format_request.c - src/core/lib/http/httpcli.c - src/core/lib/http/parser.c - src/core/lib/iomgr/closure.c - src/core/lib/iomgr/combiner.c - src/core/lib/iomgr/endpoint.c - src/core/lib/iomgr/endpoint_pair_posix.c - src/core/lib/iomgr/endpoint_pair_uv.c - src/core/lib/iomgr/endpoint_pair_windows.c - src/core/lib/iomgr/error.c - src/core/lib/iomgr/ev_epoll1_linux.c - src/core/lib/iomgr/ev_epoll_limited_pollers_linux.c - src/core/lib/iomgr/ev_epoll_thread_pool_linux.c - src/core/lib/iomgr/ev_epollex_linux.c - src/core/lib/iomgr/ev_epollsig_linux.c - src/core/lib/iomgr/ev_poll_posix.c - src/core/lib/iomgr/ev_posix.c - src/core/lib/iomgr/ev_windows.c - src/core/lib/iomgr/exec_ctx.c - src/core/lib/iomgr/executor.c - src/core/lib/iomgr/iocp_windows.c - src/core/lib/iomgr/iomgr.c - src/core/lib/iomgr/iomgr_posix.c - src/core/lib/iomgr/iomgr_uv.c - src/core/lib/iomgr/iomgr_windows.c - src/core/lib/iomgr/is_epollexclusive_available.c - src/core/lib/iomgr/load_file.c - src/core/lib/iomgr/lockfree_event.c - src/core/lib/iomgr/network_status_tracker.c - src/core/lib/iomgr/polling_entity.c - src/core/lib/iomgr/pollset_set_uv.c - src/core/lib/iomgr/pollset_set_windows.c - src/core/lib/iomgr/pollset_uv.c - src/core/lib/iomgr/pollset_windows.c - src/core/lib/iomgr/resolve_address_posix.c - src/core/lib/iomgr/resolve_address_uv.c - src/core/lib/iomgr/resolve_address_windows.c - src/core/lib/iomgr/resource_quota.c - src/core/lib/iomgr/sockaddr_utils.c - src/core/lib/iomgr/socket_factory_posix.c - src/core/lib/iomgr/socket_mutator.c - src/core/lib/iomgr/socket_utils_common_posix.c - src/core/lib/iomgr/socket_utils_linux.c - src/core/lib/iomgr/socket_utils_posix.c - src/core/lib/iomgr/socket_utils_uv.c - src/core/lib/iomgr/socket_utils_windows.c - src/core/lib/iomgr/socket_windows.c - src/core/lib/iomgr/tcp_client_posix.c - src/core/lib/iomgr/tcp_client_uv.c - src/core/lib/iomgr/tcp_client_windows.c - src/core/lib/iomgr/tcp_posix.c - src/core/lib/iomgr/tcp_server_posix.c - src/core/lib/iomgr/tcp_server_utils_posix_common.c - src/core/lib/iomgr/tcp_server_utils_posix_ifaddrs.c - src/core/lib/iomgr/tcp_server_utils_posix_noifaddrs.c - src/core/lib/iomgr/tcp_server_uv.c - src/core/lib/iomgr/tcp_server_windows.c - src/core/lib/iomgr/tcp_uv.c - src/core/lib/iomgr/tcp_windows.c - src/core/lib/iomgr/time_averaged_stats.c - src/core/lib/iomgr/timer_generic.c - src/core/lib/iomgr/timer_heap.c - src/core/lib/iomgr/timer_manager.c - src/core/lib/iomgr/timer_uv.c - src/core/lib/iomgr/udp_server.c - src/core/lib/iomgr/unix_sockets_posix.c - src/core/lib/iomgr/unix_sockets_posix_noop.c - src/core/lib/iomgr/wakeup_fd_cv.c - src/core/lib/iomgr/wakeup_fd_eventfd.c - src/core/lib/iomgr/wakeup_fd_nospecial.c - src/core/lib/iomgr/wakeup_fd_pipe.c - src/core/lib/iomgr/wakeup_fd_posix.c - src/core/lib/json/json.c - src/core/lib/json/json_reader.c - src/core/lib/json/json_string.c - src/core/lib/json/json_writer.c - src/core/lib/slice/b64.c - src/core/lib/slice/percent_encoding.c - src/core/lib/slice/slice.c - src/core/lib/slice/slice_buffer.c - src/core/lib/slice/slice_hash_table.c - src/core/lib/slice/slice_intern.c - src/core/lib/slice/slice_string_helpers.c - src/core/lib/surface/alarm.c - src/core/lib/surface/api_trace.c - src/core/lib/surface/byte_buffer.c - src/core/lib/surface/byte_buffer_reader.c - src/core/lib/surface/call.c - src/core/lib/surface/call_details.c - src/core/lib/surface/call_log_batch.c - src/core/lib/surface/channel.c - src/core/lib/surface/channel_init.c - src/core/lib/surface/channel_ping.c - src/core/lib/surface/channel_stack_type.c - src/core/lib/surface/completion_queue.c - src/core/lib/surface/completion_queue_factory.c - src/core/lib/surface/event_string.c - src/core/lib/surface/lame_client.cc - src/core/lib/surface/metadata_array.c - src/core/lib/surface/server.c - src/core/lib/surface/validate_metadata.c - src/core/lib/surface/version.c - src/core/lib/transport/bdp_estimator.c - src/core/lib/transport/byte_stream.c - src/core/lib/transport/connectivity_state.c - src/core/lib/transport/error_utils.c - src/core/lib/transport/metadata.c - src/core/lib/transport/metadata_batch.c - src/core/lib/transport/pid_controller.c - src/core/lib/transport/service_config.c - src/core/lib/transport/static_metadata.c - src/core/lib/transport/status_conversion.c - src/core/lib/transport/timeout_encoding.c - src/core/lib/transport/transport.c - src/core/lib/transport/transport_op_string.c - src/core/lib/debug/trace.c - src/core/ext/transport/chttp2/server/insecure/server_chttp2.c - src/core/ext/transport/chttp2/server/insecure/server_chttp2_posix.c - src/core/ext/transport/chttp2/transport/bin_decoder.c - src/core/ext/transport/chttp2/transport/bin_encoder.c - src/core/ext/transport/chttp2/transport/chttp2_plugin.c - src/core/ext/transport/chttp2/transport/chttp2_transport.c - src/core/ext/transport/chttp2/transport/frame_data.c - src/core/ext/transport/chttp2/transport/frame_goaway.c - src/core/ext/transport/chttp2/transport/frame_ping.c - src/core/ext/transport/chttp2/transport/frame_rst_stream.c - src/core/ext/transport/chttp2/transport/frame_settings.c - src/core/ext/transport/chttp2/transport/frame_window_update.c - src/core/ext/transport/chttp2/transport/hpack_encoder.c - src/core/ext/transport/chttp2/transport/hpack_parser.c - src/core/ext/transport/chttp2/transport/hpack_table.c - src/core/ext/transport/chttp2/transport/http2_settings.c - src/core/ext/transport/chttp2/transport/huffsyms.c - src/core/ext/transport/chttp2/transport/incoming_metadata.c - src/core/ext/transport/chttp2/transport/parsing.c - src/core/ext/transport/chttp2/transport/stream_lists.c - src/core/ext/transport/chttp2/transport/stream_map.c - src/core/ext/transport/chttp2/transport/varint.c - src/core/ext/transport/chttp2/transport/writing.c - src/core/ext/transport/chttp2/alpn/alpn.c - src/core/ext/filters/http/client/http_client_filter.c - src/core/ext/filters/http/http_filters_plugin.c - src/core/ext/filters/http/message_compress/message_compress_filter.c - src/core/ext/filters/http/server/http_server_filter.c - src/core/ext/transport/chttp2/server/chttp2_server.c - src/core/ext/transport/chttp2/client/insecure/channel_create.c - src/core/ext/transport/chttp2/client/insecure/channel_create_posix.c - src/core/ext/transport/chttp2/client/chttp2_connector.c - src/core/ext/filters/client_channel/channel_connectivity.c - src/core/ext/filters/client_channel/client_channel.c - src/core/ext/filters/client_channel/client_channel_factory.c - src/core/ext/filters/client_channel/client_channel_plugin.c - src/core/ext/filters/client_channel/connector.c - src/core/ext/filters/client_channel/http_connect_handshaker.c - src/core/ext/filters/client_channel/http_proxy.c - src/core/ext/filters/client_channel/lb_policy.c - src/core/ext/filters/client_channel/lb_policy_factory.c - src/core/ext/filters/client_channel/lb_policy_registry.c - src/core/ext/filters/client_channel/parse_address.c - src/core/ext/filters/client_channel/proxy_mapper.c - src/core/ext/filters/client_channel/proxy_mapper_registry.c - src/core/ext/filters/client_channel/resolver.c - src/core/ext/filters/client_channel/resolver_factory.c - src/core/ext/filters/client_channel/resolver_registry.c - src/core/ext/filters/client_channel/retry_throttle.c - src/core/ext/filters/client_channel/subchannel.c - src/core/ext/filters/client_channel/subchannel_index.c - src/core/ext/filters/client_channel/uri_parser.c - src/core/ext/filters/deadline/deadline_filter.c - src/core/ext/transport/inproc/inproc_plugin.c - src/core/ext/transport/inproc/inproc_transport.c - src/core/ext/filters/client_channel/resolver/dns/c_ares/dns_resolver_ares.c - src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_ev_driver_posix.c - src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper.c - src/core/ext/filters/client_channel/resolver/dns/c_ares/grpc_ares_wrapper_fallback.c - src/core/ext/filters/client_channel/resolver/dns/native/dns_resolver.c - src/core/ext/filters/client_channel/resolver/sockaddr/sockaddr_resolver.c - src/core/ext/filters/client_channel/resolver/fake/fake_resolver.c - src/core/ext/filters/load_reporting/load_reporting.c - src/core/ext/filters/load_reporting/load_reporting_filter.c - src/core/ext/filters/client_channel/lb_policy/grpclb/client_load_reporting_filter.c - src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb.c - src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb_channel.c - src/core/ext/filters/client_channel/lb_policy/grpclb/grpclb_client_stats.c - src/core/ext/filters/client_channel/lb_policy/grpclb/load_balancer_api.c - src/core/ext/filters/client_channel/lb_policy/grpclb/proto/grpc/lb/v1/load_balancer.pb.c - third_party/nanopb/pb_common.c - third_party/nanopb/pb_decode.c - third_party/nanopb/pb_encode.c - src/core/ext/filters/client_channel/lb_policy/pick_first/pick_first.c - src/core/ext/filters/client_channel/lb_policy/round_robin/round_robin.c - src/core/ext/census/base_resources.c - src/core/ext/census/context.c - src/core/ext/census/gen/census.pb.c - src/core/ext/census/gen/trace_context.pb.c - src/core/ext/census/grpc_context.c - src/core/ext/census/grpc_filter.c - src/core/ext/census/grpc_plugin.c - src/core/ext/census/initialize.c - src/core/ext/census/intrusive_hash_map.c - src/core/ext/census/mlog.c - src/core/ext/census/operation.c - src/core/ext/census/placeholders.c - src/core/ext/census/resource.c - src/core/ext/census/trace_context.c - src/core/ext/census/tracing.c - src/core/ext/filters/max_age/max_age_filter.c - src/core/ext/filters/message_size/message_size_filter.c - src/core/ext/filters/workarounds/workaround_cronet_compression_filter.c - src/core/ext/filters/workarounds/workaround_utils.c - src/core/plugin_registry/grpc_unsecure_plugin_registry.c -) - -if(WIN32 AND MSVC) - set_target_properties(grpc_unsecure PROPERTIES COMPILE_PDB_NAME "grpc_unsecure" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/grpc_unsecure.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - - -target_include_directories(grpc_unsecure - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(grpc_unsecure - ${_gRPC_BASELIB_LIBRARIES} - ${_gRPC_ZLIB_LIBRARIES} - ${_gRPC_CARES_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr -) - -foreach(_hdr - include/grpc/byte_buffer.h - include/grpc/byte_buffer_reader.h - include/grpc/compression.h - include/grpc/grpc.h - include/grpc/grpc_posix.h - include/grpc/grpc_security_constants.h - include/grpc/load_reporting.h - include/grpc/slice.h - include/grpc/slice_buffer.h - include/grpc/status.h - include/grpc/support/workaround_list.h - include/grpc/impl/codegen/byte_buffer_reader.h - include/grpc/impl/codegen/compression_types.h - include/grpc/impl/codegen/connectivity_state.h - include/grpc/impl/codegen/exec_ctx_fwd.h - include/grpc/impl/codegen/grpc_types.h - include/grpc/impl/codegen/propagation_bits.h - include/grpc/impl/codegen/slice.h - include/grpc/impl/codegen/status.h - include/grpc/impl/codegen/atm.h - include/grpc/impl/codegen/atm_gcc_atomic.h - include/grpc/impl/codegen/atm_gcc_sync.h - include/grpc/impl/codegen/atm_windows.h - include/grpc/impl/codegen/gpr_slice.h - include/grpc/impl/codegen/gpr_types.h - include/grpc/impl/codegen/port_platform.h - include/grpc/impl/codegen/sync.h - include/grpc/impl/codegen/sync_generic.h - include/grpc/impl/codegen/sync_posix.h - include/grpc/impl/codegen/sync_windows.h - include/grpc/census.h -) - string(REPLACE "include/" "" _path ${_hdr}) - get_filename_component(_path ${_path} PATH) - install(FILES ${_hdr} - DESTINATION "${gRPC_INSTALL_INCLUDEDIR}/${_path}" - ) -endforeach() - - -if (gRPC_INSTALL) - install(TARGETS grpc_unsecure EXPORT gRPCTargets - RUNTIME DESTINATION ${gRPC_INSTALL_BINDIR} - LIBRARY DESTINATION ${gRPC_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${gRPC_INSTALL_LIBDIR} - ) -endif() - -if (gRPC_BUILD_TESTS) - -add_library(reconnect_server - test/core/util/reconnect_server.c -) - -if(WIN32 AND MSVC) - set_target_properties(reconnect_server PROPERTIES COMPILE_PDB_NAME "reconnect_server" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/reconnect_server.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - - -target_include_directories(reconnect_server - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(reconnect_server - ${_gRPC_ALLTARGETS_LIBRARIES} - test_tcp_server - grpc_test_util - grpc - gpr_test_util - gpr -) - - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_library(test_tcp_server - test/core/util/test_tcp_server.c -) - -if(WIN32 AND MSVC) - set_target_properties(test_tcp_server PROPERTIES COMPILE_PDB_NAME "test_tcp_server" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/test_tcp_server.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - - -target_include_directories(test_tcp_server - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(test_tcp_server - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - - -endif (gRPC_BUILD_TESTS) - -add_library(grpc++ - src/cpp/client/insecure_credentials.cc - src/cpp/client/secure_credentials.cc - src/cpp/common/auth_property_iterator.cc - src/cpp/common/secure_auth_context.cc - src/cpp/common/secure_channel_arguments.cc - src/cpp/common/secure_create_auth_context.cc - src/cpp/server/insecure_server_credentials.cc - src/cpp/server/secure_server_credentials.cc - src/cpp/client/channel_cc.cc - src/cpp/client/client_context.cc - src/cpp/client/create_channel.cc - src/cpp/client/create_channel_internal.cc - src/cpp/client/create_channel_posix.cc - src/cpp/client/credentials_cc.cc - src/cpp/client/generic_stub.cc - src/cpp/common/channel_arguments.cc - src/cpp/common/channel_filter.cc - src/cpp/common/completion_queue_cc.cc - src/cpp/common/core_codegen.cc - src/cpp/common/resource_quota_cc.cc - src/cpp/common/rpc_method.cc - src/cpp/common/version_cc.cc - src/cpp/server/async_generic_service.cc - src/cpp/server/channel_argument_option.cc - src/cpp/server/create_default_thread_pool.cc - src/cpp/server/dynamic_thread_pool.cc - src/cpp/server/health/default_health_check_service.cc - src/cpp/server/health/health.pb.c - src/cpp/server/health/health_check_service.cc - src/cpp/server/health/health_check_service_server_builder_option.cc - src/cpp/server/server_builder.cc - src/cpp/server/server_cc.cc - src/cpp/server/server_context.cc - src/cpp/server/server_credentials.cc - src/cpp/server/server_posix.cc - src/cpp/thread_manager/thread_manager.cc - src/cpp/util/byte_buffer_cc.cc - src/cpp/util/slice_cc.cc - src/cpp/util/status.cc - src/cpp/util/string_ref.cc - src/cpp/util/time_cc.cc - third_party/nanopb/pb_common.c - third_party/nanopb/pb_decode.c - third_party/nanopb/pb_encode.c - src/cpp/codegen/codegen_init.cc -) - -if(WIN32 AND MSVC) - set_target_properties(grpc++ PROPERTIES COMPILE_PDB_NAME "grpc++" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/grpc++.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - - -target_include_directories(grpc++ - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(grpc++ - ${_gRPC_BASELIB_LIBRARIES} - ${_gRPC_SSL_LIBRARIES} - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc -) - -foreach(_hdr - include/grpc++/alarm.h - include/grpc++/channel.h - include/grpc++/client_context.h - include/grpc++/completion_queue.h - include/grpc++/create_channel.h - include/grpc++/create_channel_posix.h - include/grpc++/ext/health_check_service_server_builder_option.h - include/grpc++/generic/async_generic_service.h - include/grpc++/generic/generic_stub.h - include/grpc++/grpc++.h - include/grpc++/health_check_service_interface.h - include/grpc++/impl/call.h - include/grpc++/impl/channel_argument_option.h - include/grpc++/impl/client_unary_call.h - include/grpc++/impl/codegen/core_codegen.h - include/grpc++/impl/grpc_library.h - include/grpc++/impl/method_handler_impl.h - include/grpc++/impl/rpc_method.h - include/grpc++/impl/rpc_service_method.h - include/grpc++/impl/serialization_traits.h - include/grpc++/impl/server_builder_option.h - include/grpc++/impl/server_builder_plugin.h - include/grpc++/impl/server_initializer.h - include/grpc++/impl/service_type.h - include/grpc++/resource_quota.h - include/grpc++/security/auth_context.h - include/grpc++/security/auth_metadata_processor.h - include/grpc++/security/credentials.h - include/grpc++/security/server_credentials.h - include/grpc++/server.h - include/grpc++/server_builder.h - include/grpc++/server_context.h - include/grpc++/server_posix.h - include/grpc++/support/async_stream.h - include/grpc++/support/async_unary_call.h - include/grpc++/support/byte_buffer.h - include/grpc++/support/channel_arguments.h - include/grpc++/support/config.h - include/grpc++/support/slice.h - include/grpc++/support/status.h - include/grpc++/support/status_code_enum.h - include/grpc++/support/string_ref.h - include/grpc++/support/stub_options.h - include/grpc++/support/sync_stream.h - include/grpc++/support/time.h - include/grpc++/impl/codegen/async_stream.h - include/grpc++/impl/codegen/async_unary_call.h - include/grpc++/impl/codegen/call.h - include/grpc++/impl/codegen/call_hook.h - include/grpc++/impl/codegen/channel_interface.h - include/grpc++/impl/codegen/client_context.h - include/grpc++/impl/codegen/client_unary_call.h - include/grpc++/impl/codegen/completion_queue.h - include/grpc++/impl/codegen/completion_queue_tag.h - include/grpc++/impl/codegen/config.h - include/grpc++/impl/codegen/core_codegen_interface.h - include/grpc++/impl/codegen/create_auth_context.h - include/grpc++/impl/codegen/grpc_library.h - include/grpc++/impl/codegen/metadata_map.h - include/grpc++/impl/codegen/method_handler_impl.h - include/grpc++/impl/codegen/rpc_method.h - include/grpc++/impl/codegen/rpc_service_method.h - include/grpc++/impl/codegen/security/auth_context.h - include/grpc++/impl/codegen/serialization_traits.h - include/grpc++/impl/codegen/server_context.h - include/grpc++/impl/codegen/server_interface.h - include/grpc++/impl/codegen/service_type.h - include/grpc++/impl/codegen/slice.h - include/grpc++/impl/codegen/status.h - include/grpc++/impl/codegen/status_code_enum.h - include/grpc++/impl/codegen/string_ref.h - include/grpc++/impl/codegen/stub_options.h - include/grpc++/impl/codegen/sync_stream.h - include/grpc++/impl/codegen/time.h - include/grpc/impl/codegen/byte_buffer_reader.h - include/grpc/impl/codegen/compression_types.h - include/grpc/impl/codegen/connectivity_state.h - include/grpc/impl/codegen/exec_ctx_fwd.h - include/grpc/impl/codegen/grpc_types.h - include/grpc/impl/codegen/propagation_bits.h - include/grpc/impl/codegen/slice.h - include/grpc/impl/codegen/status.h - include/grpc/impl/codegen/atm.h - include/grpc/impl/codegen/atm_gcc_atomic.h - include/grpc/impl/codegen/atm_gcc_sync.h - include/grpc/impl/codegen/atm_windows.h - include/grpc/impl/codegen/gpr_slice.h - include/grpc/impl/codegen/gpr_types.h - include/grpc/impl/codegen/port_platform.h - include/grpc/impl/codegen/sync.h - include/grpc/impl/codegen/sync_generic.h - include/grpc/impl/codegen/sync_posix.h - include/grpc/impl/codegen/sync_windows.h - include/grpc++/impl/codegen/proto_utils.h - include/grpc++/impl/codegen/config_protobuf.h -) - string(REPLACE "include/" "" _path ${_hdr}) - get_filename_component(_path ${_path} PATH) - install(FILES ${_hdr} - DESTINATION "${gRPC_INSTALL_INCLUDEDIR}/${_path}" - ) -endforeach() - - -if (gRPC_INSTALL) - install(TARGETS grpc++ EXPORT gRPCTargets - RUNTIME DESTINATION ${gRPC_INSTALL_BINDIR} - LIBRARY DESTINATION ${gRPC_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${gRPC_INSTALL_LIBDIR} - ) -endif() - - -add_library(grpc++_cronet - src/cpp/client/cronet_credentials.cc - src/cpp/client/insecure_credentials.cc - src/cpp/common/insecure_create_auth_context.cc - src/cpp/server/insecure_server_credentials.cc - src/cpp/client/channel_cc.cc - src/cpp/client/client_context.cc - src/cpp/client/create_channel.cc - src/cpp/client/create_channel_internal.cc - src/cpp/client/create_channel_posix.cc - src/cpp/client/credentials_cc.cc - src/cpp/client/generic_stub.cc - src/cpp/common/channel_arguments.cc - src/cpp/common/channel_filter.cc - src/cpp/common/completion_queue_cc.cc - src/cpp/common/core_codegen.cc - src/cpp/common/resource_quota_cc.cc - src/cpp/common/rpc_method.cc - src/cpp/common/version_cc.cc - src/cpp/server/async_generic_service.cc - src/cpp/server/channel_argument_option.cc - src/cpp/server/create_default_thread_pool.cc - src/cpp/server/dynamic_thread_pool.cc - src/cpp/server/health/default_health_check_service.cc - src/cpp/server/health/health.pb.c - src/cpp/server/health/health_check_service.cc - src/cpp/server/health/health_check_service_server_builder_option.cc - src/cpp/server/server_builder.cc - src/cpp/server/server_cc.cc - src/cpp/server/server_context.cc - src/cpp/server/server_credentials.cc - src/cpp/server/server_posix.cc - src/cpp/thread_manager/thread_manager.cc - src/cpp/util/byte_buffer_cc.cc - src/cpp/util/slice_cc.cc - src/cpp/util/status.cc - src/cpp/util/string_ref.cc - src/cpp/util/time_cc.cc - third_party/nanopb/pb_common.c - third_party/nanopb/pb_decode.c - third_party/nanopb/pb_encode.c - src/cpp/codegen/codegen_init.cc - src/core/ext/transport/chttp2/client/insecure/channel_create.c - src/core/ext/transport/chttp2/client/insecure/channel_create_posix.c - src/core/ext/transport/chttp2/client/chttp2_connector.c - src/core/ext/transport/chttp2/transport/bin_decoder.c - src/core/ext/transport/chttp2/transport/bin_encoder.c - src/core/ext/transport/chttp2/transport/chttp2_plugin.c - src/core/ext/transport/chttp2/transport/chttp2_transport.c - src/core/ext/transport/chttp2/transport/frame_data.c - src/core/ext/transport/chttp2/transport/frame_goaway.c - src/core/ext/transport/chttp2/transport/frame_ping.c - src/core/ext/transport/chttp2/transport/frame_rst_stream.c - src/core/ext/transport/chttp2/transport/frame_settings.c - src/core/ext/transport/chttp2/transport/frame_window_update.c - src/core/ext/transport/chttp2/transport/hpack_encoder.c - src/core/ext/transport/chttp2/transport/hpack_parser.c - src/core/ext/transport/chttp2/transport/hpack_table.c - src/core/ext/transport/chttp2/transport/http2_settings.c - src/core/ext/transport/chttp2/transport/huffsyms.c - src/core/ext/transport/chttp2/transport/incoming_metadata.c - src/core/ext/transport/chttp2/transport/parsing.c - src/core/ext/transport/chttp2/transport/stream_lists.c - src/core/ext/transport/chttp2/transport/stream_map.c - src/core/ext/transport/chttp2/transport/varint.c - src/core/ext/transport/chttp2/transport/writing.c - src/core/lib/channel/channel_args.c - src/core/lib/channel/channel_stack.c - src/core/lib/channel/channel_stack_builder.c - src/core/lib/channel/connected_channel.c - src/core/lib/channel/handshaker.c - src/core/lib/channel/handshaker_factory.c - src/core/lib/channel/handshaker_registry.c - src/core/lib/compression/compression.c - src/core/lib/compression/message_compress.c - src/core/lib/compression/stream_compression.c - src/core/lib/http/format_request.c - src/core/lib/http/httpcli.c - src/core/lib/http/parser.c - src/core/lib/iomgr/closure.c - src/core/lib/iomgr/combiner.c - src/core/lib/iomgr/endpoint.c - src/core/lib/iomgr/endpoint_pair_posix.c - src/core/lib/iomgr/endpoint_pair_uv.c - src/core/lib/iomgr/endpoint_pair_windows.c - src/core/lib/iomgr/error.c - src/core/lib/iomgr/ev_epoll1_linux.c - src/core/lib/iomgr/ev_epoll_limited_pollers_linux.c - src/core/lib/iomgr/ev_epoll_thread_pool_linux.c - src/core/lib/iomgr/ev_epollex_linux.c - src/core/lib/iomgr/ev_epollsig_linux.c - src/core/lib/iomgr/ev_poll_posix.c - src/core/lib/iomgr/ev_posix.c - src/core/lib/iomgr/ev_windows.c - src/core/lib/iomgr/exec_ctx.c - src/core/lib/iomgr/executor.c - src/core/lib/iomgr/iocp_windows.c - src/core/lib/iomgr/iomgr.c - src/core/lib/iomgr/iomgr_posix.c - src/core/lib/iomgr/iomgr_uv.c - src/core/lib/iomgr/iomgr_windows.c - src/core/lib/iomgr/is_epollexclusive_available.c - src/core/lib/iomgr/load_file.c - src/core/lib/iomgr/lockfree_event.c - src/core/lib/iomgr/network_status_tracker.c - src/core/lib/iomgr/polling_entity.c - src/core/lib/iomgr/pollset_set_uv.c - src/core/lib/iomgr/pollset_set_windows.c - src/core/lib/iomgr/pollset_uv.c - src/core/lib/iomgr/pollset_windows.c - src/core/lib/iomgr/resolve_address_posix.c - src/core/lib/iomgr/resolve_address_uv.c - src/core/lib/iomgr/resolve_address_windows.c - src/core/lib/iomgr/resource_quota.c - src/core/lib/iomgr/sockaddr_utils.c - src/core/lib/iomgr/socket_factory_posix.c - src/core/lib/iomgr/socket_mutator.c - src/core/lib/iomgr/socket_utils_common_posix.c - src/core/lib/iomgr/socket_utils_linux.c - src/core/lib/iomgr/socket_utils_posix.c - src/core/lib/iomgr/socket_utils_uv.c - src/core/lib/iomgr/socket_utils_windows.c - src/core/lib/iomgr/socket_windows.c - src/core/lib/iomgr/tcp_client_posix.c - src/core/lib/iomgr/tcp_client_uv.c - src/core/lib/iomgr/tcp_client_windows.c - src/core/lib/iomgr/tcp_posix.c - src/core/lib/iomgr/tcp_server_posix.c - src/core/lib/iomgr/tcp_server_utils_posix_common.c - src/core/lib/iomgr/tcp_server_utils_posix_ifaddrs.c - src/core/lib/iomgr/tcp_server_utils_posix_noifaddrs.c - src/core/lib/iomgr/tcp_server_uv.c - src/core/lib/iomgr/tcp_server_windows.c - src/core/lib/iomgr/tcp_uv.c - src/core/lib/iomgr/tcp_windows.c - src/core/lib/iomgr/time_averaged_stats.c - src/core/lib/iomgr/timer_generic.c - src/core/lib/iomgr/timer_heap.c - src/core/lib/iomgr/timer_manager.c - src/core/lib/iomgr/timer_uv.c - src/core/lib/iomgr/udp_server.c - src/core/lib/iomgr/unix_sockets_posix.c - src/core/lib/iomgr/unix_sockets_posix_noop.c - src/core/lib/iomgr/wakeup_fd_cv.c - src/core/lib/iomgr/wakeup_fd_eventfd.c - src/core/lib/iomgr/wakeup_fd_nospecial.c - src/core/lib/iomgr/wakeup_fd_pipe.c - src/core/lib/iomgr/wakeup_fd_posix.c - src/core/lib/json/json.c - src/core/lib/json/json_reader.c - src/core/lib/json/json_string.c - src/core/lib/json/json_writer.c - src/core/lib/slice/b64.c - src/core/lib/slice/percent_encoding.c - src/core/lib/slice/slice.c - src/core/lib/slice/slice_buffer.c - src/core/lib/slice/slice_hash_table.c - src/core/lib/slice/slice_intern.c - src/core/lib/slice/slice_string_helpers.c - src/core/lib/surface/alarm.c - src/core/lib/surface/api_trace.c - src/core/lib/surface/byte_buffer.c - src/core/lib/surface/byte_buffer_reader.c - src/core/lib/surface/call.c - src/core/lib/surface/call_details.c - src/core/lib/surface/call_log_batch.c - src/core/lib/surface/channel.c - src/core/lib/surface/channel_init.c - src/core/lib/surface/channel_ping.c - src/core/lib/surface/channel_stack_type.c - src/core/lib/surface/completion_queue.c - src/core/lib/surface/completion_queue_factory.c - src/core/lib/surface/event_string.c - src/core/lib/surface/lame_client.cc - src/core/lib/surface/metadata_array.c - src/core/lib/surface/server.c - src/core/lib/surface/validate_metadata.c - src/core/lib/surface/version.c - src/core/lib/transport/bdp_estimator.c - src/core/lib/transport/byte_stream.c - src/core/lib/transport/connectivity_state.c - src/core/lib/transport/error_utils.c - src/core/lib/transport/metadata.c - src/core/lib/transport/metadata_batch.c - src/core/lib/transport/pid_controller.c - src/core/lib/transport/service_config.c - src/core/lib/transport/static_metadata.c - src/core/lib/transport/status_conversion.c - src/core/lib/transport/timeout_encoding.c - src/core/lib/transport/transport.c - src/core/lib/transport/transport_op_string.c - src/core/lib/debug/trace.c - src/core/ext/transport/chttp2/alpn/alpn.c - src/core/ext/filters/http/client/http_client_filter.c - src/core/ext/filters/http/http_filters_plugin.c - src/core/ext/filters/http/message_compress/message_compress_filter.c - src/core/ext/filters/http/server/http_server_filter.c - src/core/ext/filters/client_channel/channel_connectivity.c - src/core/ext/filters/client_channel/client_channel.c - src/core/ext/filters/client_channel/client_channel_factory.c - src/core/ext/filters/client_channel/client_channel_plugin.c - src/core/ext/filters/client_channel/connector.c - src/core/ext/filters/client_channel/http_connect_handshaker.c - src/core/ext/filters/client_channel/http_proxy.c - src/core/ext/filters/client_channel/lb_policy.c - src/core/ext/filters/client_channel/lb_policy_factory.c - src/core/ext/filters/client_channel/lb_policy_registry.c - src/core/ext/filters/client_channel/parse_address.c - src/core/ext/filters/client_channel/proxy_mapper.c - src/core/ext/filters/client_channel/proxy_mapper_registry.c - src/core/ext/filters/client_channel/resolver.c - src/core/ext/filters/client_channel/resolver_factory.c - src/core/ext/filters/client_channel/resolver_registry.c - src/core/ext/filters/client_channel/retry_throttle.c - src/core/ext/filters/client_channel/subchannel.c - src/core/ext/filters/client_channel/subchannel_index.c - src/core/ext/filters/client_channel/uri_parser.c - src/core/ext/filters/deadline/deadline_filter.c - src/core/ext/transport/chttp2/server/insecure/server_chttp2.c - src/core/ext/transport/chttp2/server/insecure/server_chttp2_posix.c - src/core/ext/transport/chttp2/server/chttp2_server.c - src/core/ext/census/base_resources.c - src/core/ext/census/context.c - src/core/ext/census/gen/census.pb.c - src/core/ext/census/gen/trace_context.pb.c - src/core/ext/census/grpc_context.c - src/core/ext/census/grpc_filter.c - src/core/ext/census/grpc_plugin.c - src/core/ext/census/initialize.c - src/core/ext/census/intrusive_hash_map.c - src/core/ext/census/mlog.c - src/core/ext/census/operation.c - src/core/ext/census/placeholders.c - src/core/ext/census/resource.c - src/core/ext/census/trace_context.c - src/core/ext/census/tracing.c -) - -if(WIN32 AND MSVC) - set_target_properties(grpc++_cronet PROPERTIES COMPILE_PDB_NAME "grpc++_cronet" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/grpc++_cronet.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - - -target_include_directories(grpc++_cronet - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(grpc++_cronet - ${_gRPC_BASELIB_LIBRARIES} - ${_gRPC_SSL_LIBRARIES} - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr - grpc_cronet - grpc -) - -foreach(_hdr - include/grpc++/alarm.h - include/grpc++/channel.h - include/grpc++/client_context.h - include/grpc++/completion_queue.h - include/grpc++/create_channel.h - include/grpc++/create_channel_posix.h - include/grpc++/ext/health_check_service_server_builder_option.h - include/grpc++/generic/async_generic_service.h - include/grpc++/generic/generic_stub.h - include/grpc++/grpc++.h - include/grpc++/health_check_service_interface.h - include/grpc++/impl/call.h - include/grpc++/impl/channel_argument_option.h - include/grpc++/impl/client_unary_call.h - include/grpc++/impl/codegen/core_codegen.h - include/grpc++/impl/grpc_library.h - include/grpc++/impl/method_handler_impl.h - include/grpc++/impl/rpc_method.h - include/grpc++/impl/rpc_service_method.h - include/grpc++/impl/serialization_traits.h - include/grpc++/impl/server_builder_option.h - include/grpc++/impl/server_builder_plugin.h - include/grpc++/impl/server_initializer.h - include/grpc++/impl/service_type.h - include/grpc++/resource_quota.h - include/grpc++/security/auth_context.h - include/grpc++/security/auth_metadata_processor.h - include/grpc++/security/credentials.h - include/grpc++/security/server_credentials.h - include/grpc++/server.h - include/grpc++/server_builder.h - include/grpc++/server_context.h - include/grpc++/server_posix.h - include/grpc++/support/async_stream.h - include/grpc++/support/async_unary_call.h - include/grpc++/support/byte_buffer.h - include/grpc++/support/channel_arguments.h - include/grpc++/support/config.h - include/grpc++/support/slice.h - include/grpc++/support/status.h - include/grpc++/support/status_code_enum.h - include/grpc++/support/string_ref.h - include/grpc++/support/stub_options.h - include/grpc++/support/sync_stream.h - include/grpc++/support/time.h - include/grpc++/impl/codegen/async_stream.h - include/grpc++/impl/codegen/async_unary_call.h - include/grpc++/impl/codegen/call.h - include/grpc++/impl/codegen/call_hook.h - include/grpc++/impl/codegen/channel_interface.h - include/grpc++/impl/codegen/client_context.h - include/grpc++/impl/codegen/client_unary_call.h - include/grpc++/impl/codegen/completion_queue.h - include/grpc++/impl/codegen/completion_queue_tag.h - include/grpc++/impl/codegen/config.h - include/grpc++/impl/codegen/core_codegen_interface.h - include/grpc++/impl/codegen/create_auth_context.h - include/grpc++/impl/codegen/grpc_library.h - include/grpc++/impl/codegen/metadata_map.h - include/grpc++/impl/codegen/method_handler_impl.h - include/grpc++/impl/codegen/rpc_method.h - include/grpc++/impl/codegen/rpc_service_method.h - include/grpc++/impl/codegen/security/auth_context.h - include/grpc++/impl/codegen/serialization_traits.h - include/grpc++/impl/codegen/server_context.h - include/grpc++/impl/codegen/server_interface.h - include/grpc++/impl/codegen/service_type.h - include/grpc++/impl/codegen/slice.h - include/grpc++/impl/codegen/status.h - include/grpc++/impl/codegen/status_code_enum.h - include/grpc++/impl/codegen/string_ref.h - include/grpc++/impl/codegen/stub_options.h - include/grpc++/impl/codegen/sync_stream.h - include/grpc++/impl/codegen/time.h - include/grpc/impl/codegen/byte_buffer_reader.h - include/grpc/impl/codegen/compression_types.h - include/grpc/impl/codegen/connectivity_state.h - include/grpc/impl/codegen/exec_ctx_fwd.h - include/grpc/impl/codegen/grpc_types.h - include/grpc/impl/codegen/propagation_bits.h - include/grpc/impl/codegen/slice.h - include/grpc/impl/codegen/status.h - include/grpc/impl/codegen/atm.h - include/grpc/impl/codegen/atm_gcc_atomic.h - include/grpc/impl/codegen/atm_gcc_sync.h - include/grpc/impl/codegen/atm_windows.h - include/grpc/impl/codegen/gpr_slice.h - include/grpc/impl/codegen/gpr_types.h - include/grpc/impl/codegen/port_platform.h - include/grpc/impl/codegen/sync.h - include/grpc/impl/codegen/sync_generic.h - include/grpc/impl/codegen/sync_posix.h - include/grpc/impl/codegen/sync_windows.h - include/grpc/byte_buffer.h - include/grpc/byte_buffer_reader.h - include/grpc/compression.h - include/grpc/grpc.h - include/grpc/grpc_posix.h - include/grpc/grpc_security_constants.h - include/grpc/load_reporting.h - include/grpc/slice.h - include/grpc/slice_buffer.h - include/grpc/status.h - include/grpc/support/workaround_list.h - include/grpc/census.h -) - string(REPLACE "include/" "" _path ${_hdr}) - get_filename_component(_path ${_path} PATH) - install(FILES ${_hdr} - DESTINATION "${gRPC_INSTALL_INCLUDEDIR}/${_path}" - ) -endforeach() - - -if (gRPC_INSTALL) - install(TARGETS grpc++_cronet EXPORT gRPCTargets - RUNTIME DESTINATION ${gRPC_INSTALL_BINDIR} - LIBRARY DESTINATION ${gRPC_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${gRPC_INSTALL_LIBDIR} - ) -endif() - - -add_library(grpc++_error_details - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/status/status.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/status/status.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/status/status.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/status/status.grpc.pb.h - src/cpp/util/error_details.cc -) - -if(WIN32 AND MSVC) - set_target_properties(grpc++_error_details PROPERTIES COMPILE_PDB_NAME "grpc++_error_details" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/grpc++_error_details.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - -protobuf_generate_grpc_cpp( - src/proto/grpc/status/status.proto -) - -target_include_directories(grpc++_error_details - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(grpc++_error_details - ${_gRPC_BASELIB_LIBRARIES} - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++ -) - -foreach(_hdr - include/grpc++/support/error_details.h -) - string(REPLACE "include/" "" _path ${_hdr}) - get_filename_component(_path ${_path} PATH) - install(FILES ${_hdr} - DESTINATION "${gRPC_INSTALL_INCLUDEDIR}/${_path}" - ) -endforeach() - - -if (gRPC_INSTALL) - install(TARGETS grpc++_error_details EXPORT gRPCTargets - RUNTIME DESTINATION ${gRPC_INSTALL_BINDIR} - LIBRARY DESTINATION ${gRPC_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${gRPC_INSTALL_LIBDIR} - ) -endif() - -if (gRPC_BUILD_TESTS) - -add_library(grpc++_proto_reflection_desc_db - test/cpp/util/proto_reflection_descriptor_database.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/reflection/v1alpha/reflection.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/reflection/v1alpha/reflection.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/reflection/v1alpha/reflection.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/reflection/v1alpha/reflection.grpc.pb.h -) - -if(WIN32 AND MSVC) - set_target_properties(grpc++_proto_reflection_desc_db PROPERTIES COMPILE_PDB_NAME "grpc++_proto_reflection_desc_db" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/grpc++_proto_reflection_desc_db.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - -protobuf_generate_grpc_cpp( - src/proto/grpc/reflection/v1alpha/reflection.proto -) - -target_include_directories(grpc++_proto_reflection_desc_db - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(grpc++_proto_reflection_desc_db - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++ - grpc -) - -foreach(_hdr - include/grpc++/impl/codegen/config_protobuf.h -) - string(REPLACE "include/" "" _path ${_hdr}) - get_filename_component(_path ${_path} PATH) - install(FILES ${_hdr} - DESTINATION "${gRPC_INSTALL_INCLUDEDIR}/${_path}" - ) -endforeach() - -endif (gRPC_BUILD_TESTS) - -add_library(grpc++_reflection - src/cpp/ext/proto_server_reflection.cc - src/cpp/ext/proto_server_reflection_plugin.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/reflection/v1alpha/reflection.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/reflection/v1alpha/reflection.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/reflection/v1alpha/reflection.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/reflection/v1alpha/reflection.grpc.pb.h -) - -if(WIN32 AND MSVC) - set_target_properties(grpc++_reflection PROPERTIES COMPILE_PDB_NAME "grpc++_reflection" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/grpc++_reflection.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - -protobuf_generate_grpc_cpp( - src/proto/grpc/reflection/v1alpha/reflection.proto -) - -target_include_directories(grpc++_reflection - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(grpc++_reflection - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++ - grpc -) - -foreach(_hdr - include/grpc++/ext/proto_server_reflection_plugin.h -) - string(REPLACE "include/" "" _path ${_hdr}) - get_filename_component(_path ${_path} PATH) - install(FILES ${_hdr} - DESTINATION "${gRPC_INSTALL_INCLUDEDIR}/${_path}" - ) -endforeach() - - -if (gRPC_INSTALL) - install(TARGETS grpc++_reflection EXPORT gRPCTargets - RUNTIME DESTINATION ${gRPC_INSTALL_BINDIR} - LIBRARY DESTINATION ${gRPC_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${gRPC_INSTALL_LIBDIR} - ) -endif() - -if (gRPC_BUILD_TESTS) - -add_library(grpc++_test_config - test/cpp/util/test_config_cc.cc -) - -if(WIN32 AND MSVC) - set_target_properties(grpc++_test_config PROPERTIES COMPILE_PDB_NAME "grpc++_test_config" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/grpc++_test_config.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - - -target_include_directories(grpc++_test_config - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(grpc++_test_config - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} -) - - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_library(grpc++_test_util - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/health/v1/health.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/health/v1/health.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/health/v1/health.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/health/v1/health.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo_messages.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo_messages.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo_messages.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo_messages.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo_mock.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/duplicate/echo_duplicate.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/duplicate/echo_duplicate.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/duplicate/echo_duplicate.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/duplicate/echo_duplicate.grpc.pb.h - test/cpp/end2end/test_service_impl.cc - test/cpp/util/byte_buffer_proto_helper.cc - test/cpp/util/create_test_channel.cc - test/cpp/util/string_ref_helper.cc - test/cpp/util/subprocess.cc - test/cpp/util/test_credentials_provider.cc - src/cpp/codegen/codegen_init.cc -) - -if(WIN32 AND MSVC) - set_target_properties(grpc++_test_util PROPERTIES COMPILE_PDB_NAME "grpc++_test_util" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/grpc++_test_util.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - -protobuf_generate_grpc_cpp( - src/proto/grpc/health/v1/health.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/echo_messages.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/echo.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/duplicate/echo_duplicate.proto -) - -target_include_directories(grpc++_test_util - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(grpc++_test_util - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++ - grpc_test_util - grpc -) - -foreach(_hdr - include/grpc++/impl/codegen/async_stream.h - include/grpc++/impl/codegen/async_unary_call.h - include/grpc++/impl/codegen/call.h - include/grpc++/impl/codegen/call_hook.h - include/grpc++/impl/codegen/channel_interface.h - include/grpc++/impl/codegen/client_context.h - include/grpc++/impl/codegen/client_unary_call.h - include/grpc++/impl/codegen/completion_queue.h - include/grpc++/impl/codegen/completion_queue_tag.h - include/grpc++/impl/codegen/config.h - include/grpc++/impl/codegen/core_codegen_interface.h - include/grpc++/impl/codegen/create_auth_context.h - include/grpc++/impl/codegen/grpc_library.h - include/grpc++/impl/codegen/metadata_map.h - include/grpc++/impl/codegen/method_handler_impl.h - include/grpc++/impl/codegen/rpc_method.h - include/grpc++/impl/codegen/rpc_service_method.h - include/grpc++/impl/codegen/security/auth_context.h - include/grpc++/impl/codegen/serialization_traits.h - include/grpc++/impl/codegen/server_context.h - include/grpc++/impl/codegen/server_interface.h - include/grpc++/impl/codegen/service_type.h - include/grpc++/impl/codegen/slice.h - include/grpc++/impl/codegen/status.h - include/grpc++/impl/codegen/status_code_enum.h - include/grpc++/impl/codegen/string_ref.h - include/grpc++/impl/codegen/stub_options.h - include/grpc++/impl/codegen/sync_stream.h - include/grpc++/impl/codegen/time.h - include/grpc/impl/codegen/byte_buffer_reader.h - include/grpc/impl/codegen/compression_types.h - include/grpc/impl/codegen/connectivity_state.h - include/grpc/impl/codegen/exec_ctx_fwd.h - include/grpc/impl/codegen/grpc_types.h - include/grpc/impl/codegen/propagation_bits.h - include/grpc/impl/codegen/slice.h - include/grpc/impl/codegen/status.h - include/grpc/impl/codegen/atm.h - include/grpc/impl/codegen/atm_gcc_atomic.h - include/grpc/impl/codegen/atm_gcc_sync.h - include/grpc/impl/codegen/atm_windows.h - include/grpc/impl/codegen/gpr_slice.h - include/grpc/impl/codegen/gpr_types.h - include/grpc/impl/codegen/port_platform.h - include/grpc/impl/codegen/sync.h - include/grpc/impl/codegen/sync_generic.h - include/grpc/impl/codegen/sync_posix.h - include/grpc/impl/codegen/sync_windows.h - include/grpc++/impl/codegen/proto_utils.h - include/grpc++/impl/codegen/config_protobuf.h -) - string(REPLACE "include/" "" _path ${_hdr}) - get_filename_component(_path ${_path} PATH) - install(FILES ${_hdr} - DESTINATION "${gRPC_INSTALL_INCLUDEDIR}/${_path}" - ) -endforeach() - -endif (gRPC_BUILD_TESTS) - -add_library(grpc++_unsecure - src/cpp/client/insecure_credentials.cc - src/cpp/common/insecure_create_auth_context.cc - src/cpp/server/insecure_server_credentials.cc - src/cpp/client/channel_cc.cc - src/cpp/client/client_context.cc - src/cpp/client/create_channel.cc - src/cpp/client/create_channel_internal.cc - src/cpp/client/create_channel_posix.cc - src/cpp/client/credentials_cc.cc - src/cpp/client/generic_stub.cc - src/cpp/common/channel_arguments.cc - src/cpp/common/channel_filter.cc - src/cpp/common/completion_queue_cc.cc - src/cpp/common/core_codegen.cc - src/cpp/common/resource_quota_cc.cc - src/cpp/common/rpc_method.cc - src/cpp/common/version_cc.cc - src/cpp/server/async_generic_service.cc - src/cpp/server/channel_argument_option.cc - src/cpp/server/create_default_thread_pool.cc - src/cpp/server/dynamic_thread_pool.cc - src/cpp/server/health/default_health_check_service.cc - src/cpp/server/health/health.pb.c - src/cpp/server/health/health_check_service.cc - src/cpp/server/health/health_check_service_server_builder_option.cc - src/cpp/server/server_builder.cc - src/cpp/server/server_cc.cc - src/cpp/server/server_context.cc - src/cpp/server/server_credentials.cc - src/cpp/server/server_posix.cc - src/cpp/thread_manager/thread_manager.cc - src/cpp/util/byte_buffer_cc.cc - src/cpp/util/slice_cc.cc - src/cpp/util/status.cc - src/cpp/util/string_ref.cc - src/cpp/util/time_cc.cc - third_party/nanopb/pb_common.c - third_party/nanopb/pb_decode.c - third_party/nanopb/pb_encode.c - src/cpp/codegen/codegen_init.cc -) - -if(WIN32 AND MSVC) - set_target_properties(grpc++_unsecure PROPERTIES COMPILE_PDB_NAME "grpc++_unsecure" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/grpc++_unsecure.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - - -target_include_directories(grpc++_unsecure - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(grpc++_unsecure - ${_gRPC_BASELIB_LIBRARIES} - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr - grpc_unsecure -) - -foreach(_hdr - include/grpc++/alarm.h - include/grpc++/channel.h - include/grpc++/client_context.h - include/grpc++/completion_queue.h - include/grpc++/create_channel.h - include/grpc++/create_channel_posix.h - include/grpc++/ext/health_check_service_server_builder_option.h - include/grpc++/generic/async_generic_service.h - include/grpc++/generic/generic_stub.h - include/grpc++/grpc++.h - include/grpc++/health_check_service_interface.h - include/grpc++/impl/call.h - include/grpc++/impl/channel_argument_option.h - include/grpc++/impl/client_unary_call.h - include/grpc++/impl/codegen/core_codegen.h - include/grpc++/impl/grpc_library.h - include/grpc++/impl/method_handler_impl.h - include/grpc++/impl/rpc_method.h - include/grpc++/impl/rpc_service_method.h - include/grpc++/impl/serialization_traits.h - include/grpc++/impl/server_builder_option.h - include/grpc++/impl/server_builder_plugin.h - include/grpc++/impl/server_initializer.h - include/grpc++/impl/service_type.h - include/grpc++/resource_quota.h - include/grpc++/security/auth_context.h - include/grpc++/security/auth_metadata_processor.h - include/grpc++/security/credentials.h - include/grpc++/security/server_credentials.h - include/grpc++/server.h - include/grpc++/server_builder.h - include/grpc++/server_context.h - include/grpc++/server_posix.h - include/grpc++/support/async_stream.h - include/grpc++/support/async_unary_call.h - include/grpc++/support/byte_buffer.h - include/grpc++/support/channel_arguments.h - include/grpc++/support/config.h - include/grpc++/support/slice.h - include/grpc++/support/status.h - include/grpc++/support/status_code_enum.h - include/grpc++/support/string_ref.h - include/grpc++/support/stub_options.h - include/grpc++/support/sync_stream.h - include/grpc++/support/time.h - include/grpc++/impl/codegen/async_stream.h - include/grpc++/impl/codegen/async_unary_call.h - include/grpc++/impl/codegen/call.h - include/grpc++/impl/codegen/call_hook.h - include/grpc++/impl/codegen/channel_interface.h - include/grpc++/impl/codegen/client_context.h - include/grpc++/impl/codegen/client_unary_call.h - include/grpc++/impl/codegen/completion_queue.h - include/grpc++/impl/codegen/completion_queue_tag.h - include/grpc++/impl/codegen/config.h - include/grpc++/impl/codegen/core_codegen_interface.h - include/grpc++/impl/codegen/create_auth_context.h - include/grpc++/impl/codegen/grpc_library.h - include/grpc++/impl/codegen/metadata_map.h - include/grpc++/impl/codegen/method_handler_impl.h - include/grpc++/impl/codegen/rpc_method.h - include/grpc++/impl/codegen/rpc_service_method.h - include/grpc++/impl/codegen/security/auth_context.h - include/grpc++/impl/codegen/serialization_traits.h - include/grpc++/impl/codegen/server_context.h - include/grpc++/impl/codegen/server_interface.h - include/grpc++/impl/codegen/service_type.h - include/grpc++/impl/codegen/slice.h - include/grpc++/impl/codegen/status.h - include/grpc++/impl/codegen/status_code_enum.h - include/grpc++/impl/codegen/string_ref.h - include/grpc++/impl/codegen/stub_options.h - include/grpc++/impl/codegen/sync_stream.h - include/grpc++/impl/codegen/time.h - include/grpc/impl/codegen/byte_buffer_reader.h - include/grpc/impl/codegen/compression_types.h - include/grpc/impl/codegen/connectivity_state.h - include/grpc/impl/codegen/exec_ctx_fwd.h - include/grpc/impl/codegen/grpc_types.h - include/grpc/impl/codegen/propagation_bits.h - include/grpc/impl/codegen/slice.h - include/grpc/impl/codegen/status.h - include/grpc/impl/codegen/atm.h - include/grpc/impl/codegen/atm_gcc_atomic.h - include/grpc/impl/codegen/atm_gcc_sync.h - include/grpc/impl/codegen/atm_windows.h - include/grpc/impl/codegen/gpr_slice.h - include/grpc/impl/codegen/gpr_types.h - include/grpc/impl/codegen/port_platform.h - include/grpc/impl/codegen/sync.h - include/grpc/impl/codegen/sync_generic.h - include/grpc/impl/codegen/sync_posix.h - include/grpc/impl/codegen/sync_windows.h -) - string(REPLACE "include/" "" _path ${_hdr}) - get_filename_component(_path ${_path} PATH) - install(FILES ${_hdr} - DESTINATION "${gRPC_INSTALL_INCLUDEDIR}/${_path}" - ) -endforeach() - - -if (gRPC_INSTALL) - install(TARGETS grpc++_unsecure EXPORT gRPCTargets - RUNTIME DESTINATION ${gRPC_INSTALL_BINDIR} - LIBRARY DESTINATION ${gRPC_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${gRPC_INSTALL_LIBDIR} - ) -endif() - -if (gRPC_BUILD_TESTS) - -add_library(grpc_benchmark - test/cpp/microbenchmarks/helpers.cc -) - -if(WIN32 AND MSVC) - set_target_properties(grpc_benchmark PROPERTIES COMPILE_PDB_NAME "grpc_benchmark" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/grpc_benchmark.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - - -target_include_directories(grpc_benchmark - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(grpc_benchmark - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - benchmark - grpc++ - grpc_test_util - grpc - ${_gRPC_GFLAGS_LIBRARIES} -) - - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_library(grpc_cli_libs - test/cpp/util/cli_call.cc - test/cpp/util/cli_credentials.cc - test/cpp/util/grpc_tool.cc - test/cpp/util/proto_file_parser.cc - test/cpp/util/service_describer.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/reflection/v1alpha/reflection.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/reflection/v1alpha/reflection.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/reflection/v1alpha/reflection.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/reflection/v1alpha/reflection.grpc.pb.h -) - -if(WIN32 AND MSVC) - set_target_properties(grpc_cli_libs PROPERTIES COMPILE_PDB_NAME "grpc_cli_libs" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/grpc_cli_libs.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - -protobuf_generate_grpc_cpp( - src/proto/grpc/reflection/v1alpha/reflection.proto -) - -target_include_directories(grpc_cli_libs - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(grpc_cli_libs - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_proto_reflection_desc_db - grpc++ - grpc -) - -foreach(_hdr - include/grpc++/impl/codegen/config_protobuf.h -) - string(REPLACE "include/" "" _path ${_hdr}) - get_filename_component(_path ${_path} PATH) - install(FILES ${_hdr} - DESTINATION "${gRPC_INSTALL_INCLUDEDIR}/${_path}" - ) -endforeach() - -endif (gRPC_BUILD_TESTS) - -add_library(grpc_plugin_support - src/compiler/cpp_generator.cc - src/compiler/csharp_generator.cc - src/compiler/node_generator.cc - src/compiler/objective_c_generator.cc - src/compiler/php_generator.cc - src/compiler/python_generator.cc - src/compiler/ruby_generator.cc -) - -if(WIN32 AND MSVC) - set_target_properties(grpc_plugin_support PROPERTIES COMPILE_PDB_NAME "grpc_plugin_support" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/grpc_plugin_support.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - - -target_include_directories(grpc_plugin_support - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(grpc_plugin_support - ${_gRPC_PROTOBUF_PROTOC_LIBRARIES} - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} -) - -foreach(_hdr - include/grpc++/impl/codegen/config_protobuf.h -) - string(REPLACE "include/" "" _path ${_hdr}) - get_filename_component(_path ${_path} PATH) - install(FILES ${_hdr} - DESTINATION "${gRPC_INSTALL_INCLUDEDIR}/${_path}" - ) -endforeach() - - -if (gRPC_INSTALL) - install(TARGETS grpc_plugin_support EXPORT gRPCTargets - RUNTIME DESTINATION ${gRPC_INSTALL_BINDIR} - LIBRARY DESTINATION ${gRPC_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${gRPC_INSTALL_LIBDIR} - ) -endif() - -if (gRPC_BUILD_TESTS) - -add_library(http2_client_main - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/empty.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/empty.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/empty.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/empty.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/test.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/test.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/test.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/test.grpc.pb.h - test/cpp/interop/http2_client.cc -) - -if(WIN32 AND MSVC) - set_target_properties(http2_client_main PROPERTIES COMPILE_PDB_NAME "http2_client_main" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/http2_client_main.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/empty.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/messages.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/test.proto -) - -target_include_directories(http2_client_main - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(http2_client_main - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - grpc++_test_config -) - - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_library(interop_client_helper - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.grpc.pb.h - test/cpp/interop/client_helper.cc -) - -if(WIN32 AND MSVC) - set_target_properties(interop_client_helper PROPERTIES COMPILE_PDB_NAME "interop_client_helper" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/interop_client_helper.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/messages.proto -) - -target_include_directories(interop_client_helper - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(interop_client_helper - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr -) - - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_library(interop_client_main - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/empty.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/empty.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/empty.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/empty.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/test.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/test.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/test.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/test.grpc.pb.h - test/cpp/interop/client.cc - test/cpp/interop/interop_client.cc -) - -if(WIN32 AND MSVC) - set_target_properties(interop_client_main PROPERTIES COMPILE_PDB_NAME "interop_client_main" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/interop_client_main.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/empty.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/messages.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/test.proto -) - -target_include_directories(interop_client_main - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(interop_client_main - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - interop_client_helper - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - grpc++_test_config -) - - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_library(interop_server_helper - test/cpp/interop/server_helper.cc -) - -if(WIN32 AND MSVC) - set_target_properties(interop_server_helper PROPERTIES COMPILE_PDB_NAME "interop_server_helper" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/interop_server_helper.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - - -target_include_directories(interop_server_helper - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(interop_server_helper - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr -) - - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_library(interop_server_lib - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/empty.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/empty.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/empty.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/empty.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/test.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/test.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/test.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/test.grpc.pb.h - test/cpp/interop/interop_server.cc -) - -if(WIN32 AND MSVC) - set_target_properties(interop_server_lib PROPERTIES COMPILE_PDB_NAME "interop_server_lib" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/interop_server_lib.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/empty.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/messages.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/test.proto -) - -target_include_directories(interop_server_lib - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(interop_server_lib - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - interop_server_helper - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - grpc++_test_config -) - - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_library(interop_server_main - test/cpp/interop/interop_server_bootstrap.cc -) - -if(WIN32 AND MSVC) - set_target_properties(interop_server_main PROPERTIES COMPILE_PDB_NAME "interop_server_main" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/interop_server_main.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - - -target_include_directories(interop_server_main - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(interop_server_main - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - interop_server_lib -) - - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_library(qps - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/payloads.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/payloads.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/payloads.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/payloads.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/stats.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/stats.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/stats.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/stats.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/control.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/control.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/control.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/control.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/services.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/services.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/services.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/services.grpc.pb.h - test/cpp/qps/benchmark_config.cc - test/cpp/qps/client_async.cc - test/cpp/qps/client_sync.cc - test/cpp/qps/driver.cc - test/cpp/qps/parse_json.cc - test/cpp/qps/qps_worker.cc - test/cpp/qps/report.cc - test/cpp/qps/server_async.cc - test/cpp/qps/server_sync.cc - test/cpp/qps/usage_timer.cc -) - -if(WIN32 AND MSVC) - set_target_properties(qps PROPERTIES COMPILE_PDB_NAME "qps" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/qps.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/messages.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/payloads.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/stats.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/control.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/services.proto -) - -target_include_directories(qps - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(qps - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc++_test_util - grpc++ - grpc -) - - -endif (gRPC_BUILD_TESTS) - -add_library(grpc_csharp_ext SHARED - src/csharp/ext/grpc_csharp_ext.c -) - -if(WIN32 AND MSVC) - set_target_properties(grpc_csharp_ext PROPERTIES COMPILE_PDB_NAME "grpc_csharp_ext" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/grpc_csharp_ext.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - - -target_include_directories(grpc_csharp_ext - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(grpc_csharp_ext - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc - gpr -) - - - -if (gRPC_INSTALL) - install(TARGETS grpc_csharp_ext EXPORT gRPCTargets - RUNTIME DESTINATION ${gRPC_INSTALL_BINDIR} - LIBRARY DESTINATION ${gRPC_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${gRPC_INSTALL_LIBDIR} - ) -endif() - -if (gRPC_BUILD_TESTS) - -add_library(ares - third_party/cares/cares/ares__close_sockets.c - third_party/cares/cares/ares__get_hostent.c - third_party/cares/cares/ares__read_line.c - third_party/cares/cares/ares__timeval.c - third_party/cares/cares/ares_cancel.c - third_party/cares/cares/ares_create_query.c - third_party/cares/cares/ares_data.c - third_party/cares/cares/ares_destroy.c - third_party/cares/cares/ares_expand_name.c - third_party/cares/cares/ares_expand_string.c - third_party/cares/cares/ares_fds.c - third_party/cares/cares/ares_free_hostent.c - third_party/cares/cares/ares_free_string.c - third_party/cares/cares/ares_getenv.c - third_party/cares/cares/ares_gethostbyaddr.c - third_party/cares/cares/ares_gethostbyname.c - third_party/cares/cares/ares_getnameinfo.c - third_party/cares/cares/ares_getopt.c - third_party/cares/cares/ares_getsock.c - third_party/cares/cares/ares_init.c - third_party/cares/cares/ares_library_init.c - third_party/cares/cares/ares_llist.c - third_party/cares/cares/ares_mkquery.c - third_party/cares/cares/ares_nowarn.c - third_party/cares/cares/ares_options.c - third_party/cares/cares/ares_parse_a_reply.c - third_party/cares/cares/ares_parse_aaaa_reply.c - third_party/cares/cares/ares_parse_mx_reply.c - third_party/cares/cares/ares_parse_naptr_reply.c - third_party/cares/cares/ares_parse_ns_reply.c - third_party/cares/cares/ares_parse_ptr_reply.c - third_party/cares/cares/ares_parse_soa_reply.c - third_party/cares/cares/ares_parse_srv_reply.c - third_party/cares/cares/ares_parse_txt_reply.c - third_party/cares/cares/ares_platform.c - third_party/cares/cares/ares_process.c - third_party/cares/cares/ares_query.c - third_party/cares/cares/ares_search.c - third_party/cares/cares/ares_send.c - third_party/cares/cares/ares_strcasecmp.c - third_party/cares/cares/ares_strdup.c - third_party/cares/cares/ares_strerror.c - third_party/cares/cares/ares_timeout.c - third_party/cares/cares/ares_version.c - third_party/cares/cares/ares_writev.c - third_party/cares/cares/bitncmp.c - third_party/cares/cares/inet_net_pton.c - third_party/cares/cares/inet_ntop.c - third_party/cares/cares/windows_port.c -) - -if(WIN32 AND MSVC) - set_target_properties(ares PROPERTIES COMPILE_PDB_NAME "ares" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/ares.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - - -target_include_directories(ares - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(ares - ${_gRPC_SSL_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} -) - - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_library(bad_client_test - test/core/bad_client/bad_client.c -) - -if(WIN32 AND MSVC) - set_target_properties(bad_client_test PROPERTIES COMPILE_PDB_NAME "bad_client_test" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/bad_client_test.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - - -target_include_directories(bad_client_test - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(bad_client_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_library(bad_ssl_test_server - test/core/bad_ssl/server_common.c -) - -if(WIN32 AND MSVC) - set_target_properties(bad_ssl_test_server PROPERTIES COMPILE_PDB_NAME "bad_ssl_test_server" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/bad_ssl_test_server.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - - -target_include_directories(bad_ssl_test_server - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(bad_ssl_test_server - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_library(end2end_tests - test/core/end2end/end2end_tests.c - test/core/end2end/end2end_test_utils.c - test/core/end2end/tests/authority_not_supported.c - test/core/end2end/tests/bad_hostname.c - test/core/end2end/tests/bad_ping.c - test/core/end2end/tests/binary_metadata.c - test/core/end2end/tests/call_creds.c - test/core/end2end/tests/cancel_after_accept.c - test/core/end2end/tests/cancel_after_client_done.c - test/core/end2end/tests/cancel_after_invoke.c - test/core/end2end/tests/cancel_after_round_trip.c - test/core/end2end/tests/cancel_before_invoke.c - test/core/end2end/tests/cancel_in_a_vacuum.c - test/core/end2end/tests/cancel_with_status.c - test/core/end2end/tests/compressed_payload.c - test/core/end2end/tests/connectivity.c - test/core/end2end/tests/default_host.c - test/core/end2end/tests/disappearing_server.c - test/core/end2end/tests/empty_batch.c - test/core/end2end/tests/filter_call_init_fails.c - test/core/end2end/tests/filter_causes_close.c - test/core/end2end/tests/filter_latency.c - test/core/end2end/tests/graceful_server_shutdown.c - test/core/end2end/tests/high_initial_seqno.c - test/core/end2end/tests/hpack_size.c - test/core/end2end/tests/idempotent_request.c - test/core/end2end/tests/invoke_large_request.c - test/core/end2end/tests/keepalive_timeout.c - test/core/end2end/tests/large_metadata.c - test/core/end2end/tests/load_reporting_hook.c - test/core/end2end/tests/max_concurrent_streams.c - test/core/end2end/tests/max_connection_age.c - test/core/end2end/tests/max_connection_idle.c - test/core/end2end/tests/max_message_length.c - test/core/end2end/tests/negative_deadline.c - test/core/end2end/tests/network_status_change.c - test/core/end2end/tests/no_logging.c - test/core/end2end/tests/no_op.c - test/core/end2end/tests/payload.c - test/core/end2end/tests/ping.c - test/core/end2end/tests/ping_pong_streaming.c - test/core/end2end/tests/proxy_auth.c - test/core/end2end/tests/registered_call.c - test/core/end2end/tests/request_with_flags.c - test/core/end2end/tests/request_with_payload.c - test/core/end2end/tests/resource_quota_server.c - test/core/end2end/tests/server_finishes_request.c - test/core/end2end/tests/shutdown_finishes_calls.c - test/core/end2end/tests/shutdown_finishes_tags.c - test/core/end2end/tests/simple_cacheable_request.c - test/core/end2end/tests/simple_delayed_request.c - test/core/end2end/tests/simple_metadata.c - test/core/end2end/tests/simple_request.c - test/core/end2end/tests/streaming_error_response.c - test/core/end2end/tests/trailing_metadata.c - test/core/end2end/tests/workaround_cronet_compression.c - test/core/end2end/tests/write_buffering.c - test/core/end2end/tests/write_buffering_at_end.c -) - -if(WIN32 AND MSVC) - set_target_properties(end2end_tests PROPERTIES COMPILE_PDB_NAME "end2end_tests" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/end2end_tests.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - - -target_include_directories(end2end_tests - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(end2end_tests - ${_gRPC_SSL_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_library(end2end_nosec_tests - test/core/end2end/end2end_nosec_tests.c - test/core/end2end/end2end_test_utils.c - test/core/end2end/tests/authority_not_supported.c - test/core/end2end/tests/bad_hostname.c - test/core/end2end/tests/bad_ping.c - test/core/end2end/tests/binary_metadata.c - test/core/end2end/tests/cancel_after_accept.c - test/core/end2end/tests/cancel_after_client_done.c - test/core/end2end/tests/cancel_after_invoke.c - test/core/end2end/tests/cancel_after_round_trip.c - test/core/end2end/tests/cancel_before_invoke.c - test/core/end2end/tests/cancel_in_a_vacuum.c - test/core/end2end/tests/cancel_with_status.c - test/core/end2end/tests/compressed_payload.c - test/core/end2end/tests/connectivity.c - test/core/end2end/tests/default_host.c - test/core/end2end/tests/disappearing_server.c - test/core/end2end/tests/empty_batch.c - test/core/end2end/tests/filter_call_init_fails.c - test/core/end2end/tests/filter_causes_close.c - test/core/end2end/tests/filter_latency.c - test/core/end2end/tests/graceful_server_shutdown.c - test/core/end2end/tests/high_initial_seqno.c - test/core/end2end/tests/hpack_size.c - test/core/end2end/tests/idempotent_request.c - test/core/end2end/tests/invoke_large_request.c - test/core/end2end/tests/keepalive_timeout.c - test/core/end2end/tests/large_metadata.c - test/core/end2end/tests/load_reporting_hook.c - test/core/end2end/tests/max_concurrent_streams.c - test/core/end2end/tests/max_connection_age.c - test/core/end2end/tests/max_connection_idle.c - test/core/end2end/tests/max_message_length.c - test/core/end2end/tests/negative_deadline.c - test/core/end2end/tests/network_status_change.c - test/core/end2end/tests/no_logging.c - test/core/end2end/tests/no_op.c - test/core/end2end/tests/payload.c - test/core/end2end/tests/ping.c - test/core/end2end/tests/ping_pong_streaming.c - test/core/end2end/tests/proxy_auth.c - test/core/end2end/tests/registered_call.c - test/core/end2end/tests/request_with_flags.c - test/core/end2end/tests/request_with_payload.c - test/core/end2end/tests/resource_quota_server.c - test/core/end2end/tests/server_finishes_request.c - test/core/end2end/tests/shutdown_finishes_calls.c - test/core/end2end/tests/shutdown_finishes_tags.c - test/core/end2end/tests/simple_cacheable_request.c - test/core/end2end/tests/simple_delayed_request.c - test/core/end2end/tests/simple_metadata.c - test/core/end2end/tests/simple_request.c - test/core/end2end/tests/streaming_error_response.c - test/core/end2end/tests/trailing_metadata.c - test/core/end2end/tests/workaround_cronet_compression.c - test/core/end2end/tests/write_buffering.c - test/core/end2end/tests/write_buffering_at_end.c -) - -if(WIN32 AND MSVC) - set_target_properties(end2end_nosec_tests PROPERTIES COMPILE_PDB_NAME "end2end_nosec_tests" - COMPILE_PDB_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}" - ) - if (gRPC_INSTALL) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/end2end_nosec_tests.pdb - DESTINATION ${gRPC_INSTALL_LIBDIR} OPTIONAL - ) - endif() -endif() - - -target_include_directories(end2end_nosec_tests - PUBLIC $ $ - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${ZLIB_INCLUDE_DIR} - PRIVATE ${BENCHMARK}/include - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(end2end_nosec_tests - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - - -endif (gRPC_BUILD_TESTS) - -if (gRPC_BUILD_TESTS) - -add_executable(alarm_test - test/core/surface/alarm_test.c -) - - -target_include_directories(alarm_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(alarm_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(algorithm_test - test/core/compression/algorithm_test.c -) - - -target_include_directories(algorithm_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(algorithm_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(alloc_test - test/core/support/alloc_test.c -) - - -target_include_directories(alloc_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(alloc_test - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(alpn_test - test/core/transport/chttp2/alpn_test.c -) - - -target_include_directories(alpn_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(alpn_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(arena_test - test/core/support/arena_test.c -) - - -target_include_directories(arena_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(arena_test - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(bad_server_response_test - test/core/end2end/bad_server_response_test.c -) - - -target_include_directories(bad_server_response_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(bad_server_response_test - ${_gRPC_ALLTARGETS_LIBRARIES} - test_tcp_server - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(bdp_estimator_test - test/core/transport/bdp_estimator_test.c -) - - -target_include_directories(bdp_estimator_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(bdp_estimator_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(bin_decoder_test - test/core/transport/chttp2/bin_decoder_test.c -) - - -target_include_directories(bin_decoder_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(bin_decoder_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(bin_encoder_test - test/core/transport/chttp2/bin_encoder_test.c -) - - -target_include_directories(bin_encoder_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(bin_encoder_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(census_context_test - test/core/census/context_test.c -) - - -target_include_directories(census_context_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(census_context_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(census_intrusive_hash_map_test - test/core/census/intrusive_hash_map_test.c -) - - -target_include_directories(census_intrusive_hash_map_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(census_intrusive_hash_map_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(census_resource_test - test/core/census/resource_test.c -) - - -target_include_directories(census_resource_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(census_resource_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(census_trace_context_test - test/core/census/trace_context_test.c -) - - -target_include_directories(census_trace_context_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(census_trace_context_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(channel_create_test - test/core/surface/channel_create_test.c -) - - -target_include_directories(channel_create_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(channel_create_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) - -add_executable(check_epollexclusive - test/build/check_epollexclusive.c -) - - -target_include_directories(check_epollexclusive - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(check_epollexclusive - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc - gpr -) - - -if (gRPC_INSTALL) - install(TARGETS check_epollexclusive EXPORT gRPCTargets - RUNTIME DESTINATION ${gRPC_INSTALL_BINDIR} - LIBRARY DESTINATION ${gRPC_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${gRPC_INSTALL_LIBDIR} - ) -endif() - -if (gRPC_BUILD_TESTS) - -add_executable(chttp2_hpack_encoder_test - test/core/transport/chttp2/hpack_encoder_test.c -) - - -target_include_directories(chttp2_hpack_encoder_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(chttp2_hpack_encoder_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(chttp2_stream_map_test - test/core/transport/chttp2/stream_map_test.c -) - - -target_include_directories(chttp2_stream_map_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(chttp2_stream_map_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(chttp2_varint_test - test/core/transport/chttp2/varint_test.c -) - - -target_include_directories(chttp2_varint_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(chttp2_varint_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(combiner_test - test/core/iomgr/combiner_test.c -) - - -target_include_directories(combiner_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(combiner_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(compression_test - test/core/compression/compression_test.c -) - - -target_include_directories(compression_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(compression_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(concurrent_connectivity_test - test/core/surface/concurrent_connectivity_test.c -) - - -target_include_directories(concurrent_connectivity_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(concurrent_connectivity_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(connection_refused_test - test/core/end2end/connection_refused_test.c -) - - -target_include_directories(connection_refused_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(connection_refused_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(dns_resolver_connectivity_test - test/core/client_channel/resolvers/dns_resolver_connectivity_test.c -) - - -target_include_directories(dns_resolver_connectivity_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(dns_resolver_connectivity_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(dns_resolver_test - test/core/client_channel/resolvers/dns_resolver_test.c -) - - -target_include_directories(dns_resolver_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(dns_resolver_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(dualstack_socket_test - test/core/end2end/dualstack_socket_test.c -) - - -target_include_directories(dualstack_socket_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(dualstack_socket_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(endpoint_pair_test - test/core/iomgr/endpoint_pair_test.c -) - - -target_include_directories(endpoint_pair_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(endpoint_pair_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(error_test - test/core/iomgr/error_test.c -) - - -target_include_directories(error_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(error_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX) - -add_executable(ev_epollsig_linux_test - test/core/iomgr/ev_epollsig_linux_test.c -) - - -target_include_directories(ev_epollsig_linux_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(ev_epollsig_linux_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(fake_resolver_test - test/core/client_channel/resolvers/fake_resolver_test.c -) - - -target_include_directories(fake_resolver_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(fake_resolver_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(fd_conservation_posix_test - test/core/iomgr/fd_conservation_posix_test.c -) - - -target_include_directories(fd_conservation_posix_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(fd_conservation_posix_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(fd_posix_test - test/core/iomgr/fd_posix_test.c -) - - -target_include_directories(fd_posix_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(fd_posix_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(fling_client - test/core/fling/client.c -) - - -target_include_directories(fling_client - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(fling_client - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(fling_server - test/core/fling/server.c -) - - -target_include_directories(fling_server - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(fling_server - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(fling_stream_test - test/core/fling/fling_stream_test.c -) - - -target_include_directories(fling_stream_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(fling_stream_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(fling_test - test/core/fling/fling_test.c -) - - -target_include_directories(fling_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(fling_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) - -add_executable(gen_hpack_tables - tools/codegen/core/gen_hpack_tables.c -) - - -target_include_directories(gen_hpack_tables - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(gen_hpack_tables - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr - grpc -) - - -if (gRPC_INSTALL) - install(TARGETS gen_hpack_tables EXPORT gRPCTargets - RUNTIME DESTINATION ${gRPC_INSTALL_BINDIR} - LIBRARY DESTINATION ${gRPC_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${gRPC_INSTALL_LIBDIR} - ) -endif() - - -add_executable(gen_legal_metadata_characters - tools/codegen/core/gen_legal_metadata_characters.c -) - - -target_include_directories(gen_legal_metadata_characters - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(gen_legal_metadata_characters - ${_gRPC_ALLTARGETS_LIBRARIES} -) - - -if (gRPC_INSTALL) - install(TARGETS gen_legal_metadata_characters EXPORT gRPCTargets - RUNTIME DESTINATION ${gRPC_INSTALL_BINDIR} - LIBRARY DESTINATION ${gRPC_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${gRPC_INSTALL_LIBDIR} - ) -endif() - - -add_executable(gen_percent_encoding_tables - tools/codegen/core/gen_percent_encoding_tables.c -) - - -target_include_directories(gen_percent_encoding_tables - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(gen_percent_encoding_tables - ${_gRPC_ALLTARGETS_LIBRARIES} -) - - -if (gRPC_INSTALL) - install(TARGETS gen_percent_encoding_tables EXPORT gRPCTargets - RUNTIME DESTINATION ${gRPC_INSTALL_BINDIR} - LIBRARY DESTINATION ${gRPC_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${gRPC_INSTALL_LIBDIR} - ) -endif() - -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(goaway_server_test - test/core/end2end/goaway_server_test.c -) - - -target_include_directories(goaway_server_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(goaway_server_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(gpr_avl_test - test/core/support/avl_test.c -) - - -target_include_directories(gpr_avl_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(gpr_avl_test - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(gpr_backoff_test - test/core/support/backoff_test.c -) - - -target_include_directories(gpr_backoff_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(gpr_backoff_test - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(gpr_cmdline_test - test/core/support/cmdline_test.c -) - - -target_include_directories(gpr_cmdline_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(gpr_cmdline_test - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(gpr_cpu_test - test/core/support/cpu_test.c -) - - -target_include_directories(gpr_cpu_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(gpr_cpu_test - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(gpr_env_test - test/core/support/env_test.c -) - - -target_include_directories(gpr_env_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(gpr_env_test - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(gpr_histogram_test - test/core/support/histogram_test.c -) - - -target_include_directories(gpr_histogram_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(gpr_histogram_test - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(gpr_host_port_test - test/core/support/host_port_test.c -) - - -target_include_directories(gpr_host_port_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(gpr_host_port_test - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(gpr_log_test - test/core/support/log_test.c -) - - -target_include_directories(gpr_log_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(gpr_log_test - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(gpr_mpscq_test - test/core/support/mpscq_test.c -) - - -target_include_directories(gpr_mpscq_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(gpr_mpscq_test - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(gpr_spinlock_test - test/core/support/spinlock_test.c -) - - -target_include_directories(gpr_spinlock_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(gpr_spinlock_test - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(gpr_stack_lockfree_test - test/core/support/stack_lockfree_test.c -) - - -target_include_directories(gpr_stack_lockfree_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(gpr_stack_lockfree_test - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(gpr_string_test - test/core/support/string_test.c -) - - -target_include_directories(gpr_string_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(gpr_string_test - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(gpr_sync_test - test/core/support/sync_test.c -) - - -target_include_directories(gpr_sync_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(gpr_sync_test - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(gpr_thd_test - test/core/support/thd_test.c -) - - -target_include_directories(gpr_thd_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(gpr_thd_test - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(gpr_time_test - test/core/support/time_test.c -) - - -target_include_directories(gpr_time_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(gpr_time_test - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(gpr_tls_test - test/core/support/tls_test.c -) - - -target_include_directories(gpr_tls_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(gpr_tls_test - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(gpr_useful_test - test/core/support/useful_test.c -) - - -target_include_directories(gpr_useful_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(gpr_useful_test - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(grpc_auth_context_test - test/core/security/auth_context_test.c -) - - -target_include_directories(grpc_auth_context_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(grpc_auth_context_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(grpc_b64_test - test/core/slice/b64_test.c -) - - -target_include_directories(grpc_b64_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(grpc_b64_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(grpc_byte_buffer_reader_test - test/core/surface/byte_buffer_reader_test.c -) - - -target_include_directories(grpc_byte_buffer_reader_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(grpc_byte_buffer_reader_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(grpc_channel_args_test - test/core/channel/channel_args_test.c -) - - -target_include_directories(grpc_channel_args_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(grpc_channel_args_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(grpc_channel_stack_test - test/core/channel/channel_stack_test.c -) - - -target_include_directories(grpc_channel_stack_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(grpc_channel_stack_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(grpc_completion_queue_test - test/core/surface/completion_queue_test.c -) - - -target_include_directories(grpc_completion_queue_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(grpc_completion_queue_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(grpc_completion_queue_threading_test - test/core/surface/completion_queue_threading_test.c -) - - -target_include_directories(grpc_completion_queue_threading_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(grpc_completion_queue_threading_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) - -add_executable(grpc_create_jwt - test/core/security/create_jwt.c -) - - -target_include_directories(grpc_create_jwt - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(grpc_create_jwt - ${_gRPC_SSL_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc - gpr -) - - -if (gRPC_INSTALL) - install(TARGETS grpc_create_jwt EXPORT gRPCTargets - RUNTIME DESTINATION ${gRPC_INSTALL_BINDIR} - LIBRARY DESTINATION ${gRPC_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${gRPC_INSTALL_LIBDIR} - ) -endif() - -if (gRPC_BUILD_TESTS) - -add_executable(grpc_credentials_test - test/core/security/credentials_test.c -) - - -target_include_directories(grpc_credentials_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(grpc_credentials_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(grpc_fetch_oauth2 - test/core/security/fetch_oauth2.c -) - - -target_include_directories(grpc_fetch_oauth2 - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(grpc_fetch_oauth2 - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(grpc_invalid_channel_args_test - test/core/surface/invalid_channel_args_test.c -) - - -target_include_directories(grpc_invalid_channel_args_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(grpc_invalid_channel_args_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(grpc_json_token_test - test/core/security/json_token_test.c -) - - -target_include_directories(grpc_json_token_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(grpc_json_token_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(grpc_jwt_verifier_test - test/core/security/jwt_verifier_test.c -) - - -target_include_directories(grpc_jwt_verifier_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(grpc_jwt_verifier_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) - -add_executable(grpc_print_google_default_creds_token - test/core/security/print_google_default_creds_token.c -) - - -target_include_directories(grpc_print_google_default_creds_token - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(grpc_print_google_default_creds_token - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc - gpr -) - - -if (gRPC_INSTALL) - install(TARGETS grpc_print_google_default_creds_token EXPORT gRPCTargets - RUNTIME DESTINATION ${gRPC_INSTALL_BINDIR} - LIBRARY DESTINATION ${gRPC_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${gRPC_INSTALL_LIBDIR} - ) -endif() - -if (gRPC_BUILD_TESTS) - -add_executable(grpc_security_connector_test - test/core/security/security_connector_test.c -) - - -target_include_directories(grpc_security_connector_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(grpc_security_connector_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) - -add_executable(grpc_verify_jwt - test/core/security/verify_jwt.c -) - - -target_include_directories(grpc_verify_jwt - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(grpc_verify_jwt - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc - gpr -) - - -if (gRPC_INSTALL) - install(TARGETS grpc_verify_jwt EXPORT gRPCTargets - RUNTIME DESTINATION ${gRPC_INSTALL_BINDIR} - LIBRARY DESTINATION ${gRPC_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${gRPC_INSTALL_LIBDIR} - ) -endif() - -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX) - -add_executable(handshake_client - test/core/handshake/client_ssl.c -) - - -target_include_directories(handshake_client - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(handshake_client - ${_gRPC_SSL_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX) - -add_executable(handshake_server - test/core/handshake/server_ssl.c -) - - -target_include_directories(handshake_server - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(handshake_server - ${_gRPC_SSL_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(hpack_parser_test - test/core/transport/chttp2/hpack_parser_test.c -) - - -target_include_directories(hpack_parser_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(hpack_parser_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(hpack_table_test - test/core/transport/chttp2/hpack_table_test.c -) - - -target_include_directories(hpack_table_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(hpack_table_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(http_parser_test - test/core/http/parser_test.c -) - - -target_include_directories(http_parser_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(http_parser_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(httpcli_format_request_test - test/core/http/format_request_test.c -) - - -target_include_directories(httpcli_format_request_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(httpcli_format_request_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(httpcli_test - test/core/http/httpcli_test.c -) - - -target_include_directories(httpcli_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(httpcli_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX) - -add_executable(httpscli_test - test/core/http/httpscli_test.c -) - - -target_include_directories(httpscli_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(httpscli_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(init_test - test/core/surface/init_test.c -) - - -target_include_directories(init_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(init_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(invalid_call_argument_test - test/core/end2end/invalid_call_argument_test.c -) - - -target_include_directories(invalid_call_argument_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(invalid_call_argument_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(json_rewrite - test/core/json/json_rewrite.c -) - - -target_include_directories(json_rewrite - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(json_rewrite - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(json_rewrite_test - test/core/json/json_rewrite_test.c -) - - -target_include_directories(json_rewrite_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(json_rewrite_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(json_stream_error_test - test/core/json/json_stream_error_test.c -) - - -target_include_directories(json_stream_error_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(json_stream_error_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(json_test - test/core/json/json_test.c -) - - -target_include_directories(json_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(json_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(lame_client_test - test/core/surface/lame_client_test.c -) - - -target_include_directories(lame_client_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(lame_client_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(lb_policies_test - test/core/client_channel/lb_policies_test.c -) - - -target_include_directories(lb_policies_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(lb_policies_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(load_file_test - test/core/iomgr/load_file_test.c -) - - -target_include_directories(load_file_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(load_file_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(memory_profile_client - test/core/memory_usage/client.c -) - - -target_include_directories(memory_profile_client - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(memory_profile_client - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(memory_profile_server - test/core/memory_usage/server.c -) - - -target_include_directories(memory_profile_server - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(memory_profile_server - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(memory_profile_test - test/core/memory_usage/memory_usage_test.c -) - - -target_include_directories(memory_profile_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(memory_profile_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(message_compress_test - test/core/compression/message_compress_test.c -) - - -target_include_directories(message_compress_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(message_compress_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(minimal_stack_is_minimal_test - test/core/channel/minimal_stack_is_minimal_test.c -) - - -target_include_directories(minimal_stack_is_minimal_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(minimal_stack_is_minimal_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(mlog_test - test/core/census/mlog_test.c -) - - -target_include_directories(mlog_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(mlog_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(multiple_server_queues_test - test/core/end2end/multiple_server_queues_test.c -) - - -target_include_directories(multiple_server_queues_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(multiple_server_queues_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(murmur_hash_test - test/core/support/murmur_hash_test.c -) - - -target_include_directories(murmur_hash_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(murmur_hash_test - ${_gRPC_ALLTARGETS_LIBRARIES} - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(no_server_test - test/core/end2end/no_server_test.c -) - - -target_include_directories(no_server_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(no_server_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(num_external_connectivity_watchers_test - test/core/surface/num_external_connectivity_watchers_test.c -) - - -target_include_directories(num_external_connectivity_watchers_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(num_external_connectivity_watchers_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(parse_address_test - test/core/client_channel/parse_address_test.c -) - - -target_include_directories(parse_address_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(parse_address_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(percent_encoding_test - test/core/slice/percent_encoding_test.c -) - - -target_include_directories(percent_encoding_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(percent_encoding_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX) - -add_executable(pollset_set_test - test/core/iomgr/pollset_set_test.c -) - - -target_include_directories(pollset_set_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(pollset_set_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(resolve_address_posix_test - test/core/iomgr/resolve_address_posix_test.c -) - - -target_include_directories(resolve_address_posix_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(resolve_address_posix_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(resolve_address_test - test/core/iomgr/resolve_address_test.c -) - - -target_include_directories(resolve_address_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(resolve_address_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(resource_quota_test - test/core/iomgr/resource_quota_test.c -) - - -target_include_directories(resource_quota_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(resource_quota_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(secure_channel_create_test - test/core/surface/secure_channel_create_test.c -) - - -target_include_directories(secure_channel_create_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(secure_channel_create_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(secure_endpoint_test - test/core/security/secure_endpoint_test.c -) - - -target_include_directories(secure_endpoint_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(secure_endpoint_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(sequential_connectivity_test - test/core/surface/sequential_connectivity_test.c -) - - -target_include_directories(sequential_connectivity_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(sequential_connectivity_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(server_chttp2_test - test/core/surface/server_chttp2_test.c -) - - -target_include_directories(server_chttp2_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(server_chttp2_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(server_test - test/core/surface/server_test.c -) - - -target_include_directories(server_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(server_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(slice_buffer_test - test/core/slice/slice_buffer_test.c -) - - -target_include_directories(slice_buffer_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(slice_buffer_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(slice_hash_table_test - test/core/slice/slice_hash_table_test.c -) - - -target_include_directories(slice_hash_table_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(slice_hash_table_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(slice_string_helpers_test - test/core/slice/slice_string_helpers_test.c -) - - -target_include_directories(slice_string_helpers_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(slice_string_helpers_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(slice_test - test/core/slice/slice_test.c -) - - -target_include_directories(slice_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(slice_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(sockaddr_resolver_test - test/core/client_channel/resolvers/sockaddr_resolver_test.c -) - - -target_include_directories(sockaddr_resolver_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(sockaddr_resolver_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(sockaddr_utils_test - test/core/iomgr/sockaddr_utils_test.c -) - - -target_include_directories(sockaddr_utils_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(sockaddr_utils_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(socket_utils_test - test/core/iomgr/socket_utils_test.c -) - - -target_include_directories(socket_utils_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(socket_utils_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(status_conversion_test - test/core/transport/status_conversion_test.c -) - - -target_include_directories(status_conversion_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(status_conversion_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(stream_compression_test - test/core/compression/stream_compression_test.c -) - - -target_include_directories(stream_compression_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(stream_compression_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(stream_owned_slice_test - test/core/transport/stream_owned_slice_test.c -) - - -target_include_directories(stream_owned_slice_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(stream_owned_slice_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(tcp_client_posix_test - test/core/iomgr/tcp_client_posix_test.c -) - - -target_include_directories(tcp_client_posix_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(tcp_client_posix_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(tcp_client_uv_test - test/core/iomgr/tcp_client_uv_test.c -) - - -target_include_directories(tcp_client_uv_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(tcp_client_uv_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(tcp_posix_test - test/core/iomgr/tcp_posix_test.c -) - - -target_include_directories(tcp_posix_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(tcp_posix_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(tcp_server_posix_test - test/core/iomgr/tcp_server_posix_test.c -) - - -target_include_directories(tcp_server_posix_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(tcp_server_posix_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(tcp_server_uv_test - test/core/iomgr/tcp_server_uv_test.c -) - - -target_include_directories(tcp_server_uv_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(tcp_server_uv_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(time_averaged_stats_test - test/core/iomgr/time_averaged_stats_test.c -) - - -target_include_directories(time_averaged_stats_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(time_averaged_stats_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(timeout_encoding_test - test/core/transport/timeout_encoding_test.c -) - - -target_include_directories(timeout_encoding_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(timeout_encoding_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(timer_heap_test - test/core/iomgr/timer_heap_test.c -) - - -target_include_directories(timer_heap_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(timer_heap_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(timer_list_test - test/core/iomgr/timer_list_test.c -) - - -target_include_directories(timer_list_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(timer_list_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(transport_connectivity_state_test - test/core/transport/connectivity_state_test.c -) - - -target_include_directories(transport_connectivity_state_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(transport_connectivity_state_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(transport_metadata_test - test/core/transport/metadata_test.c -) - - -target_include_directories(transport_metadata_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(transport_metadata_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(transport_pid_controller_test - test/core/transport/pid_controller_test.c -) - - -target_include_directories(transport_pid_controller_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(transport_pid_controller_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(transport_security_test - test/core/tsi/transport_security_test.c -) - - -target_include_directories(transport_security_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(transport_security_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(udp_server_test - test/core/iomgr/udp_server_test.c -) - - -target_include_directories(udp_server_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(udp_server_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(uri_parser_test - test/core/client_channel/uri_parser_test.c -) - - -target_include_directories(uri_parser_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(uri_parser_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(wakeup_fd_cv_test - test/core/iomgr/wakeup_fd_cv_test.c -) - - -target_include_directories(wakeup_fd_cv_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(wakeup_fd_cv_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(alarm_cpp_test - test/cpp/common/alarm_cpp_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(alarm_cpp_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(alarm_cpp_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(async_end2end_test - test/cpp/end2end/async_end2end_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(async_end2end_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(async_end2end_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(auth_property_iterator_test - test/cpp/common/auth_property_iterator_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(auth_property_iterator_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(auth_property_iterator_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(bm_arena - test/cpp/microbenchmarks/bm_arena.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(bm_arena - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(bm_arena - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_benchmark - benchmark - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(bm_call_create - test/cpp/microbenchmarks/bm_call_create.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(bm_call_create - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(bm_call_create - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_benchmark - benchmark - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(bm_chttp2_hpack - test/cpp/microbenchmarks/bm_chttp2_hpack.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(bm_chttp2_hpack - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(bm_chttp2_hpack - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_benchmark - benchmark - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(bm_chttp2_transport - test/cpp/microbenchmarks/bm_chttp2_transport.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(bm_chttp2_transport - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(bm_chttp2_transport - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_benchmark - benchmark - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(bm_closure - test/cpp/microbenchmarks/bm_closure.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(bm_closure - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(bm_closure - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_benchmark - benchmark - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(bm_cq - test/cpp/microbenchmarks/bm_cq.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(bm_cq - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(bm_cq - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_benchmark - benchmark - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(bm_cq_multiple_threads - test/cpp/microbenchmarks/bm_cq_multiple_threads.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(bm_cq_multiple_threads - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(bm_cq_multiple_threads - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_benchmark - benchmark - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(bm_error - test/cpp/microbenchmarks/bm_error.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(bm_error - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(bm_error - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_benchmark - benchmark - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(bm_fullstack_streaming_ping_pong - test/cpp/microbenchmarks/bm_fullstack_streaming_ping_pong.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(bm_fullstack_streaming_ping_pong - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(bm_fullstack_streaming_ping_pong - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_benchmark - benchmark - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(bm_fullstack_streaming_pump - test/cpp/microbenchmarks/bm_fullstack_streaming_pump.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(bm_fullstack_streaming_pump - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(bm_fullstack_streaming_pump - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_benchmark - benchmark - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(bm_fullstack_trickle - test/cpp/microbenchmarks/bm_fullstack_trickle.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(bm_fullstack_trickle - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(bm_fullstack_trickle - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_benchmark - benchmark - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(bm_fullstack_unary_ping_pong - test/cpp/microbenchmarks/bm_fullstack_unary_ping_pong.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(bm_fullstack_unary_ping_pong - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(bm_fullstack_unary_ping_pong - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_benchmark - benchmark - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(bm_metadata - test/cpp/microbenchmarks/bm_metadata.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(bm_metadata - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(bm_metadata - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_benchmark - benchmark - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(bm_pollset - test/cpp/microbenchmarks/bm_pollset.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(bm_pollset - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(bm_pollset - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_benchmark - benchmark - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(channel_arguments_test - test/cpp/common/channel_arguments_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(channel_arguments_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(channel_arguments_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++ - grpc - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(channel_filter_test - test/cpp/common/channel_filter_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(channel_filter_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(channel_filter_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++ - grpc - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(cli_call_test - test/cpp/util/cli_call_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(cli_call_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(cli_call_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_cli_libs - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(client_crash_test - test/cpp/end2end/client_crash_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(client_crash_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(client_crash_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(client_crash_test_server - test/cpp/end2end/client_crash_test_server.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(client_crash_test_server - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(client_crash_test_server - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(client_lb_end2end_test - test/cpp/end2end/client_lb_end2end_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(client_lb_end2end_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(client_lb_end2end_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(codegen_test_full - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/control.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/control.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/control.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/control.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/payloads.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/payloads.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/payloads.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/payloads.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/services.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/services.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/services.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/services.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/stats.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/stats.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/stats.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/stats.grpc.pb.h - test/cpp/codegen/codegen_test_full.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/control.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/messages.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/payloads.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/services.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/stats.proto -) - -target_include_directories(codegen_test_full - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(codegen_test_full - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++ - grpc - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(codegen_test_minimal - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/control.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/control.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/control.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/control.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/payloads.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/payloads.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/payloads.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/payloads.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/services.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/services.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/services.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/services.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/stats.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/stats.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/stats.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/stats.grpc.pb.h - test/cpp/codegen/codegen_test_minimal.cc - src/cpp/codegen/codegen_init.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/control.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/messages.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/payloads.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/services.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/stats.proto -) - -target_include_directories(codegen_test_minimal - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(codegen_test_minimal - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(credentials_test - test/cpp/client/credentials_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(credentials_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(credentials_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++ - grpc - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(cxx_byte_buffer_test - test/cpp/util/byte_buffer_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(cxx_byte_buffer_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(cxx_byte_buffer_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(cxx_slice_test - test/cpp/util/slice_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(cxx_slice_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(cxx_slice_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(cxx_string_ref_test - test/cpp/util/string_ref_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(cxx_string_ref_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(cxx_string_ref_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++ - grpc - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(cxx_time_test - test/cpp/util/time_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(cxx_time_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(cxx_time_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(end2end_test - test/cpp/end2end/end2end_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(end2end_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(end2end_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(error_details_test - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo_messages.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo_messages.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo_messages.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo_messages.grpc.pb.h - test/cpp/util/error_details_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/echo_messages.proto -) - -target_include_directories(error_details_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(error_details_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_error_details - grpc++ - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(filter_end2end_test - test/cpp/end2end/filter_end2end_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(filter_end2end_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(filter_end2end_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(generic_end2end_test - test/cpp/end2end/generic_end2end_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(generic_end2end_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(generic_end2end_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(golden_file_test - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/compiler_test.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/compiler_test.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/compiler_test.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/compiler_test.grpc.pb.h - test/cpp/codegen/golden_file_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/compiler_test.proto -) - -target_include_directories(golden_file_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(golden_file_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++ - grpc - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(grpc_cli - test/cpp/util/grpc_cli.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(grpc_cli - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(grpc_cli - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_cli_libs - grpc++_proto_reflection_desc_db - grpc++ - grpc - gpr - grpc++_test_config - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) - -add_executable(grpc_cpp_plugin - src/compiler/cpp_plugin.cc -) - - -target_include_directories(grpc_cpp_plugin - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(grpc_cpp_plugin - ${_gRPC_PROTOBUF_PROTOC_LIBRARIES} - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_plugin_support -) - - -if (gRPC_INSTALL) - install(TARGETS grpc_cpp_plugin EXPORT gRPCTargets - RUNTIME DESTINATION ${gRPC_INSTALL_BINDIR} - LIBRARY DESTINATION ${gRPC_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${gRPC_INSTALL_LIBDIR} - ) -endif() - - -add_executable(grpc_csharp_plugin - src/compiler/csharp_plugin.cc -) - - -target_include_directories(grpc_csharp_plugin - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(grpc_csharp_plugin - ${_gRPC_PROTOBUF_PROTOC_LIBRARIES} - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_plugin_support -) - - -if (gRPC_INSTALL) - install(TARGETS grpc_csharp_plugin EXPORT gRPCTargets - RUNTIME DESTINATION ${gRPC_INSTALL_BINDIR} - LIBRARY DESTINATION ${gRPC_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${gRPC_INSTALL_LIBDIR} - ) -endif() - - -add_executable(grpc_node_plugin - src/compiler/node_plugin.cc -) - - -target_include_directories(grpc_node_plugin - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(grpc_node_plugin - ${_gRPC_PROTOBUF_PROTOC_LIBRARIES} - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_plugin_support -) - - -if (gRPC_INSTALL) - install(TARGETS grpc_node_plugin EXPORT gRPCTargets - RUNTIME DESTINATION ${gRPC_INSTALL_BINDIR} - LIBRARY DESTINATION ${gRPC_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${gRPC_INSTALL_LIBDIR} - ) -endif() - - -add_executable(grpc_objective_c_plugin - src/compiler/objective_c_plugin.cc -) - - -target_include_directories(grpc_objective_c_plugin - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(grpc_objective_c_plugin - ${_gRPC_PROTOBUF_PROTOC_LIBRARIES} - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_plugin_support -) - - -if (gRPC_INSTALL) - install(TARGETS grpc_objective_c_plugin EXPORT gRPCTargets - RUNTIME DESTINATION ${gRPC_INSTALL_BINDIR} - LIBRARY DESTINATION ${gRPC_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${gRPC_INSTALL_LIBDIR} - ) -endif() - - -add_executable(grpc_php_plugin - src/compiler/php_plugin.cc -) - - -target_include_directories(grpc_php_plugin - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(grpc_php_plugin - ${_gRPC_PROTOBUF_PROTOC_LIBRARIES} - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_plugin_support -) - - -if (gRPC_INSTALL) - install(TARGETS grpc_php_plugin EXPORT gRPCTargets - RUNTIME DESTINATION ${gRPC_INSTALL_BINDIR} - LIBRARY DESTINATION ${gRPC_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${gRPC_INSTALL_LIBDIR} - ) -endif() - - -add_executable(grpc_python_plugin - src/compiler/python_plugin.cc -) - - -target_include_directories(grpc_python_plugin - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(grpc_python_plugin - ${_gRPC_PROTOBUF_PROTOC_LIBRARIES} - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_plugin_support -) - - -if (gRPC_INSTALL) - install(TARGETS grpc_python_plugin EXPORT gRPCTargets - RUNTIME DESTINATION ${gRPC_INSTALL_BINDIR} - LIBRARY DESTINATION ${gRPC_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${gRPC_INSTALL_LIBDIR} - ) -endif() - - -add_executable(grpc_ruby_plugin - src/compiler/ruby_plugin.cc -) - - -target_include_directories(grpc_ruby_plugin - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(grpc_ruby_plugin - ${_gRPC_PROTOBUF_PROTOC_LIBRARIES} - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_plugin_support -) - - -if (gRPC_INSTALL) - install(TARGETS grpc_ruby_plugin EXPORT gRPCTargets - RUNTIME DESTINATION ${gRPC_INSTALL_BINDIR} - LIBRARY DESTINATION ${gRPC_INSTALL_LIBDIR} - ARCHIVE DESTINATION ${gRPC_INSTALL_LIBDIR} - ) -endif() - -if (gRPC_BUILD_TESTS) - -add_executable(grpc_tool_test - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo_messages.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo_messages.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo_messages.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo_messages.grpc.pb.h - test/cpp/util/grpc_tool_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/echo.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/echo_messages.proto -) - -target_include_directories(grpc_tool_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(grpc_tool_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_cli_libs - grpc++_proto_reflection_desc_db - grpc++_reflection - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(grpclb_api_test - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/lb/v1/load_balancer.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/lb/v1/load_balancer.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/lb/v1/load_balancer.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/lb/v1/load_balancer.grpc.pb.h - test/cpp/grpclb/grpclb_api_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - -protobuf_generate_grpc_cpp( - src/proto/grpc/lb/v1/load_balancer.proto -) - -target_include_directories(grpclb_api_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(grpclb_api_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(grpclb_end2end_test - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/lb/v1/load_balancer.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/lb/v1/load_balancer.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/lb/v1/load_balancer.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/lb/v1/load_balancer.grpc.pb.h - test/cpp/end2end/grpclb_end2end_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - -protobuf_generate_grpc_cpp( - src/proto/grpc/lb/v1/load_balancer.proto -) - -target_include_directories(grpclb_end2end_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(grpclb_end2end_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(grpclb_test - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/lb/v1/load_balancer.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/lb/v1/load_balancer.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/lb/v1/load_balancer.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/lb/v1/load_balancer.grpc.pb.h - test/cpp/grpclb/grpclb_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - -protobuf_generate_grpc_cpp( - src/proto/grpc/lb/v1/load_balancer.proto -) - -target_include_directories(grpclb_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(grpclb_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(health_service_end2end_test - test/cpp/end2end/health_service_end2end_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(health_service_end2end_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(health_service_end2end_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(http2_client - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(http2_client - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(http2_client - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - http2_client_main - grpc++_test_util - grpc_test_util - grpc++ - grpc - grpc++_test_config - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(hybrid_end2end_test - test/cpp/end2end/hybrid_end2end_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(hybrid_end2end_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(hybrid_end2end_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(interop_client - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(interop_client - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(interop_client - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - interop_client_main - interop_client_helper - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - grpc++_test_config - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(interop_server - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(interop_server - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(interop_server - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - interop_server_main - interop_server_helper - interop_server_lib - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - grpc++_test_config - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(interop_test - test/cpp/interop/interop_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(interop_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(interop_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr - grpc++_test_config - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(json_run_localhost - test/cpp/qps/json_run_localhost.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(json_run_localhost - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(json_run_localhost - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - grpc++_test_config - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(memory_test - test/core/support/memory_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(memory_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(memory_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(metrics_client - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/metrics.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/metrics.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/metrics.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/metrics.grpc.pb.h - test/cpp/interop/metrics_client.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/metrics.proto -) - -target_include_directories(metrics_client - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(metrics_client - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++ - grpc - gpr - grpc++_test_config - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(mock_test - test/cpp/end2end/mock_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(mock_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(mock_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(noop-benchmark - test/cpp/microbenchmarks/noop-benchmark.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(noop-benchmark - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(noop-benchmark - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - benchmark - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(proto_server_reflection_test - test/cpp/end2end/proto_server_reflection_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(proto_server_reflection_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(proto_server_reflection_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_proto_reflection_desc_db - grpc++_reflection - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(proto_utils_test - test/cpp/codegen/proto_utils_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(proto_utils_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(proto_utils_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++ - grpc - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(qps_interarrival_test - test/cpp/qps/qps_interarrival_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(qps_interarrival_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(qps_interarrival_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - qps - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - grpc++_test_config - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(qps_json_driver - test/cpp/qps/qps_json_driver.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(qps_json_driver - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(qps_json_driver - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - qps - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - grpc++_test_config - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(qps_openloop_test - test/cpp/qps/qps_openloop_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(qps_openloop_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(qps_openloop_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - qps - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - grpc++_test_config - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(qps_worker - test/cpp/qps/worker.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(qps_worker - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(qps_worker - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - qps - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - grpc++_test_config - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(reconnect_interop_client - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/empty.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/empty.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/empty.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/empty.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/test.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/test.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/test.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/test.grpc.pb.h - test/cpp/interop/reconnect_interop_client.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/empty.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/messages.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/test.proto -) - -target_include_directories(reconnect_interop_client - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(reconnect_interop_client - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - grpc++_test_config - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(reconnect_interop_server - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/empty.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/empty.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/empty.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/empty.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/test.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/test.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/test.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/test.grpc.pb.h - test/cpp/interop/reconnect_interop_server.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/empty.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/messages.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/test.proto -) - -target_include_directories(reconnect_interop_server - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(reconnect_interop_server - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - reconnect_server - test_tcp_server - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - grpc++_test_config - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(secure_auth_context_test - test/cpp/common/secure_auth_context_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(secure_auth_context_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(secure_auth_context_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(secure_sync_unary_ping_pong_test - test/cpp/qps/secure_sync_unary_ping_pong_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(secure_sync_unary_ping_pong_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(secure_sync_unary_ping_pong_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - qps - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - grpc++_test_config - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(server_builder_plugin_test - test/cpp/end2end/server_builder_plugin_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(server_builder_plugin_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(server_builder_plugin_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(server_builder_test - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo_messages.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo_messages.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo_messages.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo_messages.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo.grpc.pb.h - test/cpp/server/server_builder_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/echo_messages.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/echo.proto -) - -target_include_directories(server_builder_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(server_builder_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - gpr_test_util - grpc++ - grpc - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(server_context_test_spouse_test - test/cpp/test/server_context_test_spouse_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(server_context_test_spouse_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(server_context_test_spouse_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(server_crash_test - test/cpp/end2end/server_crash_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(server_crash_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(server_crash_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(server_crash_test_client - test/cpp/end2end/server_crash_test_client.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(server_crash_test_client - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(server_crash_test_client - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(server_request_call_test - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo_messages.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo_messages.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo_messages.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo_messages.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/echo.grpc.pb.h - test/cpp/server/server_request_call_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/echo_messages.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/echo.proto -) - -target_include_directories(server_request_call_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(server_request_call_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - gpr_test_util - grpc++ - grpc - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(shutdown_test - test/cpp/end2end/shutdown_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(shutdown_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(shutdown_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(status_test - test/cpp/util/status_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(status_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(status_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(streaming_throughput_test - test/cpp/end2end/streaming_throughput_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(streaming_throughput_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(streaming_throughput_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(stress_test - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/empty.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/empty.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/empty.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/empty.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/messages.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/metrics.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/metrics.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/metrics.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/metrics.grpc.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/test.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/test.grpc.pb.cc - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/test.pb.h - ${_gRPC_PROTO_GENS_DIR}/src/proto/grpc/testing/test.grpc.pb.h - test/cpp/interop/interop_client.cc - test/cpp/interop/stress_interop_client.cc - test/cpp/interop/stress_test.cc - test/cpp/util/metrics_server.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/empty.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/messages.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/metrics.proto -) -protobuf_generate_grpc_cpp( - src/proto/grpc/testing/test.proto -) - -target_include_directories(stress_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(stress_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - grpc++_test_config - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(thread_manager_test - test/cpp/thread_manager/thread_manager_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(thread_manager_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(thread_manager_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++ - grpc - gpr - grpc++_test_config - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(thread_stress_test - test/cpp/end2end/thread_stress_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(thread_stress_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(thread_stress_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(writes_per_rpc_test - test/cpp/performance/writes_per_rpc_test.cc - third_party/googletest/googletest/src/gtest-all.cc - third_party/googletest/googlemock/src/gmock-all.cc -) - - -target_include_directories(writes_per_rpc_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include - PRIVATE third_party/googletest/googletest/include - PRIVATE third_party/googletest/googletest - PRIVATE third_party/googletest/googlemock/include - PRIVATE third_party/googletest/googlemock - PRIVATE ${_gRPC_PROTO_GENS_DIR} -) - -target_link_libraries(writes_per_rpc_test - ${_gRPC_PROTOBUF_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc++_test_util - grpc_test_util - grpc++ - grpc - gpr_test_util - gpr - ${_gRPC_GFLAGS_LIBRARIES} -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(public_headers_must_be_c89 - test/core/surface/public_headers_must_be_c89.c -) - - -target_include_directories(public_headers_must_be_c89 - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(public_headers_must_be_c89 - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(badreq_bad_client_test - test/core/bad_client/tests/badreq.c -) - - -target_include_directories(badreq_bad_client_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(badreq_bad_client_test - ${_gRPC_SSL_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - bad_client_test - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(connection_prefix_bad_client_test - test/core/bad_client/tests/connection_prefix.c -) - - -target_include_directories(connection_prefix_bad_client_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(connection_prefix_bad_client_test - ${_gRPC_SSL_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - bad_client_test - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(head_of_line_blocking_bad_client_test - test/core/bad_client/tests/head_of_line_blocking.c -) - - -target_include_directories(head_of_line_blocking_bad_client_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(head_of_line_blocking_bad_client_test - ${_gRPC_SSL_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - bad_client_test - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(headers_bad_client_test - test/core/bad_client/tests/headers.c -) - - -target_include_directories(headers_bad_client_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(headers_bad_client_test - ${_gRPC_SSL_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - bad_client_test - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(initial_settings_frame_bad_client_test - test/core/bad_client/tests/initial_settings_frame.c -) - - -target_include_directories(initial_settings_frame_bad_client_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(initial_settings_frame_bad_client_test - ${_gRPC_SSL_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - bad_client_test - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(large_metadata_bad_client_test - test/core/bad_client/tests/large_metadata.c -) - - -target_include_directories(large_metadata_bad_client_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(large_metadata_bad_client_test - ${_gRPC_SSL_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - bad_client_test - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(server_registered_method_bad_client_test - test/core/bad_client/tests/server_registered_method.c -) - - -target_include_directories(server_registered_method_bad_client_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(server_registered_method_bad_client_test - ${_gRPC_SSL_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - bad_client_test - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(simple_request_bad_client_test - test/core/bad_client/tests/simple_request.c -) - - -target_include_directories(simple_request_bad_client_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(simple_request_bad_client_test - ${_gRPC_SSL_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - bad_client_test - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(unknown_frame_bad_client_test - test/core/bad_client/tests/unknown_frame.c -) - - -target_include_directories(unknown_frame_bad_client_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(unknown_frame_bad_client_test - ${_gRPC_SSL_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - bad_client_test - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(window_overflow_bad_client_test - test/core/bad_client/tests/window_overflow.c -) - - -target_include_directories(window_overflow_bad_client_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(window_overflow_bad_client_test - ${_gRPC_SSL_LIBRARIES} - ${_gRPC_ALLTARGETS_LIBRARIES} - bad_client_test - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(bad_ssl_cert_server - test/core/bad_ssl/servers/cert.c -) - - -target_include_directories(bad_ssl_cert_server - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(bad_ssl_cert_server - ${_gRPC_ALLTARGETS_LIBRARIES} - bad_ssl_test_server - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(bad_ssl_cert_test - test/core/bad_ssl/bad_ssl_test.c -) - - -target_include_directories(bad_ssl_cert_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(bad_ssl_cert_test - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_census_test - test/core/end2end/fixtures/h2_census.c -) - - -target_include_directories(h2_census_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_census_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_tests - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_compress_test - test/core/end2end/fixtures/h2_compress.c -) - - -target_include_directories(h2_compress_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_compress_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_tests - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_fakesec_test - test/core/end2end/fixtures/h2_fakesec.c -) - - -target_include_directories(h2_fakesec_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_fakesec_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_tests - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(h2_fd_test - test/core/end2end/fixtures/h2_fd.c -) - - -target_include_directories(h2_fd_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_fd_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_tests - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_full_test - test/core/end2end/fixtures/h2_full.c -) - - -target_include_directories(h2_full_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_full_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_tests - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX) - -add_executable(h2_full+pipe_test - test/core/end2end/fixtures/h2_full+pipe.c -) - - -target_include_directories(h2_full+pipe_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_full+pipe_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_tests - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_full+trace_test - test/core/end2end/fixtures/h2_full+trace.c -) - - -target_include_directories(h2_full+trace_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_full+trace_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_tests - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_full+workarounds_test - test/core/end2end/fixtures/h2_full+workarounds.c -) - - -target_include_directories(h2_full+workarounds_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_full+workarounds_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_tests - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_http_proxy_test - test/core/end2end/fixtures/h2_http_proxy.c -) - - -target_include_directories(h2_http_proxy_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_http_proxy_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_tests - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_load_reporting_test - test/core/end2end/fixtures/h2_load_reporting.c -) - - -target_include_directories(h2_load_reporting_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_load_reporting_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_tests - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_oauth2_test - test/core/end2end/fixtures/h2_oauth2.c -) - - -target_include_directories(h2_oauth2_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_oauth2_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_tests - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_proxy_test - test/core/end2end/fixtures/h2_proxy.c -) - - -target_include_directories(h2_proxy_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_proxy_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_tests - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_sockpair_test - test/core/end2end/fixtures/h2_sockpair.c -) - - -target_include_directories(h2_sockpair_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_sockpair_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_tests - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_sockpair+trace_test - test/core/end2end/fixtures/h2_sockpair+trace.c -) - - -target_include_directories(h2_sockpair+trace_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_sockpair+trace_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_tests - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_sockpair_1byte_test - test/core/end2end/fixtures/h2_sockpair_1byte.c -) - - -target_include_directories(h2_sockpair_1byte_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_sockpair_1byte_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_tests - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_ssl_test - test/core/end2end/fixtures/h2_ssl.c -) - - -target_include_directories(h2_ssl_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_ssl_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_tests - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_ssl_cert_test - test/core/end2end/fixtures/h2_ssl_cert.c -) - - -target_include_directories(h2_ssl_cert_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_ssl_cert_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_tests - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_ssl_proxy_test - test/core/end2end/fixtures/h2_ssl_proxy.c -) - - -target_include_directories(h2_ssl_proxy_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_ssl_proxy_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_tests - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(h2_uds_test - test/core/end2end/fixtures/h2_uds.c -) - - -target_include_directories(h2_uds_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_uds_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_tests - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(inproc_test - test/core/end2end/fixtures/inproc.c -) - - -target_include_directories(inproc_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(inproc_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_tests - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_census_nosec_test - test/core/end2end/fixtures/h2_census.c -) - - -target_include_directories(h2_census_nosec_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_census_nosec_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_nosec_tests - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_compress_nosec_test - test/core/end2end/fixtures/h2_compress.c -) - - -target_include_directories(h2_compress_nosec_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_compress_nosec_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_nosec_tests - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(h2_fd_nosec_test - test/core/end2end/fixtures/h2_fd.c -) - - -target_include_directories(h2_fd_nosec_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_fd_nosec_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_nosec_tests - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_full_nosec_test - test/core/end2end/fixtures/h2_full.c -) - - -target_include_directories(h2_full_nosec_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_full_nosec_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_nosec_tests - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX) - -add_executable(h2_full+pipe_nosec_test - test/core/end2end/fixtures/h2_full+pipe.c -) - - -target_include_directories(h2_full+pipe_nosec_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_full+pipe_nosec_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_nosec_tests - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_full+trace_nosec_test - test/core/end2end/fixtures/h2_full+trace.c -) - - -target_include_directories(h2_full+trace_nosec_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_full+trace_nosec_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_nosec_tests - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_full+workarounds_nosec_test - test/core/end2end/fixtures/h2_full+workarounds.c -) - - -target_include_directories(h2_full+workarounds_nosec_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_full+workarounds_nosec_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_nosec_tests - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_http_proxy_nosec_test - test/core/end2end/fixtures/h2_http_proxy.c -) - - -target_include_directories(h2_http_proxy_nosec_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_http_proxy_nosec_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_nosec_tests - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_load_reporting_nosec_test - test/core/end2end/fixtures/h2_load_reporting.c -) - - -target_include_directories(h2_load_reporting_nosec_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_load_reporting_nosec_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_nosec_tests - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_proxy_nosec_test - test/core/end2end/fixtures/h2_proxy.c -) - - -target_include_directories(h2_proxy_nosec_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_proxy_nosec_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_nosec_tests - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_sockpair_nosec_test - test/core/end2end/fixtures/h2_sockpair.c -) - - -target_include_directories(h2_sockpair_nosec_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_sockpair_nosec_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_nosec_tests - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_sockpair+trace_nosec_test - test/core/end2end/fixtures/h2_sockpair+trace.c -) - - -target_include_directories(h2_sockpair+trace_nosec_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_sockpair+trace_nosec_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_nosec_tests - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(h2_sockpair_1byte_nosec_test - test/core/end2end/fixtures/h2_sockpair_1byte.c -) - - -target_include_directories(h2_sockpair_1byte_nosec_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_sockpair_1byte_nosec_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_nosec_tests - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) -if(_gRPC_PLATFORM_LINUX OR _gRPC_PLATFORM_MAC OR _gRPC_PLATFORM_POSIX) - -add_executable(h2_uds_nosec_test - test/core/end2end/fixtures/h2_uds.c -) - - -target_include_directories(h2_uds_nosec_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(h2_uds_nosec_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_nosec_tests - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - -endif() -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(inproc_nosec_test - test/core/end2end/fixtures/inproc.c -) - - -target_include_directories(inproc_nosec_test - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(inproc_nosec_test - ${_gRPC_ALLTARGETS_LIBRARIES} - end2end_nosec_tests - grpc_test_util_unsecure - grpc_unsecure - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(api_fuzzer_one_entry - test/core/end2end/fuzzers/api_fuzzer.c - test/core/util/one_corpus_entry_fuzzer.c -) - - -target_include_directories(api_fuzzer_one_entry - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(api_fuzzer_one_entry - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(client_fuzzer_one_entry - test/core/end2end/fuzzers/client_fuzzer.c - test/core/util/one_corpus_entry_fuzzer.c -) - - -target_include_directories(client_fuzzer_one_entry - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(client_fuzzer_one_entry - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(hpack_parser_fuzzer_test_one_entry - test/core/transport/chttp2/hpack_parser_fuzzer_test.c - test/core/util/one_corpus_entry_fuzzer.c -) - - -target_include_directories(hpack_parser_fuzzer_test_one_entry - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(hpack_parser_fuzzer_test_one_entry - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(http_request_fuzzer_test_one_entry - test/core/http/request_fuzzer.c - test/core/util/one_corpus_entry_fuzzer.c -) - - -target_include_directories(http_request_fuzzer_test_one_entry - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(http_request_fuzzer_test_one_entry - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(http_response_fuzzer_test_one_entry - test/core/http/response_fuzzer.c - test/core/util/one_corpus_entry_fuzzer.c -) - - -target_include_directories(http_response_fuzzer_test_one_entry - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(http_response_fuzzer_test_one_entry - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(json_fuzzer_test_one_entry - test/core/json/fuzzer.c - test/core/util/one_corpus_entry_fuzzer.c -) - - -target_include_directories(json_fuzzer_test_one_entry - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(json_fuzzer_test_one_entry - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(nanopb_fuzzer_response_test_one_entry - test/core/nanopb/fuzzer_response.c - test/core/util/one_corpus_entry_fuzzer.c -) - - -target_include_directories(nanopb_fuzzer_response_test_one_entry - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(nanopb_fuzzer_response_test_one_entry - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(nanopb_fuzzer_serverlist_test_one_entry - test/core/nanopb/fuzzer_serverlist.c - test/core/util/one_corpus_entry_fuzzer.c -) - - -target_include_directories(nanopb_fuzzer_serverlist_test_one_entry - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(nanopb_fuzzer_serverlist_test_one_entry - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(percent_decode_fuzzer_one_entry - test/core/slice/percent_decode_fuzzer.c - test/core/util/one_corpus_entry_fuzzer.c -) - - -target_include_directories(percent_decode_fuzzer_one_entry - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(percent_decode_fuzzer_one_entry - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(percent_encode_fuzzer_one_entry - test/core/slice/percent_encode_fuzzer.c - test/core/util/one_corpus_entry_fuzzer.c -) - - -target_include_directories(percent_encode_fuzzer_one_entry - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(percent_encode_fuzzer_one_entry - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(server_fuzzer_one_entry - test/core/end2end/fuzzers/server_fuzzer.c - test/core/util/one_corpus_entry_fuzzer.c -) - - -target_include_directories(server_fuzzer_one_entry - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(server_fuzzer_one_entry - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(ssl_server_fuzzer_one_entry - test/core/security/ssl_server_fuzzer.c - test/core/util/one_corpus_entry_fuzzer.c -) - - -target_include_directories(ssl_server_fuzzer_one_entry - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(ssl_server_fuzzer_one_entry - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) -if (gRPC_BUILD_TESTS) - -add_executable(uri_fuzzer_test_one_entry - test/core/client_channel/uri_fuzzer_test.c - test/core/util/one_corpus_entry_fuzzer.c -) - - -target_include_directories(uri_fuzzer_test_one_entry - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include - PRIVATE ${BORINGSSL_ROOT_DIR}/include - PRIVATE ${PROTOBUF_ROOT_DIR}/src - PRIVATE ${BENCHMARK_ROOT_DIR}/include - PRIVATE ${ZLIB_ROOT_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/zlib - PRIVATE ${CARES_BUILD_INCLUDE_DIR} - PRIVATE ${CARES_INCLUDE_DIR} - PRIVATE ${CARES_PLATFORM_INCLUDE_DIR} - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/cares/cares - PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/third_party/gflags/include -) - -target_link_libraries(uri_fuzzer_test_one_entry - ${_gRPC_ALLTARGETS_LIBRARIES} - grpc_test_util - grpc - gpr_test_util - gpr -) - -endif (gRPC_BUILD_TESTS) - - - - - - - -if (gRPC_INSTALL) - install(EXPORT gRPCTargets - DESTINATION ${gRPC_INSTALL_CMAKEDIR} - NAMESPACE gRPC:: - ) -endif() - -foreach(_config gRPCConfig gRPCConfigVersion) - configure_file(tools/cmake/${_config}.cmake.in - ${_config}.cmake @ONLY) - install(FILES ${CMAKE_CURRENT_BINARY_DIR}/${_config}.cmake - DESTINATION ${gRPC_INSTALL_CMAKEDIR} - ) -endforeach() diff --git a/tensorflow/contrib/cmake/patches/nsync/CMakeLists.txt b/tensorflow/contrib/cmake/patches/nsync/CMakeLists.txt index fbd89bad079c5d7f6c2909ca643f4c175428e77f..aaae18a313dd082b428654091c9411600c981ec9 100644 --- a/tensorflow/contrib/cmake/patches/nsync/CMakeLists.txt +++ b/tensorflow/contrib/cmake/patches/nsync/CMakeLists.txt @@ -61,9 +61,15 @@ if ("${NSYNC_LANGUAGE}X" STREQUAL "c++11X") ) elseif ("${CMAKE_SYSTEM_NAME}X" STREQUAL "DarwinX") include_directories ("${PROJECT_SOURCE_DIR}/platform/macos") + include_directories ("${PROJECT_SOURCE_DIR}/platform/posix") + # Some versions of MacOS, such as Sierra, require _DARWIN_C_SOURCE + # when including certin C++ standard header files, such as . + add_definitions ("-D_DARWIN_C_SOURCE") add_compile_options ("-std=c++11") set (NSYNC_OS_SRC ${NSYNC_OS_CPP_SRC} + "platform/posix/src/clock_gettime.c" + "platform/posix/src/nsync_semaphore_mutex.c" ) set (NSYNC_TEST_OS_SRC "platform/posix/src/start_thread.c" @@ -138,6 +144,10 @@ if (NOT "${NSYNC_LANGUAGE}X" STREQUAL "c++11X") elseif ("${CMAKE_SYSTEM_NAME}X" STREQUAL "DarwinX") include_directories ("${PROJECT_SOURCE_DIR}/platform/macos") set (NSYNC_POSIX ON) + set (NSYNC_OS_EXTRA_SRC + "platform/posix/src/clock_gettime.c" + "platform/posix/src/nsync_semaphore_mutex.c" + ) include_directories ("${PROJECT_SOURCE_DIR}/platform/posix") elseif ("${CMAKE_SYSTEM_NAME}X" STREQUAL "LinuxX") include_directories ("${PROJECT_SOURCE_DIR}/platform/linux") @@ -148,12 +158,21 @@ if (NOT "${NSYNC_LANGUAGE}X" STREQUAL "c++11X") elseif ("${CMAKE_SYSTEM_NAME}X" STREQUAL "NetBSDX") include_directories ("${PROJECT_SOURCE_DIR}/platform/netbsd") set (NSYNC_POSIX ON) + set (NSYNC_OS_EXTRA_SRC + "platform/posix/src/nsync_semaphore_mutex.c" + ) elseif ("${CMAKE_SYSTEM_NAME}X" STREQUAL "FreeBSDX") include_directories ("${PROJECT_SOURCE_DIR}/platform/freebsd") set (NSYNC_POSIX ON) + set (NSYNC_OS_EXTRA_SRC + "platform/posix/src/nsync_semaphore_mutex.c" + ) elseif ("${CMAKE_SYSTEM_NAME}X" STREQUAL "OpenBSDX") include_directories ("${PROJECT_SOURCE_DIR}/platform/openbsd") set (NSYNC_POSIX ON) + set (NSYNC_OS_EXTRA_SRC + "platform/posix/src/nsync_semaphore_mutex.c" + ) endif () endif () diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt new file mode 100644 index 0000000000000000000000000000000000000000..a0fca690ef6bedc5a872498583dfd0cbb55e2143 --- /dev/null +++ b/tensorflow/contrib/cmake/python_modules.txt @@ -0,0 +1,449 @@ +tensorflow +tensorflow/core +tensorflow/core/example +tensorflow/core/framework +tensorflow/core/lib +tensorflow/core/lib/core +tensorflow/core/protobuf +tensorflow/core/util +tensorflow/examples +tensorflow/examples/tutorials +tensorflow/examples/tutorials/mnist +tensorflow/python +tensorflow/python/client +tensorflow/python/data +tensorflow/python/data/ops +tensorflow/python/data/util +tensorflow/python/debug +tensorflow/python/debug/cli +tensorflow/python/debug/examples +tensorflow/python/debug/lib +tensorflow/python/debug/wrappers +tensorflow/python/eager +tensorflow/python/estimator +tensorflow/python/estimator/canned +tensorflow/python/estimator/export +tensorflow/python/estimator/inputs +tensorflow/python/estimator/inputs/queues +tensorflow/python/feature_column +tensorflow/python/framework +tensorflow/python/grappler +tensorflow/python/keras +tensorflow/python/keras/activations +tensorflow/python/keras/applications +tensorflow/python/keras/applications/inception_resnet_v2 +tensorflow/python/keras/applications/inception_v3 +tensorflow/python/keras/applications/mobilenet +tensorflow/python/keras/applications/resnet50 +tensorflow/python/keras/applications/vgg16 +tensorflow/python/keras/applications/vgg19 +tensorflow/python/keras/applications/xception +tensorflow/python/keras/backend +tensorflow/python/keras/callbacks +tensorflow/python/keras/constraints +tensorflow/python/keras/datasets +tensorflow/python/keras/datasets/boston_housing +tensorflow/python/keras/datasets/cifar10 +tensorflow/python/keras/datasets/cifar100 +tensorflow/python/keras/datasets/fashion_mnist +tensorflow/python/keras/datasets/imdb +tensorflow/python/keras/datasets/mnist +tensorflow/python/keras/datasets/reuters +tensorflow/python/keras/estimator +tensorflow/python/keras/initializers +tensorflow/python/keras/layers +tensorflow/python/keras/losses +tensorflow/python/keras/metrics +tensorflow/python/keras/models +tensorflow/python/keras/optimizers +tensorflow/python/keras/preprocessing +tensorflow/python/keras/preprocessing/image +tensorflow/python/keras/preprocessing/sequence +tensorflow/python/keras/preprocessing/text +tensorflow/python/keras/regularizers +tensorflow/python/keras/utils +tensorflow/python/keras/wrappers +tensorflow/python/keras/wrappers/scikit_learn +tensorflow/python/keras/_impl +tensorflow/python/keras/_impl/keras +tensorflow/python/keras/_impl/keras/applications +tensorflow/python/keras/_impl/keras/datasets +tensorflow/python/keras/_impl/keras/engine +tensorflow/python/keras/_impl/keras/layers +tensorflow/python/keras/_impl/keras/preprocessing +tensorflow/python/keras/_impl/keras/utils +tensorflow/python/keras/_impl/keras/wrappers +tensorflow/python/kernel_tests +tensorflow/python/kernel_tests/distributions +tensorflow/python/kernel_tests/linalg +tensorflow/python/kernel_tests/random +tensorflow/python/layers +tensorflow/python/lib +tensorflow/python/lib/core +tensorflow/python/lib/io +tensorflow/python/ops +tensorflow/python/ops/distributions +tensorflow/python/ops/linalg +tensorflow/python/ops/losses +tensorflow/python/platform +tensorflow/python/platform/default +tensorflow/python/platform/summary +tensorflow/python/profiler/ +tensorflow/python/profiler/internal +tensorflow/python/saved_model +tensorflow/python/summary +tensorflow/python/summary/writer +tensorflow/python/tools +tensorflow/python/training +tensorflow/python/user_ops +tensorflow/python/util +tensorflow/python/util/protobuf +tensorflow/tools +tensorflow/tools/graph_transforms +tensorflow/contrib +tensorflow/contrib/all_reduce +tensorflow/contrib/all_reduce/python +tensorflow/contrib/android +tensorflow/contrib/android/java +tensorflow/contrib/android/java/org +tensorflow/contrib/android/java/org/tensorflow +tensorflow/contrib/android/java/org/tensorflow/contrib +tensorflow/contrib/android/java/org/tensorflow/contrib/android +tensorflow/contrib/android/jni +tensorflow/contrib/batching +tensorflow/contrib/batching/kernels +tensorflow/contrib/batching/python +tensorflow/contrib/batching/python/ops +tensorflow/contrib/bayesflow +tensorflow/contrib/bayesflow/examples +tensorflow/contrib/bayesflow/examples/reinforce_simple +tensorflow/contrib/bayesflow/python +tensorflow/contrib/bayesflow/python/ops +tensorflow/contrib/boosted_trees +tensorflow/contrib/boosted_trees/estimator_batch +tensorflow/contrib/boosted_trees/kernels +tensorflow/contrib/boosted_trees/ops +tensorflow/contrib/boosted_trees/proto +tensorflow/contrib/boosted_trees/python +tensorflow/contrib/boosted_trees/python/ops +tensorflow/contrib/cloud +tensorflow/contrib/cloud/kernels +tensorflow/contrib/cloud/ops +tensorflow/contrib/cloud/python +tensorflow/contrib/cloud/python/ops +tensorflow/contrib/cluster_resolver +tensorflow/contrib/cluster_resolver/python +tensorflow/contrib/cluster_resolver/python/training +tensorflow/contrib/compiler +tensorflow/contrib/copy_graph +tensorflow/contrib/copy_graph/python +tensorflow/contrib/copy_graph/python/util +tensorflow/contrib/crf +tensorflow/contrib/crf/python +tensorflow/contrib/crf/python/ops +tensorflow/contrib/cudnn_rnn +tensorflow/contrib/cudnn_rnn/kernels +tensorflow/contrib/cudnn_rnn/ops +tensorflow/contrib/cudnn_rnn/python +tensorflow/contrib/cudnn_rnn/python/layers +tensorflow/contrib/cudnn_rnn/python/ops +tensorflow/contrib/data +tensorflow/contrib/data/kernels +tensorflow/contrib/data/python +tensorflow/contrib/data/python/kernel_tests +tensorflow/contrib/data/python/ops +tensorflow/contrib/decision_trees +tensorflow/contrib/decision_trees/proto +tensorflow/contrib/deprecated +tensorflow/contrib/distributions +tensorflow/contrib/distributions/python +tensorflow/contrib/distributions/python/ops +tensorflow/contrib/distributions/python/ops/bijectors +tensorflow/contrib/eager +tensorflow/contrib/eager/python +tensorflow/contrib/estimator +tensorflow/contrib/estimator/python +tensorflow/contrib/estimator/python/estimator +tensorflow/contrib/factorization +tensorflow/contrib/factorization/examples +tensorflow/contrib/factorization/kernels +tensorflow/contrib/factorization/ops +tensorflow/contrib/factorization/python +tensorflow/contrib/factorization/python/ops +tensorflow/contrib/ffmpeg +tensorflow/contrib/ffmpeg/default +tensorflow/contrib/framework +tensorflow/contrib/framework/kernels +tensorflow/contrib/framework/ops +tensorflow/contrib/framework/python +tensorflow/contrib/framework/python/framework +tensorflow/contrib/framework/python/ops +tensorflow/contrib/fused_conv +tensorflow/contrib/fused_conv/kernels +tensorflow/contrib/fused_conv/python +tensorflow/contrib/fused_conv/python/ops +tensorflow/contrib/gan +tensorflow/contrib/gan/python +tensorflow/contrib/gan/python/estimator +tensorflow/contrib/gan/python/estimator/python +tensorflow/contrib/gan/python/eval +tensorflow/contrib/gan/python/eval/python +tensorflow/contrib/gan/python/features +tensorflow/contrib/gan/python/features/python +tensorflow/contrib/gan/python/losses +tensorflow/contrib/gan/python/losses/python +tensorflow/contrib/graph_editor +tensorflow/contrib/graph_editor/examples +tensorflow/contrib/grid_rnn +tensorflow/contrib/grid_rnn/python +tensorflow/contrib/grid_rnn/python/ops +tensorflow/contrib/hooks +tensorflow/contrib/hooks/python +tensorflow/contrib/image +tensorflow/contrib/image/kernels +tensorflow/contrib/image/ops +tensorflow/contrib/image/python +tensorflow/contrib/image/python/ops +tensorflow/contrib/input_pipeline +tensorflow/contrib/input_pipeline/kernels +tensorflow/contrib/input_pipeline/ops +tensorflow/contrib/input_pipeline/python +tensorflow/contrib/input_pipeline/python/ops +tensorflow/contrib/integrate +tensorflow/contrib/integrate/python +tensorflow/contrib/integrate/python/ops +tensorflow/contrib/ios_examples +tensorflow/contrib/ios_examples/benchmark +tensorflow/contrib/ios_examples/benchmark/benchmark.xcodeproj +tensorflow/contrib/ios_examples/benchmark/data +tensorflow/contrib/ios_examples/camera +tensorflow/contrib/ios_examples/camera/camera_example.xcodeproj +tensorflow/contrib/ios_examples/camera/en.lproj +tensorflow/contrib/ios_examples/simple +tensorflow/contrib/ios_examples/simple/data +tensorflow/contrib/ios_examples/simple/tf_ios_makefile_example.xcodeproj +tensorflow/contrib/keras +tensorflow/contrib/keras/api +tensorflow/contrib/keras/api/keras +tensorflow/contrib/keras/api/keras/activations +tensorflow/contrib/keras/api/keras/applications +tensorflow/contrib/keras/api/keras/applications/inception_v3 +tensorflow/contrib/keras/api/keras/applications/mobilenet +tensorflow/contrib/keras/api/keras/applications/resnet50 +tensorflow/contrib/keras/api/keras/applications/vgg16 +tensorflow/contrib/keras/api/keras/applications/vgg19 +tensorflow/contrib/keras/api/keras/applications/xception +tensorflow/contrib/keras/api/keras/backend +tensorflow/contrib/keras/api/keras/callbacks +tensorflow/contrib/keras/api/keras/constraints +tensorflow/contrib/keras/api/keras/datasets +tensorflow/contrib/keras/api/keras/datasets/boston_housing +tensorflow/contrib/keras/api/keras/datasets/cifar10 +tensorflow/contrib/keras/api/keras/datasets/cifar100 +tensorflow/contrib/keras/api/keras/datasets/imdb +tensorflow/contrib/keras/api/keras/datasets/mnist +tensorflow/contrib/keras/api/keras/datasets/reuters +tensorflow/contrib/keras/api/keras/initializers +tensorflow/contrib/keras/api/keras/layers +tensorflow/contrib/keras/api/keras/losses +tensorflow/contrib/keras/api/keras/metrics +tensorflow/contrib/keras/api/keras/models +tensorflow/contrib/keras/api/keras/optimizers +tensorflow/contrib/keras/api/keras/preprocessing +tensorflow/contrib/keras/api/keras/preprocessing/image +tensorflow/contrib/keras/api/keras/preprocessing/sequence +tensorflow/contrib/keras/api/keras/preprocessing/text +tensorflow/contrib/keras/api/keras/regularizers +tensorflow/contrib/keras/api/keras/utils +tensorflow/contrib/keras/api/keras/wrappers +tensorflow/contrib/keras/api/keras/wrappers/scikit_learn +tensorflow/contrib/kernel_methods +tensorflow/contrib/kernel_methods/python +tensorflow/contrib/kernel_methods/python/mappers +tensorflow/contrib/kfac +tensorflow/contrib/kfac/examples +tensorflow/contrib/kfac/python +tensorflow/contrib/kfac/python/ops +tensorflow/contrib/labeled_tensor +tensorflow/contrib/labeled_tensor/python +tensorflow/contrib/labeled_tensor/python/ops +tensorflow/contrib/layers +tensorflow/contrib/layers/kernels +tensorflow/contrib/layers/ops +tensorflow/contrib/layers/python +tensorflow/contrib/layers/python/layers +tensorflow/contrib/layers/python/ops +tensorflow/contrib/learn +tensorflow/contrib/learn/python +tensorflow/contrib/learn/python/learn +tensorflow/contrib/learn/python/learn/dataframe +tensorflow/contrib/learn/python/learn/dataframe/queues +tensorflow/contrib/learn/python/learn/dataframe/transforms +tensorflow/contrib/learn/python/learn/datasets +tensorflow/contrib/learn/python/learn/datasets/data +tensorflow/contrib/learn/python/learn/estimators +tensorflow/contrib/learn/python/learn/learn_io +tensorflow/contrib/learn/python/learn/ops +tensorflow/contrib/learn/python/learn/preprocessing +tensorflow/contrib/learn/python/learn/utils +tensorflow/contrib/legacy_seq2seq +tensorflow/contrib/legacy_seq2seq/python +tensorflow/contrib/legacy_seq2seq/python/ops +tensorflow/contrib/linalg +tensorflow/contrib/linalg/python +tensorflow/contrib/linalg/python/ops +tensorflow/contrib/linear_optimizer +tensorflow/contrib/linear_optimizer/kernels +tensorflow/contrib/linear_optimizer/kernels/g3doc +tensorflow/contrib/linear_optimizer/python +tensorflow/contrib/linear_optimizer/python/ops +tensorflow/contrib/lookup +tensorflow/contrib/losses +tensorflow/contrib/losses/python +tensorflow/contrib/losses/python/losses +tensorflow/contrib/losses/python/metric_learning +tensorflow/contrib/makefile +tensorflow/contrib/memory_stats +tensorflow/contrib/memory_stats/kernels +tensorflow/contrib/memory_stats/ops +tensorflow/contrib/memory_stats/python +tensorflow/contrib/memory_stats/python/ops +tensorflow/contrib/meta_graph_transform +tensorflow/contrib/metrics +tensorflow/contrib/metrics/ops +tensorflow/contrib/metrics/python +tensorflow/contrib/metrics/python/metrics +tensorflow/contrib/metrics/python/ops +tensorflow/contrib/model_pruning +tensorflow/contrib/model_pruning/examples +tensorflow/contrib/model_pruning/examples/cifar10 +tensorflow/contrib/model_pruning/python +tensorflow/contrib/model_pruning/python/layers +tensorflow/contrib/nccl +tensorflow/contrib/nccl/kernels +tensorflow/contrib/nccl/ops +tensorflow/contrib/nccl/python +tensorflow/contrib/nccl/python/ops +tensorflow/contrib/ndlstm +tensorflow/contrib/ndlstm/python +tensorflow/contrib/nearest_neighbor/kernels +tensorflow/contrib/nearest_neighbor/ops +tensorflow/contrib/nearest_neighbor/python +tensorflow/contrib/nearest_neighbor/python/ops +tensorflow/contrib/nn +tensorflow/contrib/nn/python +tensorflow/contrib/nn/python/ops +tensorflow/contrib/opt +tensorflow/contrib/opt/python +tensorflow/contrib/opt/python/training +tensorflow/contrib/pi_examples +tensorflow/contrib/pi_examples/camera +tensorflow/contrib/pi_examples/label_image +tensorflow/contrib/pi_examples/label_image/data +tensorflow/contrib/periodic_resample +tensorflow/contrib/periodic_resample/python +tensorflow/contrib/periodic_resample/python/kernels +tensorflow/contrib/periodic_resample/python/ops +tensorflow/contrib/predictor +tensorflow/contrib/quantization +tensorflow/contrib/quantization/python +tensorflow/contrib/quantize +tensorflow/contrib/quantize/python +tensorflow/contrib/receptive_field +tensorflow/contrib/receptive_field/python +tensorflow/contrib/reduce_slice_ops +tensorflow/contrib/reduce_slice_ops/kernels +tensorflow/contrib/reduce_slice_ops/ops +tensorflow/contrib/reduce_slice_ops/python +tensorflow/contrib/reduce_slice_ops/python/ops +tensorflow/contrib/remote_fused_graph/pylib +tensorflow/contrib/remote_fused_graph/pylib/python +tensorflow/contrib/remote_fused_graph/pylib/python/ops +tensorflow/contrib/resampler +tensorflow/contrib/resampler/kernels +tensorflow/contrib/resampler/ops +tensorflow/contrib/resampler/python +tensorflow/contrib/resampler/python/ops +tensorflow/contrib/rnn +tensorflow/contrib/rnn/kernels +tensorflow/contrib/rnn/ops +tensorflow/contrib/rnn/python +tensorflow/contrib/rnn/python/kernel_tests +tensorflow/contrib/rnn/python/ops +tensorflow/contrib/saved_model +tensorflow/contrib/saved_model/python +tensorflow/contrib/saved_model/python/saved_model +tensorflow/contrib/seq2seq +tensorflow/contrib/seq2seq/kernels +tensorflow/contrib/seq2seq/ops +tensorflow/contrib/seq2seq/python +tensorflow/contrib/seq2seq/python/ops +tensorflow/contrib/session_bundle +tensorflow/contrib/session_bundle/example +tensorflow/contrib/signal +tensorflow/contrib/signal/python +tensorflow/contrib/signal/python/ops +tensorflow/contrib/slim +tensorflow/contrib/slim/python +tensorflow/contrib/slim/python/slim +tensorflow/contrib/slim/python/slim/data +tensorflow/contrib/slim/python/slim/nets +tensorflow/contrib/solvers +tensorflow/contrib/solvers/python +tensorflow/contrib/solvers/python/ops +tensorflow/contrib/sparsemax +tensorflow/contrib/sparsemax/python +tensorflow/contrib/sparsemax/python/ops +tensorflow/contrib/specs +tensorflow/contrib/specs/python +tensorflow/contrib/staging +tensorflow/contrib/stat_summarizer +tensorflow/contrib/stat_summarizer/python +tensorflow/contrib/stateless +tensorflow/contrib/stateless/python +tensorflow/contrib/summary +tensorflow/contrib/tensorboard +tensorflow/contrib/tensorboard/plugins +tensorflow/contrib/tensorboard/plugins/projector +tensorflow/contrib/tensor_forest +tensorflow/contrib/tensor_forest/client +tensorflow/contrib/tensor_forest/core +tensorflow/contrib/tensor_forest/core/ops +tensorflow/contrib/tensor_forest/data +tensorflow/contrib/tensor_forest/hybrid +tensorflow/contrib/tensor_forest/hybrid/core +tensorflow/contrib/tensor_forest/hybrid/core/ops +tensorflow/contrib/tensor_forest/hybrid/ops +tensorflow/contrib/tensor_forest/hybrid/python +tensorflow/contrib/tensor_forest/hybrid/python/layers +tensorflow/contrib/tensor_forest/hybrid/python/models +tensorflow/contrib/tensor_forest/hybrid/python/ops +tensorflow/contrib/tensor_forest/kernels +tensorflow/contrib/tensor_forest/python +tensorflow/contrib/tensor_forest/python/ops +tensorflow/contrib/testing +tensorflow/contrib/testing/python +tensorflow/contrib/testing/python/framework +tensorflow/contrib/text +tensorflow/contrib/text/kernels +tensorflow/contrib/text/ops +tensorflow/contrib/text/python +tensorflow/contrib/text/python/ops +tensorflow/contrib/tfprof +tensorflow/contrib/timeseries +tensorflow/contrib/timeseries/examples +tensorflow/contrib/timeseries/examples/data +tensorflow/contrib/timeseries/python +tensorflow/contrib/timeseries/python/timeseries +tensorflow/contrib/timeseries/python/timeseries/state_space_models +tensorflow/contrib/tpu +tensorflow/contrib/tpu/ops +tensorflow/contrib/tpu/profiler +tensorflow/contrib/tpu/python +tensorflow/contrib/tpu/python/ops +tensorflow/contrib/tpu/python/profiler +tensorflow/contrib/tpu/python/tpu +tensorflow/contrib/training +tensorflow/contrib/training/python +tensorflow/contrib/training/python/training +tensorflow/contrib/util diff --git a/tensorflow/contrib/cmake/python_protos.txt b/tensorflow/contrib/cmake/python_protos.txt new file mode 100644 index 0000000000000000000000000000000000000000..8a9c406d8b118c10ddcaafb0e4fc242aa79cdb57 --- /dev/null +++ b/tensorflow/contrib/cmake/python_protos.txt @@ -0,0 +1,19 @@ +tensorflow/core +tensorflow/core/profiler +tensorflow/python +tensorflow/contrib/boosted_trees/proto +tensorflow/contrib/cloud/kernels +tensorflow/contrib/decision_trees/proto +tensorflow/contrib/gdr +tensorflow/contrib/lite/toco +tensorflow/contrib/mpi +tensorflow/contrib/mpi_collectives +tensorflow/contrib/session_bundle +tensorflow/contrib/tensor_forest/proto +tensorflow/contrib/tensorboard/graph_explorer/proto +tensorflow/contrib/tensorboard/plugins/projector +tensorflow/contrib/tensorboard/plugins/trace +tensorflow/contrib/tpu/proto +tensorflow/contrib/tpu/profiler +tensorflow/contrib/training/python/training +tensorflow/contrib/verbs diff --git a/tensorflow/contrib/cmake/python_protos_cc.txt b/tensorflow/contrib/cmake/python_protos_cc.txt new file mode 100644 index 0000000000000000000000000000000000000000..d4a257b25c814a1464308d0e6ce3ce65d21f6a36 --- /dev/null +++ b/tensorflow/contrib/cmake/python_protos_cc.txt @@ -0,0 +1,5 @@ +tensorflow/core/profiler +tensorflow/python +tensorflow/contrib/session_bundle +tensorflow/contrib/tensorboard +tensorflow/contrib/training diff --git a/tensorflow/contrib/cmake/tf_c.cmake b/tensorflow/contrib/cmake/tf_c.cmake index f3882e8cf76c6dad31371fc340de959c05411a2f..c6a15f2ca075c8de96786a580c7ddb89541df5bc 100644 --- a/tensorflow/contrib/cmake/tf_c.cmake +++ b/tensorflow/contrib/cmake/tf_c.cmake @@ -21,7 +21,6 @@ set(tf_c_srcs "${tensorflow_source_dir}/tensorflow/c/c_api_function.cc" "${tensorflow_source_dir}/tensorflow/c/eager/c_api.cc" "${tensorflow_source_dir}/tensorflow/c/eager/c_api.h" - "${tensorflow_source_dir}/tensorflow/c/eager/tape.cc" "${tensorflow_source_dir}/tensorflow/c/eager/tape.h" "${tensorflow_source_dir}/tensorflow/c/eager/runtime.cc" "${tensorflow_source_dir}/tensorflow/c/eager/runtime.h" @@ -47,4 +46,5 @@ add_dependencies( tf_c_python_api tf_c tf_core_lib + tf_core_framework tf_protos_cc) diff --git a/tensorflow/contrib/cmake/tf_cc_ops.cmake b/tensorflow/contrib/cmake/tf_cc_ops.cmake index a5f5ae5478f3ca82f428d494f2822d0c69064b98..6e2ac203f9a7f96cb14752a91483840a9eb6b451 100644 --- a/tensorflow/contrib/cmake/tf_cc_ops.cmake +++ b/tensorflow/contrib/cmake/tf_cc_ops.cmake @@ -83,7 +83,7 @@ foreach(tf_cc_op_lib_name ${tf_cc_op_lib_names}) ${cc_ops_target_dir}/${tf_cc_op_lib_name}.cc ${cc_ops_target_dir}/${tf_cc_op_lib_name}_internal.h ${cc_ops_target_dir}/${tf_cc_op_lib_name}_internal.cc - COMMAND ${tf_cc_op_lib_name}_gen_cc ${cc_ops_target_dir}/${tf_cc_op_lib_name}.h ${cc_ops_target_dir}/${tf_cc_op_lib_name}.cc ${tensorflow_source_dir}/tensorflow/cc/ops/op_gen_overrides.pbtxt ${cc_ops_include_internal} + COMMAND ${tf_cc_op_lib_name}_gen_cc ${cc_ops_target_dir}/${tf_cc_op_lib_name}.h ${cc_ops_target_dir}/${tf_cc_op_lib_name}.cc ${tensorflow_source_dir}/tensorflow/cc/ops/op_gen_overrides.pbtxt ${cc_ops_include_internal} ${tensorflow_source_dir}/tensorflow/core/api_def/base_api DEPENDS ${tf_cc_op_lib_name}_gen_cc create_cc_ops_header_dir ) @@ -148,7 +148,11 @@ list(REMOVE_ITEM tf_cc_srcs ${tf_cc_test_srcs}) add_library(tf_cc OBJECT ${tf_cc_srcs}) add_dependencies(tf_cc tf_cc_framework tf_cc_ops) -set (pywrap_tensorflow_lib "${CMAKE_CURRENT_BINARY_DIR}/${CMAKE_BUILD_TYPE}/pywrap_tensorflow_internal.lib") +if (WIN32) + set (pywrap_tensorflow_lib "${CMAKE_CURRENT_BINARY_DIR}/${CMAKE_BUILD_TYPE}/pywrap_tensorflow_internal.lib") +else (WIN32) + set (pywrap_tensorflow_lib "${CMAKE_CURRENT_BINARY_DIR}/libpywrap_tensorflow_internal.so") +endif (WIN32) add_custom_target(tf_extension_ops) function(AddUserOps) @@ -164,15 +168,13 @@ function(AddUserOps) # create shared library from source and cuda obj add_library(${_AT_TARGET} SHARED ${_AT_SOURCES} ${gpu_lib}) target_link_libraries(${_AT_TARGET} ${pywrap_tensorflow_lib}) - if(WIN32) - if (tensorflow_ENABLE_GPU AND _AT_GPUSOURCES) - # some ops call out to cuda directly; need to link libs for the cuda dlls - target_link_libraries(${_AT_TARGET} ${CUDA_LIBRARIES}) - endif() - if (_AT_DISTCOPY) - add_custom_command(TARGET ${_AT_TARGET} POST_BUILD - COMMAND ${CMAKE_COMMAND} -E copy $ ${_AT_DISTCOPY}/) - endif() + if (tensorflow_ENABLE_GPU AND _AT_GPUSOURCES) + # some ops call out to cuda directly; need to link libs for the cuda dlls + target_link_libraries(${_AT_TARGET} ${CUDA_LIBRARIES}) + endif() + if (_AT_DISTCOPY) + add_custom_command(TARGET ${_AT_TARGET} POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy $ ${_AT_DISTCOPY}/) endif() if (_AT_DEPENDS) add_dependencies(${_AT_TARGET} ${_AT_DEPENDS}) @@ -180,9 +182,19 @@ function(AddUserOps) # make sure TF_COMPILE_LIBRARY is not defined for this target get_target_property(target_compile_flags ${_AT_TARGET} COMPILE_FLAGS) if(target_compile_flags STREQUAL "target_compile_flags-NOTFOUND") - set(target_compile_flags "/UTF_COMPILE_LIBRARY") + if (WIN32) + set(target_compile_flags "/UTF_COMPILE_LIBRARY") + else (WIN32) + # gcc uses UTF as default + set(target_compile_flags "-finput-charset=UTF-8") + endif (WIN32) else() - set(target_compile_flags "${target_compile_flags} /UTF_COMPILE_LIBRARY") + if (WIN32) + set(target_compile_flags "${target_compile_flags} /UTF_COMPILE_LIBRARY") + else (WIN32) + # gcc uses UTF as default + set(target_compile_flags "${target_compile_flags} -finput-charset=UTF-8") + endif (WIN32) endif() set_target_properties(${_AT_TARGET} PROPERTIES COMPILE_FLAGS ${target_compile_flags}) add_dependencies(tf_extension_ops ${_AT_TARGET}) diff --git a/tensorflow/contrib/cmake/tf_core_cpu.cmake b/tensorflow/contrib/cmake/tf_core_cpu.cmake index 5c01ca382fb9cc7a01a6f2b60a510c59f0aa7119..e4213ea2a47da2a7381cccd0504235ad62018d4e 100644 --- a/tensorflow/contrib/cmake/tf_core_cpu.cmake +++ b/tensorflow/contrib/cmake/tf_core_cpu.cmake @@ -63,7 +63,7 @@ if (tensorflow_ENABLE_GPU) file(GLOB_RECURSE tf_core_gpu_srcs "${tensorflow_source_dir}/tensorflow/core/common_runtime/gpu/*.cc" "${tensorflow_source_dir}/tensorflow/core/platform/default/gpu/cupti_wrapper.cc" - "${tensorflow_source_dir}/tensorflow/core/platform/default/gpu_tracer.cc" + "${tensorflow_source_dir}/tensorflow/core/platform/default/device_tracer.cc" "${tensorflow_source_dir}/tensorflow/core/common_runtime/gpu_device_factory.cc" "${tensorflow_source_dir}/tensorflow/core/grappler/devices.h" "${tensorflow_source_dir}/tensorflow/core/grappler/devices.cc" diff --git a/tensorflow/contrib/cmake/tf_core_framework.cmake b/tensorflow/contrib/cmake/tf_core_framework.cmake index c3dc8531bb9f0164f06841d9715f227202fdb7c9..5ec1a8d04fa41c6b36400fc0998af77592866150 100644 --- a/tensorflow/contrib/cmake/tf_core_framework.cmake +++ b/tensorflow/contrib/cmake/tf_core_framework.cmake @@ -211,7 +211,7 @@ if (NOT tensorflow_ENABLE_GPU) list(REMOVE_ITEM tf_core_platform_srcs ${tf_core_platform_gpu_srcs}) else() file(GLOB tf_core_platform_srcs_exclude - "${tensorflow_source_dir}/tensorflow/core/platform/default/gpu_tracer.cc") + "${tensorflow_source_dir}/tensorflow/core/platform/default/device_tracer.cc") list(REMOVE_ITEM tf_core_platform_srcs ${tf_core_platform_srcs_exclude}) endif() @@ -301,6 +301,8 @@ file(GLOB_RECURSE tf_core_framework_srcs "${tensorflow_source_dir}/tensorflow/core/common_runtime/session.cc" "${tensorflow_source_dir}/tensorflow/core/common_runtime/session_factory.cc" "${tensorflow_source_dir}/tensorflow/core/common_runtime/session_options.cc" + "${tensorflow_source_dir}/tensorflow/contrib/tensorboard/db/*.cc" + "${tensorflow_source_dir}/tensorflow/contrib/tensorboard/db/*.h" "${tensorflow_source_dir}/public/*.h" ) @@ -314,6 +316,7 @@ file(GLOB_RECURSE tf_core_framework_exclude_srcs "${tensorflow_source_dir}/tensorflow/core/util/*test*.h" "${tensorflow_source_dir}/tensorflow/core/util/*test*.cc" "${tensorflow_source_dir}/tensorflow/core/util/*main.cc" + "${tensorflow_source_dir}/tensorflow/contrib/tensorboard/db/*test*.cc" ) list(REMOVE_ITEM tf_core_framework_srcs ${tf_core_framework_exclude_srcs}) diff --git a/tensorflow/contrib/cmake/tf_core_kernels.cmake b/tensorflow/contrib/cmake/tf_core_kernels.cmake index f978c8ccd5a454ca4a89de0ab5d757b566295c60..eb6bf567aa7dc2e87f3d5ce462a7680fc9850bbf 100644 --- a/tensorflow/contrib/cmake/tf_core_kernels.cmake +++ b/tensorflow/contrib/cmake/tf_core_kernels.cmake @@ -55,10 +55,6 @@ if(tensorflow_BUILD_CONTRIB_KERNELS) "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable.cc" "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/lib/utils/tensor_utils.cc" "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/lib/learner/common/partitioners/example_partitioner.cc" - "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/bias-feature-column-handler.cc" - "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/categorical-feature-column-handler.cc" - "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/dense-quantized-feature-column-handler.cc" - "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/sparse-quantized-feature-column-handler.cc" "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.cc" "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/lib/trees/decision_tree.cc" "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/model_ops.cc" @@ -154,9 +150,6 @@ list(REMOVE_ITEM tf_core_kernels_srcs ${tf_core_kernels_exclude_srcs}) if(WIN32) file(GLOB_RECURSE tf_core_kernels_windows_exclude_srcs # not working on windows yet - "${tensorflow_source_dir}/tensorflow/core/kernels/meta_support.*" - "${tensorflow_source_dir}/tensorflow/core/kernels/*quantiz*.h" - "${tensorflow_source_dir}/tensorflow/core/kernels/*quantiz*.cc" "${tensorflow_source_dir}/tensorflow/core/kernels/neon/*" # not in core - those are loaded dynamically as dll "${tensorflow_source_dir}/tensorflow/contrib/nearest_neighbor/kernels/hyperplane_lsh_probes.cc" @@ -183,6 +176,7 @@ file(GLOB_RECURSE tf_core_gpu_kernels_srcs "${tensorflow_source_dir}/tensorflow/contrib/image/kernels/*.cu.cc" "${tensorflow_source_dir}/tensorflow/contrib/rnn/kernels/*.cu.cc" "${tensorflow_source_dir}/tensorflow/contrib/seq2seq/kernels/*.cu.cc" + "${tensorflow_source_dir}/tensorflow/contrib/resampler/kernels/*.cu.cc" ) if(WIN32 AND tensorflow_ENABLE_GPU) @@ -206,16 +200,16 @@ endif(WIN32 AND tensorflow_ENABLE_GPU) add_library(tf_core_kernels OBJECT ${tf_core_kernels_srcs}) add_dependencies(tf_core_kernels tf_core_cpu) -if(WIN32) +if (WIN32) target_compile_options(tf_core_kernels PRIVATE /MP) - if (tensorflow_ENABLE_GPU) - set_source_files_properties(${tf_core_gpu_kernels_srcs} PROPERTIES CUDA_SOURCE_PROPERTY_FORMAT OBJ) - set(tf_core_gpu_kernels_lib tf_core_gpu_kernels) - cuda_add_library(${tf_core_gpu_kernels_lib} ${tf_core_gpu_kernels_srcs}) - set_target_properties(${tf_core_gpu_kernels_lib} - PROPERTIES DEBUG_POSTFIX "" - COMPILE_FLAGS "${TF_REGULAR_CXX_FLAGS}" - ) - add_dependencies(${tf_core_gpu_kernels_lib} tf_core_cpu) - endif() +endif (WIN32) +if (tensorflow_ENABLE_GPU) + set_source_files_properties(${tf_core_gpu_kernels_srcs} PROPERTIES CUDA_SOURCE_PROPERTY_FORMAT OBJ) + set(tf_core_gpu_kernels_lib tf_core_gpu_kernels) + cuda_add_library(${tf_core_gpu_kernels_lib} ${tf_core_gpu_kernels_srcs}) + set_target_properties(${tf_core_gpu_kernels_lib} + PROPERTIES DEBUG_POSTFIX "" + COMPILE_FLAGS "${TF_REGULAR_CXX_FLAGS}" + ) + add_dependencies(${tf_core_gpu_kernels_lib} tf_core_cpu) endif() diff --git a/tensorflow/contrib/cmake/tf_core_ops.cmake b/tensorflow/contrib/cmake/tf_core_ops.cmake index 4a61ed7a3548b1992ddc71acb8a7761e252296ea..e8c2cd347327843d10d13c1d24a800ff776aa8c1 100644 --- a/tensorflow/contrib/cmake/tf_core_ops.cmake +++ b/tensorflow/contrib/cmake/tf_core_ops.cmake @@ -92,6 +92,7 @@ GENERATE_CONTRIB_OP_LIBRARY(image_sirds "${tensorflow_source_dir}/tensorflow/con GENERATE_CONTRIB_OP_LIBRARY(layers_sparse_feature_cross "${tensorflow_source_dir}/tensorflow/contrib/layers/ops/sparse_feature_cross_op.cc") GENERATE_CONTRIB_OP_LIBRARY(memory_stats "${tensorflow_source_dir}/tensorflow/contrib/memory_stats/ops/memory_stats_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(nccl "${tensorflow_source_dir}/tensorflow/contrib/nccl/ops/nccl_ops.cc") +GENERATE_CONTRIB_OP_LIBRARY(periodic_resample "${tensorflow_source_dir}/tensorflow/contrib/periodic_resample/ops/array_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(nearest_neighbor "${tensorflow_source_dir}/tensorflow/contrib/nearest_neighbor/ops/nearest_neighbor_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(resampler "${tensorflow_source_dir}/tensorflow/contrib/resampler/ops/resampler_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(rnn_gru "${tensorflow_source_dir}/tensorflow/contrib/rnn/ops/gru_ops.cc") diff --git a/tensorflow/contrib/cmake/tf_grappler.cmake b/tensorflow/contrib/cmake/tf_grappler.cmake index a7841c98e83ec8c3eb91edfd9d639e169cb5f440..410490531a300c091afdd857d7f2d4e789a4c80e 100644 --- a/tensorflow/contrib/cmake/tf_grappler.cmake +++ b/tensorflow/contrib/cmake/tf_grappler.cmake @@ -23,7 +23,7 @@ file(GLOB tf_grappler_srcs "${tensorflow_source_dir}/tensorflow/python/grappler/model_analyzer.cc" "${tensorflow_source_dir}/tensorflow/python/grappler/model_analyzer.h" ) - + add_library(tf_grappler OBJECT ${tf_grappler_srcs}) add_dependencies(tf_grappler tf_core_cpu) \ No newline at end of file diff --git a/tensorflow/contrib/cmake/tf_label_image_example.cmake b/tensorflow/contrib/cmake/tf_label_image_example.cmake index 0d3a4699ebb102257e8a4a816652c90ffff42d92..7f2f60b0897f62d335416f4fcffd91c1e629cf28 100644 --- a/tensorflow/contrib/cmake/tf_label_image_example.cmake +++ b/tensorflow/contrib/cmake/tf_label_image_example.cmake @@ -34,3 +34,8 @@ target_link_libraries(tf_label_image_example PUBLIC ${tf_core_gpu_kernels_lib} ${tensorflow_EXTERNAL_LIBRARIES} ) + +install(TARGETS tf_label_image_example + RUNTIME DESTINATION bin + LIBRARY DESTINATION lib + ARCHIVE DESTINATION lib) \ No newline at end of file diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index 277818b159062da4ba6efaacbe006da623c8619c..8db6929e31a1a5f5c793721f455a664bd6741b06 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -120,32 +120,34 @@ function(RELATIVE_PROTOBUF_GENERATE_CPP SRCS HDRS ROOT_DIR) set(${HDRS} ${${HDRS}} PARENT_SCOPE) endfunction() -file(GLOB_RECURSE tf_protos_python_srcs RELATIVE ${tensorflow_source_dir} - "${tensorflow_source_dir}/tensorflow/core/*.proto" - "${tensorflow_source_dir}/tensorflow/core/profiler/*.proto" - "${tensorflow_source_dir}/tensorflow/python/*.proto" - "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/proto/*.proto" - "${tensorflow_source_dir}/tensorflow/contrib/decision_trees/proto/*.proto" - "${tensorflow_source_dir}/tensorflow/contrib/session_bundle/*.proto" - "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/proto/*.proto" - "${tensorflow_source_dir}/tensorflow/contrib/tensorboard/*.proto" - "${tensorflow_source_dir}/tensorflow/contrib/tpu/profiler/*.proto" - "${tensorflow_source_dir}/tensorflow/contrib/training/*.proto" -) +FILE(READ python_protos.txt python_protos) +# Convert file contents into a CMake list (where each element in the list is one line of the file) +STRING(REGEX REPLACE ";" "\\\\;" python_protos "${python_protos}") +STRING(REGEX REPLACE "\n" ";" python_protos "${python_protos}") + +foreach(python_proto ${python_protos}) + file(GLOB_RECURSE tf_python_protos_src RELATIVE ${tensorflow_source_dir} + "${tensorflow_source_dir}/${python_proto}/*.proto" + ) + list(APPEND tf_python_protos_srcs ${tf_python_protos_src}) +endforeach(python_proto) + RELATIVE_PROTOBUF_GENERATE_PYTHON( - ${tensorflow_source_dir} PYTHON_PROTO_GENFILES ${tf_protos_python_srcs} + ${tensorflow_source_dir} PYTHON_PROTO_GENFILES ${tf_python_protos_srcs} ) -# NOTE(mrry): Avoid regenerating the tensorflow/core protos because this -# can cause benign-but-failing-on-Windows-due-to-file-locking conflicts -# when two rules attempt to generate the same file. -file(GLOB_RECURSE tf_python_protos_cc_srcs RELATIVE ${tensorflow_source_dir} - "${tensorflow_source_dir}/tensorflow/core/profiler/*.proto" - "${tensorflow_source_dir}/tensorflow/python/*.proto" - "${tensorflow_source_dir}/tensorflow/contrib/session_bundle/*.proto" - "${tensorflow_source_dir}/tensorflow/contrib/tensorboard/*.proto" - "${tensorflow_source_dir}/tensorflow/contrib/training/*.proto" -) +FILE(READ python_protos_cc.txt python_protos_cc) +# Convert file contents into a CMake list (where each element in the list is one line of the file) +STRING(REGEX REPLACE ";" "\\\\;" python_protos_cc "${python_protos_cc}") +STRING(REGEX REPLACE "\n" ";" python_protos_cc "${python_protos_cc}") + +foreach(python_proto_cc ${python_protos_cc}) + file(GLOB_RECURSE tf_python_protos_cc_src RELATIVE ${tensorflow_source_dir} + "${tensorflow_source_dir}/${python_proto_cc}/*.proto" + ) + list(APPEND tf_python_protos_cc_srcs ${tf_python_protos_cc_src}) +endforeach(python_proto_cc) + RELATIVE_PROTOBUF_GENERATE_CPP(PROTO_SRCS PROTO_HDRS ${tensorflow_source_dir} ${tf_python_protos_cc_srcs} ) @@ -191,458 +193,28 @@ function(add_python_module MODULE_NAME) endif() endfunction() -add_python_module("tensorflow") -add_python_module("tensorflow/core") -add_python_module("tensorflow/core/example") -add_python_module("tensorflow/core/framework") -add_python_module("tensorflow/core/lib") -add_python_module("tensorflow/core/lib/core") -add_python_module("tensorflow/core/protobuf") -add_python_module("tensorflow/core/util") -add_python_module("tensorflow/examples") -add_python_module("tensorflow/examples/tutorials") -add_python_module("tensorflow/examples/tutorials/mnist") -add_python_module("tensorflow/python") -add_python_module("tensorflow/python/client") -add_python_module("tensorflow/python/data") -add_python_module("tensorflow/python/data/ops") -add_python_module("tensorflow/python/data/util") -add_python_module("tensorflow/python/debug") -add_python_module("tensorflow/python/debug/cli") -add_python_module("tensorflow/python/debug/examples") -add_python_module("tensorflow/python/debug/lib") -add_python_module("tensorflow/python/debug/wrappers") -add_python_module("tensorflow/python/eager") -add_python_module("tensorflow/python/estimator") -add_python_module("tensorflow/python/estimator/canned") -add_python_module("tensorflow/python/estimator/export") -add_python_module("tensorflow/python/estimator/inputs") -add_python_module("tensorflow/python/estimator/inputs/queues") -add_python_module("tensorflow/python/feature_column") -add_python_module("tensorflow/python/framework") -add_python_module("tensorflow/python/grappler") -add_python_module("tensorflow/python/keras") -add_python_module("tensorflow/python/keras/activations") -add_python_module("tensorflow/python/keras/applications") -add_python_module("tensorflow/python/keras/applications/inception_v3") -add_python_module("tensorflow/python/keras/applications/mobilenet") -add_python_module("tensorflow/python/keras/applications/resnet50") -add_python_module("tensorflow/python/keras/applications/vgg16") -add_python_module("tensorflow/python/keras/applications/vgg19") -add_python_module("tensorflow/python/keras/applications/xception") -add_python_module("tensorflow/python/keras/backend") -add_python_module("tensorflow/python/keras/callbacks") -add_python_module("tensorflow/python/keras/constraints") -add_python_module("tensorflow/python/keras/datasets") -add_python_module("tensorflow/python/keras/datasets/boston_housing") -add_python_module("tensorflow/python/keras/datasets/cifar10") -add_python_module("tensorflow/python/keras/datasets/cifar100") -add_python_module("tensorflow/python/keras/datasets/imdb") -add_python_module("tensorflow/python/keras/datasets/mnist") -add_python_module("tensorflow/python/keras/datasets/reuters") -add_python_module("tensorflow/python/keras/estimator") -add_python_module("tensorflow/python/keras/initializers") -add_python_module("tensorflow/python/keras/layers") -add_python_module("tensorflow/python/keras/losses") -add_python_module("tensorflow/python/keras/metrics") -add_python_module("tensorflow/python/keras/models") -add_python_module("tensorflow/python/keras/optimizers") -add_python_module("tensorflow/python/keras/preprocessing") -add_python_module("tensorflow/python/keras/preprocessing/image") -add_python_module("tensorflow/python/keras/preprocessing/sequence") -add_python_module("tensorflow/python/keras/preprocessing/text") -add_python_module("tensorflow/python/keras/regularizers") -add_python_module("tensorflow/python/keras/utils") -add_python_module("tensorflow/python/keras/wrappers") -add_python_module("tensorflow/python/keras/wrappers/scikit_learn") -add_python_module("tensorflow/python/keras/_impl") -add_python_module("tensorflow/python/keras/_impl/keras") -add_python_module("tensorflow/python/keras/_impl/keras/applications") -add_python_module("tensorflow/python/keras/_impl/keras/datasets") -add_python_module("tensorflow/python/keras/_impl/keras/engine") -add_python_module("tensorflow/python/keras/_impl/keras/layers") -add_python_module("tensorflow/python/keras/_impl/keras/preprocessing") -add_python_module("tensorflow/python/keras/_impl/keras/utils") -add_python_module("tensorflow/python/keras/_impl/keras/wrappers") -add_python_module("tensorflow/python/kernel_tests") -add_python_module("tensorflow/python/kernel_tests/distributions") -add_python_module("tensorflow/python/kernel_tests/linalg") -add_python_module("tensorflow/python/layers") -add_python_module("tensorflow/python/lib") -add_python_module("tensorflow/python/lib/core") -add_python_module("tensorflow/python/lib/io") -add_python_module("tensorflow/python/ops") -add_python_module("tensorflow/python/ops/distributions") -add_python_module("tensorflow/python/ops/linalg") -add_python_module("tensorflow/python/ops/losses") -add_python_module("tensorflow/python/platform") -add_python_module("tensorflow/python/platform/default") -add_python_module("tensorflow/python/platform/summary") -add_python_module("tensorflow/python/profiler/") -add_python_module("tensorflow/python/profiler/internal") -add_python_module("tensorflow/python/saved_model") -add_python_module("tensorflow/python/summary") -add_python_module("tensorflow/python/summary/writer") -add_python_module("tensorflow/python/tools") -add_python_module("tensorflow/python/training") -add_python_module("tensorflow/python/user_ops") -add_python_module("tensorflow/python/util") -add_python_module("tensorflow/python/util/protobuf") -add_python_module("tensorflow/tools") -add_python_module("tensorflow/tools/graph_transforms") -add_python_module("tensorflow/contrib") -add_python_module("tensorflow/contrib/all_reduce") -add_python_module("tensorflow/contrib/all_reduce/python") -add_python_module("tensorflow/contrib/android") -add_python_module("tensorflow/contrib/android/java") -add_python_module("tensorflow/contrib/android/java/org") -add_python_module("tensorflow/contrib/android/java/org/tensorflow") -add_python_module("tensorflow/contrib/android/java/org/tensorflow/contrib") -add_python_module("tensorflow/contrib/android/java/org/tensorflow/contrib/android") -add_python_module("tensorflow/contrib/android/jni") -add_python_module("tensorflow/contrib/bayesflow") -add_python_module("tensorflow/contrib/bayesflow/examples") -add_python_module("tensorflow/contrib/bayesflow/examples/reinforce_simple") -add_python_module("tensorflow/contrib/bayesflow/python") -add_python_module("tensorflow/contrib/bayesflow/python/kernel_tests") -add_python_module("tensorflow/contrib/bayesflow/python/ops") -add_python_module("tensorflow/contrib/boosted_trees") -add_python_module("tensorflow/contrib/boosted_trees/estimator_batch") -add_python_module("tensorflow/contrib/boosted_trees/ops") -add_python_module("tensorflow/contrib/boosted_trees/proto") -add_python_module("tensorflow/contrib/boosted_trees/python") -add_python_module("tensorflow/contrib/boosted_trees/python/kernel_tests") -add_python_module("tensorflow/contrib/boosted_trees/python/ops") -add_python_module("tensorflow/contrib/cloud") -add_python_module("tensorflow/contrib/cloud/kernels") -add_python_module("tensorflow/contrib/cloud/ops") -add_python_module("tensorflow/contrib/cloud/python") -add_python_module("tensorflow/contrib/cloud/python/ops") -add_python_module("tensorflow/contrib/cluster_resolver") -add_python_module("tensorflow/contrib/cluster_resolver/python") -add_python_module("tensorflow/contrib/cluster_resolver/python/training") -add_python_module("tensorflow/contrib/compiler") -add_python_module("tensorflow/contrib/copy_graph") -add_python_module("tensorflow/contrib/copy_graph/python") -add_python_module("tensorflow/contrib/copy_graph/python/util") -add_python_module("tensorflow/contrib/crf") -add_python_module("tensorflow/contrib/crf/python") -add_python_module("tensorflow/contrib/crf/python/kernel_tests") -add_python_module("tensorflow/contrib/crf/python/ops") -add_python_module("tensorflow/contrib/cudnn_rnn") -add_python_module("tensorflow/contrib/cudnn_rnn/kernels") -add_python_module("tensorflow/contrib/cudnn_rnn/ops") -add_python_module("tensorflow/contrib/cudnn_rnn/python") -add_python_module("tensorflow/contrib/cudnn_rnn/python/kernel_tests") -add_python_module("tensorflow/contrib/cudnn_rnn/python/ops") -add_python_module("tensorflow/contrib/data") -add_python_module("tensorflow/contrib/data/python") -add_python_module("tensorflow/contrib/data/python/kernel_tests") -add_python_module("tensorflow/contrib/data/python/ops") -add_python_module("tensorflow/contrib/decision_trees") -add_python_module("tensorflow/contrib/decision_trees/proto") -add_python_module("tensorflow/contrib/deprecated") -add_python_module("tensorflow/contrib/distributions") -add_python_module("tensorflow/contrib/distributions/python") -add_python_module("tensorflow/contrib/distributions/python/kernel_tests") -add_python_module("tensorflow/contrib/distributions/python/ops") -add_python_module("tensorflow/contrib/distributions/python/ops/bijectors") -add_python_module("tensorflow/contrib/eager") -add_python_module("tensorflow/contrib/eager/python") -add_python_module("tensorflow/contrib/estimator") -add_python_module("tensorflow/contrib/estimator/python") -add_python_module("tensorflow/contrib/estimator/python/estimator") -add_python_module("tensorflow/contrib/factorization") -add_python_module("tensorflow/contrib/factorization/examples") -add_python_module("tensorflow/contrib/factorization/kernels") -add_python_module("tensorflow/contrib/factorization/ops") -add_python_module("tensorflow/contrib/factorization/python") -add_python_module("tensorflow/contrib/factorization/python/kernel_tests") -add_python_module("tensorflow/contrib/factorization/python/ops") -add_python_module("tensorflow/contrib/ffmpeg") -add_python_module("tensorflow/contrib/ffmpeg/default") -add_python_module("tensorflow/contrib/ffmpeg/testdata") -add_python_module("tensorflow/contrib/framework") -add_python_module("tensorflow/contrib/framework/kernels") -add_python_module("tensorflow/contrib/framework/ops") -add_python_module("tensorflow/contrib/framework/python") -add_python_module("tensorflow/contrib/framework/python/framework") -add_python_module("tensorflow/contrib/framework/python/ops") -add_python_module("tensorflow/contrib/gan") -add_python_module("tensorflow/contrib/gan/python") -add_python_module("tensorflow/contrib/gan/python/eval") -add_python_module("tensorflow/contrib/gan/python/eval/python") -add_python_module("tensorflow/contrib/gan/python/features") -add_python_module("tensorflow/contrib/gan/python/features/python") -add_python_module("tensorflow/contrib/gan/python/estimator") -add_python_module("tensorflow/contrib/gan/python/estimator/python") -add_python_module("tensorflow/contrib/gan/python/losses") -add_python_module("tensorflow/contrib/gan/python/losses/python") -add_python_module("tensorflow/contrib/graph_editor") -add_python_module("tensorflow/contrib/graph_editor/examples") -add_python_module("tensorflow/contrib/graph_editor/tests") -add_python_module("tensorflow/contrib/grid_rnn") -add_python_module("tensorflow/contrib/grid_rnn/python") -add_python_module("tensorflow/contrib/grid_rnn/python/kernel_tests") -add_python_module("tensorflow/contrib/grid_rnn/python/ops") -add_python_module("tensorflow/contrib/hooks") -add_python_module("tensorflow/contrib/image") -add_python_module("tensorflow/contrib/image/ops") -add_python_module("tensorflow/contrib/image/python") -add_python_module("tensorflow/contrib/image/python/ops") -add_python_module("tensorflow/contrib/input_pipeline") -add_python_module("tensorflow/contrib/input_pipeline/ops") -add_python_module("tensorflow/contrib/input_pipeline/python") -add_python_module("tensorflow/contrib/input_pipeline/python/ops") -add_python_module("tensorflow/contrib/integrate") -add_python_module("tensorflow/contrib/integrate/python") -add_python_module("tensorflow/contrib/integrate/python/ops") -add_python_module("tensorflow/contrib/ios_examples") -add_python_module("tensorflow/contrib/ios_examples/benchmark") -add_python_module("tensorflow/contrib/ios_examples/benchmark/benchmark.xcodeproj") -add_python_module("tensorflow/contrib/ios_examples/benchmark/data") -add_python_module("tensorflow/contrib/ios_examples/camera") -add_python_module("tensorflow/contrib/ios_examples/camera/camera_example.xcodeproj") -add_python_module("tensorflow/contrib/ios_examples/camera/en.lproj") -add_python_module("tensorflow/contrib/ios_examples/simple") -add_python_module("tensorflow/contrib/ios_examples/simple/data") -add_python_module("tensorflow/contrib/ios_examples/simple/tf_ios_makefile_example.xcodeproj") -add_python_module("tensorflow/contrib/keras") -add_python_module("tensorflow/contrib/keras/api") -add_python_module("tensorflow/contrib/keras/api/keras") -add_python_module("tensorflow/contrib/keras/api/keras/activations") -add_python_module("tensorflow/contrib/keras/api/keras/applications") -add_python_module("tensorflow/contrib/keras/api/keras/applications/inception_v3") -add_python_module("tensorflow/contrib/keras/api/keras/applications/mobilenet") -add_python_module("tensorflow/contrib/keras/api/keras/applications/resnet50") -add_python_module("tensorflow/contrib/keras/api/keras/applications/vgg16") -add_python_module("tensorflow/contrib/keras/api/keras/applications/vgg19") -add_python_module("tensorflow/contrib/keras/api/keras/applications/xception") -add_python_module("tensorflow/contrib/keras/api/keras/backend") -add_python_module("tensorflow/contrib/keras/api/keras/callbacks") -add_python_module("tensorflow/contrib/keras/api/keras/constraints") -add_python_module("tensorflow/contrib/keras/api/keras/datasets") -add_python_module("tensorflow/contrib/keras/api/keras/datasets/boston_housing") -add_python_module("tensorflow/contrib/keras/api/keras/datasets/cifar10") -add_python_module("tensorflow/contrib/keras/api/keras/datasets/cifar100") -add_python_module("tensorflow/contrib/keras/api/keras/datasets/imdb") -add_python_module("tensorflow/contrib/keras/api/keras/datasets/mnist") -add_python_module("tensorflow/contrib/keras/api/keras/datasets/reuters") -add_python_module("tensorflow/contrib/keras/api/keras/initializers") -add_python_module("tensorflow/contrib/keras/api/keras/layers") -add_python_module("tensorflow/contrib/keras/api/keras/losses") -add_python_module("tensorflow/contrib/keras/api/keras/metrics") -add_python_module("tensorflow/contrib/keras/api/keras/models") -add_python_module("tensorflow/contrib/keras/api/keras/optimizers") -add_python_module("tensorflow/contrib/keras/api/keras/preprocessing") -add_python_module("tensorflow/contrib/keras/api/keras/preprocessing/image") -add_python_module("tensorflow/contrib/keras/api/keras/preprocessing/sequence") -add_python_module("tensorflow/contrib/keras/api/keras/preprocessing/text") -add_python_module("tensorflow/contrib/keras/api/keras/regularizers") -add_python_module("tensorflow/contrib/keras/api/keras/utils") -add_python_module("tensorflow/contrib/keras/api/keras/wrappers") -add_python_module("tensorflow/contrib/keras/api/keras/wrappers/scikit_learn") -add_python_module("tensorflow/contrib/keras/python") -add_python_module("tensorflow/contrib/keras/python/keras") -add_python_module("tensorflow/contrib/keras/python/keras/applications") -add_python_module("tensorflow/contrib/keras/python/keras/datasets") -add_python_module("tensorflow/contrib/keras/python/keras/engine") -add_python_module("tensorflow/contrib/keras/python/keras/layers") -add_python_module("tensorflow/contrib/keras/python/keras/preprocessing") -add_python_module("tensorflow/contrib/keras/python/keras/utils") -add_python_module("tensorflow/contrib/keras/python/keras/wrappers") -add_python_module("tensorflow/contrib/kernel_methods") -add_python_module("tensorflow/contrib/kernel_methods/python") -add_python_module("tensorflow/contrib/kernel_methods/python/mappers") -add_python_module("tensorflow/contrib/kfac") -add_python_module("tensorflow/contrib/kfac/examples") -add_python_module("tensorflow/contrib/kfac/python") -add_python_module("tensorflow/contrib/kfac/python/ops") -add_python_module("tensorflow/contrib/labeled_tensor") -add_python_module("tensorflow/contrib/labeled_tensor/python") -add_python_module("tensorflow/contrib/labeled_tensor/python/ops") -add_python_module("tensorflow/contrib/layers") -add_python_module("tensorflow/contrib/layers/kernels") -add_python_module("tensorflow/contrib/layers/ops") -add_python_module("tensorflow/contrib/layers/python") -add_python_module("tensorflow/contrib/layers/python/kernel_tests") -add_python_module("tensorflow/contrib/layers/python/layers") -add_python_module("tensorflow/contrib/layers/python/ops") -add_python_module("tensorflow/contrib/learn") -add_python_module("tensorflow/contrib/learn/python") -add_python_module("tensorflow/contrib/learn/python/learn") -add_python_module("tensorflow/contrib/learn/python/learn/dataframe") -add_python_module("tensorflow/contrib/learn/python/learn/dataframe/queues") -add_python_module("tensorflow/contrib/learn/python/learn/dataframe/transforms") -add_python_module("tensorflow/contrib/learn/python/learn/datasets") -add_python_module("tensorflow/contrib/learn/python/learn/datasets/data") -add_python_module("tensorflow/contrib/learn/python/learn/estimators") -add_python_module("tensorflow/contrib/learn/python/learn/learn_io") -add_python_module("tensorflow/contrib/learn/python/learn/ops") -add_python_module("tensorflow/contrib/learn/python/learn/preprocessing") -add_python_module("tensorflow/contrib/learn/python/learn/preprocessing/tests") -add_python_module("tensorflow/contrib/learn/python/learn/tests") -add_python_module("tensorflow/contrib/learn/python/learn/tests/dataframe") -add_python_module("tensorflow/contrib/learn/python/learn/utils") -add_python_module("tensorflow/contrib/legacy_seq2seq") -add_python_module("tensorflow/contrib/legacy_seq2seq/python") -add_python_module("tensorflow/contrib/legacy_seq2seq/python/ops") -add_python_module("tensorflow/contrib/linalg") -add_python_module("tensorflow/contrib/linalg/python") -add_python_module("tensorflow/contrib/linalg/python/ops") -add_python_module("tensorflow/contrib/linalg/python/kernel_tests") -add_python_module("tensorflow/contrib/linear_optimizer") -add_python_module("tensorflow/contrib/linear_optimizer/kernels") -add_python_module("tensorflow/contrib/linear_optimizer/kernels/g3doc") -add_python_module("tensorflow/contrib/linear_optimizer/python") -add_python_module("tensorflow/contrib/linear_optimizer/python/kernel_tests") -add_python_module("tensorflow/contrib/linear_optimizer/python/ops") -add_python_module("tensorflow/contrib/lookup") -add_python_module("tensorflow/contrib/losses") -add_python_module("tensorflow/contrib/losses/python") -add_python_module("tensorflow/contrib/losses/python/losses") -add_python_module("tensorflow/contrib/losses/python/metric_learning") -add_python_module("tensorflow/contrib/makefile") -add_python_module("tensorflow/contrib/makefile/test") -add_python_module("tensorflow/contrib/memory_stats") -add_python_module("tensorflow/contrib/memory_stats/kernels") -add_python_module("tensorflow/contrib/memory_stats/ops") -add_python_module("tensorflow/contrib/memory_stats/python") -add_python_module("tensorflow/contrib/memory_stats/python/kernel_tests") -add_python_module("tensorflow/contrib/memory_stats/python/ops") -add_python_module("tensorflow/contrib/meta_graph_transform") -add_python_module("tensorflow/contrib/metrics") -add_python_module("tensorflow/contrib/metrics/kernels") -add_python_module("tensorflow/contrib/metrics/ops") -add_python_module("tensorflow/contrib/metrics/python") -add_python_module("tensorflow/contrib/metrics/python/kernel_tests") -add_python_module("tensorflow/contrib/metrics/python/metrics") -add_python_module("tensorflow/contrib/metrics/python/ops") -add_python_module("tensorflow/contrib/ndlstm") -add_python_module("tensorflow/contrib/ndlstm/python") -add_python_module("tensorflow/contrib/nn") -add_python_module("tensorflow/contrib/nn/python") -add_python_module("tensorflow/contrib/nn/python/ops") -add_python_module("tensorflow/contrib/nccl") -add_python_module("tensorflow/contrib/nccl/kernels") -add_python_module("tensorflow/contrib/nccl/ops") -add_python_module("tensorflow/contrib/nccl/python") -add_python_module("tensorflow/contrib/nccl/python/ops") -add_python_module("tensorflow/contrib/nearest_neighbor/kernels") -add_python_module("tensorflow/contrib/nearest_neighbor/ops") -add_python_module("tensorflow/contrib/nearest_neighbor/python") -add_python_module("tensorflow/contrib/nearest_neighbor/python/kernel_tests") -add_python_module("tensorflow/contrib/nearest_neighbor/python/ops") -add_python_module("tensorflow/contrib/opt") -add_python_module("tensorflow/contrib/opt/python") -add_python_module("tensorflow/contrib/opt/python/training") -add_python_module("tensorflow/contrib/pi_examples") -add_python_module("tensorflow/contrib/pi_examples/camera") -add_python_module("tensorflow/contrib/pi_examples/label_image") -add_python_module("tensorflow/contrib/pi_examples/label_image/data") -add_python_module("tensorflow/contrib/predictor") -add_python_module("tensorflow/contrib/quantization") -add_python_module("tensorflow/contrib/quantization/python") -add_python_module("tensorflow/contrib/quantize") -add_python_module("tensorflow/contrib/quantize/python") -add_python_module("tensorflow/contrib/remote_fused_graph/pylib") -add_python_module("tensorflow/contrib/remote_fused_graph/pylib/python") -add_python_module("tensorflow/contrib/remote_fused_graph/pylib/python/ops") -add_python_module("tensorflow/contrib/resampler") -add_python_module("tensorflow/contrib/resampler/kernels") -add_python_module("tensorflow/contrib/resampler/ops") -add_python_module("tensorflow/contrib/resampler/python") -add_python_module("tensorflow/contrib/resampler/python/ops") -add_python_module("tensorflow/contrib/rnn") -add_python_module("tensorflow/contrib/rnn/kernels") -add_python_module("tensorflow/contrib/rnn/ops") -add_python_module("tensorflow/contrib/rnn/python") -add_python_module("tensorflow/contrib/rnn/python/kernel_tests") -add_python_module("tensorflow/contrib/rnn/python/ops") -add_python_module("tensorflow/contrib/saved_model") -add_python_module("tensorflow/contrib/saved_model/python") -add_python_module("tensorflow/contrib/saved_model/python/saved_model") -add_python_module("tensorflow/contrib/seq2seq") -add_python_module("tensorflow/contrib/seq2seq/kernels") -add_python_module("tensorflow/contrib/seq2seq/ops") -add_python_module("tensorflow/contrib/seq2seq/python") -add_python_module("tensorflow/contrib/seq2seq/python/kernel_tests") -add_python_module("tensorflow/contrib/seq2seq/python/ops") -add_python_module("tensorflow/contrib/session_bundle") -add_python_module("tensorflow/contrib/session_bundle/example") -add_python_module("tensorflow/contrib/session_bundle/testdata") -add_python_module("tensorflow/contrib/signal") -add_python_module("tensorflow/contrib/signal/python") -add_python_module("tensorflow/contrib/signal/python/ops") -add_python_module("tensorflow/contrib/slim") -add_python_module("tensorflow/contrib/slim/python") -add_python_module("tensorflow/contrib/slim/python/slim") -add_python_module("tensorflow/contrib/slim/python/slim/data") -add_python_module("tensorflow/contrib/slim/python/slim/nets") -add_python_module("tensorflow/contrib/solvers") -add_python_module("tensorflow/contrib/solvers/python") -add_python_module("tensorflow/contrib/solvers/python/ops") -add_python_module("tensorflow/contrib/sparsemax") -add_python_module("tensorflow/contrib/sparsemax/python") -add_python_module("tensorflow/contrib/sparsemax/python/ops") -add_python_module("tensorflow/contrib/specs") -add_python_module("tensorflow/contrib/specs/python") -add_python_module("tensorflow/contrib/staging") -add_python_module("tensorflow/contrib/stat_summarizer") -add_python_module("tensorflow/contrib/stateless") -add_python_module("tensorflow/contrib/tensorboard") -add_python_module("tensorflow/contrib/tensorboard/plugins") -add_python_module("tensorflow/contrib/tensorboard/plugins/projector") -add_python_module("tensorflow/contrib/tensor_forest") -add_python_module("tensorflow/contrib/tensor_forest/client") -add_python_module("tensorflow/contrib/tensor_forest/core") -add_python_module("tensorflow/contrib/tensor_forest/core/ops") -add_python_module("tensorflow/contrib/tensor_forest/data") -add_python_module("tensorflow/contrib/tensor_forest/hybrid") -add_python_module("tensorflow/contrib/tensor_forest/hybrid/core") -add_python_module("tensorflow/contrib/tensor_forest/hybrid/core/ops") -add_python_module("tensorflow/contrib/tensor_forest/hybrid/ops") -add_python_module("tensorflow/contrib/tensor_forest/hybrid/python") -add_python_module("tensorflow/contrib/tensor_forest/hybrid/python/kernel_tests") -add_python_module("tensorflow/contrib/tensor_forest/hybrid/python/layers") -add_python_module("tensorflow/contrib/tensor_forest/hybrid/python/models") -add_python_module("tensorflow/contrib/tensor_forest/hybrid/python/ops") -add_python_module("tensorflow/contrib/tensor_forest/python") -add_python_module("tensorflow/contrib/tensor_forest/python/kernel_tests") -add_python_module("tensorflow/contrib/tensor_forest/python/ops") -add_python_module("tensorflow/contrib/testing") -add_python_module("tensorflow/contrib/testing/python") -add_python_module("tensorflow/contrib/testing/python/framework") -add_python_module("tensorflow/contrib/text") -add_python_module("tensorflow/contrib/text/kernels") -add_python_module("tensorflow/contrib/text/ops") -add_python_module("tensorflow/contrib/text/python") -add_python_module("tensorflow/contrib/text/python/ops") -add_python_module("tensorflow/contrib/tfprof") -add_python_module("tensorflow/contrib/timeseries") -add_python_module("tensorflow/contrib/timeseries/examples") -add_python_module("tensorflow/contrib/timeseries/examples/data") -add_python_module("tensorflow/contrib/timeseries/python") -add_python_module("tensorflow/contrib/timeseries/python/timeseries") -add_python_module("tensorflow/contrib/timeseries/python/timeseries/state_space_models") -add_python_module("tensorflow/contrib/tpu") -add_python_module("tensorflow/contrib/tpu/ops") -add_python_module("tensorflow/contrib/tpu/profiler") -add_python_module("tensorflow/contrib/tpu/python") -add_python_module("tensorflow/contrib/tpu/python/ops") -add_python_module("tensorflow/contrib/tpu/python/profiler") -add_python_module("tensorflow/contrib/tpu/python/tpu") -add_python_module("tensorflow/contrib/training") -add_python_module("tensorflow/contrib/training/python") -add_python_module("tensorflow/contrib/training/python/training") -add_python_module("tensorflow/contrib/util") -add_python_module("tensorflow/contrib/reduce_slice_ops") -add_python_module("tensorflow/contrib/reduce_slice_ops/kernels") -add_python_module("tensorflow/contrib/reduce_slice_ops/ops") -add_python_module("tensorflow/contrib/reduce_slice_ops/python") -add_python_module("tensorflow/contrib/reduce_slice_ops/python/kernel_tests") -add_python_module("tensorflow/contrib/reduce_slice_ops/python/ops") -add_python_module("tensorflow/contrib/summary") +FILE(READ python_modules.txt python_modules) +# Convert file contents into a CMake list (where each element in the list is one line of the file) +STRING(REGEX REPLACE ";" "\\\\;" python_modules "${python_modules}") +STRING(REGEX REPLACE "\n" ";" python_modules "${python_modules}") + +foreach(python_module ${python_modules}) + add_python_module(${python_module}) +endforeach(python_module) + +add_custom_command(TARGET tf_python_touchup_modules PRE_BUILD + COMMAND ${CMAKE_COMMAND} -E make_directory + "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/lite") +add_custom_command(TARGET tf_python_touchup_modules PRE_BUILD + COMMAND ${CMAKE_COMMAND} -E make_directory + "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/lite/python") +add_custom_command(TARGET tf_python_touchup_modules PRE_BUILD + COMMAND ${CMAKE_COMMAND} -E touch + "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/lite/python/__init__.py") +add_custom_command( + TARGET tf_python_copy_scripts_to_destination PRE_BUILD + COMMAND ${CMAKE_COMMAND} -E touch + ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/lite/python/lite.py) # Generate the tensorflow.python.platform.build_info module. set(BUILD_INFO_PY "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/platform/build_info.py") @@ -694,6 +266,9 @@ function(GENERATE_PYTHON_OP_LIB tf_python_op_lib_name) set(require_shape_fn 1) endif() + get_filename_component(GENERATE_PYTHON_OP_LIB_MKDIRPATH ${GENERATE_PYTHON_OP_LIB_DESTINATION} PATH) + file(MAKE_DIRECTORY ${GENERATE_PYTHON_OP_LIB_MKDIRPATH}) + # Create a C++ executable that links in the appropriate op # registrations and generates Python wrapper code based on the # registered ops. @@ -714,7 +289,7 @@ function(GENERATE_PYTHON_OP_LIB tf_python_op_lib_name) # containing the wrappers. add_custom_command( OUTPUT ${GENERATE_PYTHON_OP_LIB_DESTINATION} - COMMAND ${tf_python_op_lib_name}_gen_python @${tensorflow_source_dir}/tensorflow/python/ops/hidden_ops.txt ${require_shape_fn} > ${GENERATE_PYTHON_OP_LIB_DESTINATION} + COMMAND ${tf_python_op_lib_name}_gen_python ${tensorflow_source_dir}/tensorflow/core/api_def/base_api,${tensorflow_source_dir}/tensorflow/core/api_def/python_api @${tensorflow_source_dir}/tensorflow/python/ops/hidden_ops.txt ${require_shape_fn} > ${GENERATE_PYTHON_OP_LIB_DESTINATION} DEPENDS ${tf_python_op_lib_name}_gen_python ) @@ -722,6 +297,7 @@ function(GENERATE_PYTHON_OP_LIB tf_python_op_lib_name) ${GENERATE_PYTHON_OP_LIB_DESTINATION} PARENT_SCOPE) endfunction() +GENERATE_PYTHON_OP_LIB("audio_ops") GENERATE_PYTHON_OP_LIB("array_ops") GENERATE_PYTHON_OP_LIB("bitwise_ops") GENERATE_PYTHON_OP_LIB("math_ops") @@ -791,6 +367,9 @@ GENERATE_PYTHON_OP_LIB("contrib_memory_stats_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/memory_stats/ops/gen_memory_stats_ops.py) GENERATE_PYTHON_OP_LIB("contrib_nccl_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/nccl/ops/gen_nccl_ops.py) +GENERATE_PYTHON_OP_LIB("contrib_periodic_resample_ops" + DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/periodic_resample/python/ops/gen_periodic_resample_op.py) + GENERATE_PYTHON_OP_LIB("contrib_nearest_neighbor_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/nearest_neighbor/ops/gen_nearest_neighbor_ops.py) GENERATE_PYTHON_OP_LIB("contrib_resampler_ops" @@ -863,6 +442,8 @@ set (pywrap_tensorflow_internal_src "${tensorflow_source_dir}/tensorflow/python/framework/cpp_shape_inference.cc" "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen.h" "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen.cc" + "${tensorflow_source_dir}/tensorflow/python/lib/core/bfloat16.h" + "${tensorflow_source_dir}/tensorflow/python/lib/core/bfloat16.cc" "${tensorflow_source_dir}/tensorflow/python/lib/core/numpy.h" "${tensorflow_source_dir}/tensorflow/python/lib/core/numpy.cc" "${tensorflow_source_dir}/tensorflow/python/lib/core/ndarray_tensor.h" @@ -873,6 +454,8 @@ set (pywrap_tensorflow_internal_src "${tensorflow_source_dir}/tensorflow/python/lib/core/py_func.cc" "${tensorflow_source_dir}/tensorflow/python/lib/core/py_seq_tensor.h" "${tensorflow_source_dir}/tensorflow/python/lib/core/py_seq_tensor.cc" + "${tensorflow_source_dir}/tensorflow/python/lib/core/py_util.h" + "${tensorflow_source_dir}/tensorflow/python/lib/core/py_util.cc" "${tensorflow_source_dir}/tensorflow/python/lib/core/safe_ptr.h" "${tensorflow_source_dir}/tensorflow/python/lib/core/safe_ptr.cc" "${tensorflow_source_dir}/tensorflow/python/lib/io/py_record_reader.h" @@ -966,7 +549,7 @@ add_library(pywrap_tensorflow_internal SHARED $ $<$:$> $ - $<$:$> + $<$:$<$:$>> $<$:$> ${pywrap_tensorflow_deffile} ) @@ -989,6 +572,20 @@ target_link_libraries(pywrap_tensorflow_internal PRIVATE ) if(WIN32) + + # include contrib/periodic_resample as .so + # + set(tf_periodic_resample_srcs + "${tensorflow_source_dir}/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.cc" + "${tensorflow_source_dir}/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.h" + "${tensorflow_source_dir}/tensorflow/contrib/periodic_resample/ops/array_ops.cc" + ) + + AddUserOps(TARGET _periodic_resample_op + SOURCES "${tf_periodic_resample_srcs}" + DEPENDS pywrap_tensorflow_internal tf_python_ops + DISTCOPY ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/periodic_resample/python/ops/) + # include contrib/nearest_neighbor as .so # set(tf_nearest_neighbor_srcs @@ -1042,25 +639,23 @@ if(WIN32) DISTCOPY ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/rnn/python/ops/) endif(WIN32) -if(WIN32) - # include contrib/seq2seq as .so - # - set(tf_beam_search_srcs - "${tensorflow_source_dir}/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc" - "${tensorflow_source_dir}/tensorflow/contrib/seq2seq/kernels/beam_search_ops.h" - "${tensorflow_source_dir}/tensorflow/contrib/seq2seq/ops/beam_search_ops.cc" - ) +# include contrib/seq2seq as .so +# +set(tf_beam_search_srcs + "${tensorflow_source_dir}/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc" + "${tensorflow_source_dir}/tensorflow/contrib/seq2seq/kernels/beam_search_ops.h" + "${tensorflow_source_dir}/tensorflow/contrib/seq2seq/ops/beam_search_ops.cc" +) - set(tf_beam_search_gpu_srcs - "${tensorflow_source_dir}/tensorflow/contrib/seq2seq/kernels/beam_search_ops_gpu.cu.cc" - ) +set(tf_beam_search_gpu_srcs + "${tensorflow_source_dir}/tensorflow/contrib/seq2seq/kernels/beam_search_ops_gpu.cu.cc" +) - AddUserOps(TARGET _beam_search_ops - SOURCES "${tf_beam_search_srcs}" - GPUSOURCES ${tf_beam_search_gpu_srcs} - DEPENDS pywrap_tensorflow_internal tf_python_ops - DISTCOPY ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/seq2seq/python/ops/) -endif(WIN32) +AddUserOps(TARGET _beam_search_ops + SOURCES "${tf_beam_search_srcs}" + GPUSOURCES ${tf_beam_search_gpu_srcs} + DEPENDS pywrap_tensorflow_internal tf_python_ops + DISTCOPY ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/seq2seq/python/ops/) ############################################################ # Build a PIP package containing the TensorFlow runtime. diff --git a/tensorflow/contrib/cmake/tf_shared_lib.cmake b/tensorflow/contrib/cmake/tf_shared_lib.cmake index 9bf45bab3041142206900bf96beeddefb3308ee4..571d2b0decb5e9afcec2314f9837546f0974e90d 100644 --- a/tensorflow/contrib/cmake/tf_shared_lib.cmake +++ b/tensorflow/contrib/cmake/tf_shared_lib.cmake @@ -45,7 +45,7 @@ if(WIN32) $ $ ) - + set(tensorflow_deffile "${CMAKE_CURRENT_BINARY_DIR}/${CMAKE_BUILD_TYPE}/tensorflow.def") set_source_files_properties(${tensorflow_deffile} PROPERTIES GENERATED TRUE) @@ -73,7 +73,7 @@ add_library(tensorflow SHARED $ $<$:$> $ - $<$:$> + $<$:$<$:$>> $<$:$> ${tensorflow_deffile} ) @@ -94,3 +94,54 @@ endif() if(WIN32) add_dependencies(tensorflow tensorflow_static) endif(WIN32) + +target_include_directories(tensorflow PUBLIC + $ + $) + +install(TARGETS tensorflow EXPORT tensorflow_export + RUNTIME DESTINATION bin + LIBRARY DESTINATION lib + ARCHIVE DESTINATION lib) + +install(EXPORT tensorflow_export + FILE TensorflowConfig.cmake + DESTINATION lib/cmake) + +# install necessary headers +# tensorflow headers +install(DIRECTORY ${tensorflow_source_dir}/tensorflow/cc/ + DESTINATION include/tensorflow/cc + FILES_MATCHING PATTERN "*.h") +install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/tensorflow/cc/ + DESTINATION include/tensorflow/cc + FILES_MATCHING PATTERN "*.h") +install(DIRECTORY ${tensorflow_source_dir}/tensorflow/core/ + DESTINATION include/tensorflow/core + FILES_MATCHING PATTERN "*.h") +install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/tensorflow/core/ + DESTINATION include/tensorflow/core + FILES_MATCHING PATTERN "*.h") +install(DIRECTORY ${tensorflow_source_dir}/tensorflow/stream_executor/ + DESTINATION include/tensorflow/stream_executor + FILES_MATCHING PATTERN "*.h") +# google protobuf headers +install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/protobuf/src/protobuf/src/google/ + DESTINATION include/google + FILES_MATCHING PATTERN "*.h") +# nsync headers +install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/external/nsync/ + DESTINATION include/external/nsync + FILES_MATCHING PATTERN "*.h") +# Eigen directory +install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/eigen/src/eigen/Eigen/ + DESTINATION include/Eigen) +# external directory +install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/external/eigen_archive/ + DESTINATION include/external/eigen_archive) +# third_party eigen directory +install(DIRECTORY ${tensorflow_source_dir}/third_party/eigen3/ + DESTINATION include/third_party/eigen3) +# unsupported Eigen directory +install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/eigen/src/eigen/unsupported/Eigen/ + DESTINATION include/unsupported/Eigen) diff --git a/tensorflow/contrib/cmake/tf_stream_executor.cmake b/tensorflow/contrib/cmake/tf_stream_executor.cmake index 3d84f1ebb9c1fa1b2f3ccdd8d5ae8eaf182f7715..91ca33f4c4d5f6c822f45b0676e6e46d2e4c2860 100644 --- a/tensorflow/contrib/cmake/tf_stream_executor.cmake +++ b/tensorflow/contrib/cmake/tf_stream_executor.cmake @@ -61,19 +61,22 @@ file(GLOB tf_stream_executor_srcs "${tensorflow_source_dir}/tensorflow/stream_executor/platform/default/*.h" ) -if (tensorflow_ENABLE_GPU) +if (tensorflow_ENABLE_GPU) file(GLOB tf_stream_executor_gpu_srcs "${tensorflow_source_dir}/tensorflow/stream_executor/cuda/*.cc" ) list(APPEND tf_stream_executor_srcs ${tf_stream_executor_gpu_srcs}) -endif() +endif() #file(GLOB_RECURSE tf_stream_executor_test_srcs # "${tensorflow_source_dir}/tensorflow/stream_executor/*_test.cc" # "${tensorflow_source_dir}/tensorflow/stream_executor/*_test.h" #) -#list(REMOVE_ITEM tf_stream_executor_srcs ${tf_stream_executor_test_srcs}) +#list(REMOVE_ITEM tf_stream_executor_srcs ${tf_stream_executor_test_srcs}) +if (NOT WIN32) + set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -lgomp") +endif (NOT WIN32) add_library(tf_stream_executor OBJECT ${tf_stream_executor_srcs}) add_dependencies(tf_stream_executor diff --git a/tensorflow/contrib/cmake/tf_tests.cmake b/tensorflow/contrib/cmake/tf_tests.cmake index 77d21249148cc900a1bb4fc2742956aee47734de..94ca4b00175dffb4461fca34c5ecd79ba79be778 100644 --- a/tensorflow/contrib/cmake/tf_tests.cmake +++ b/tensorflow/contrib/cmake/tf_tests.cmake @@ -139,12 +139,15 @@ if (tensorflow_BUILD_PYTHON_TESTS) file(GLOB_RECURSE tf_test_src_py ${tf_test_rnn_src_py} + "${tensorflow_source_dir}/tensorflow/python/data/kernel_tests/*.py" "${tensorflow_source_dir}/tensorflow/python/debug/cli/*_test.py" "${tensorflow_source_dir}/tensorflow/python/debug/lib/*_test.py" "${tensorflow_source_dir}/tensorflow/python/debug/wrappers/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/estimator/python/estimator/*_test.py" "${tensorflow_source_dir}/tensorflow/python/kernel_tests/*.py" "${tensorflow_source_dir}/tensorflow/python/meta_graph_transform/*_test.py" + "${tensorflow_source_dir}/tensorflow/python/ops/quantized_conv_ops_test.py" + "${tensorflow_source_dir}/tensorflow/python/ops/quantized_ops_test.py" "${tensorflow_source_dir}/tensorflow/python/platform/build_info_test.py" "${tensorflow_source_dir}/tensorflow/python/profiler/*_test.py" "${tensorflow_source_dir}/tensorflow/python/profiler/internal/*_test.py" @@ -153,7 +156,8 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/contrib/data/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/factorization/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/image/*_test.py" - "${tensorflow_source_dir}/tensorflow/contrib/keras/python/keras/integration_test.py" + "${tensorflow_source_dir}/tensorflow/python/keras/_impl/keras/*_test.py" + "${tensorflow_source_dir}/tensorflow/contrib/periodic_resample/python/kernel_tests/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/nearest_neighbor/python/kernel_tests/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/seq2seq/python/kernel_tests/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/stateless/python/kernel_tests/*_test.py" @@ -171,7 +175,6 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/contrib/graph_editor/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/bayesflow/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/framework/*_test.py" - "${tensorflow_source_dir}/tensorflow/contrib/keras/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/distributions/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/learn/*_test.py" ) @@ -179,17 +182,10 @@ if (tensorflow_BUILD_PYTHON_TESTS) # exclude the ones we don't want set(tf_test_src_py_exclude - # generally excluded + # Not a test. "${tensorflow_source_dir}/tensorflow/python/kernel_tests/__init__.py" - - # Python source line inspection tests are flaky on Windows (b/36375074). - "${tensorflow_source_dir}/tensorflow/python/debug/cli/analyzer_cli_test.py" - "${tensorflow_source_dir}/tensorflow/python/debug/cli/profile_analyzer_cli_test.py" - # Windows does not have the curses library and uses readline. - "${tensorflow_source_dir}/tensorflow/python/debug/cli/curses_ui_test.py" - # TFDBG grpc:// mode is not yet available on Windows. - "${tensorflow_source_dir}/tensorflow/python/debug/lib/dist_session_debug_grpc_test.py" - "${tensorflow_source_dir}/tensorflow/python/debug/lib/session_debug_grpc_test.py" + # Flaky because of port collisions. + "${tensorflow_source_dir}/tensorflow/python/training/localhost_cluster_performance_test.py" # generally not working "${tensorflow_source_dir}/tensorflow/python/profiler/pprof_profiler_test.py" # flaky test @@ -216,7 +212,15 @@ if (tensorflow_BUILD_PYTHON_TESTS) # TODO: failing tests. # Nothing critical in here but should get this list down to [] # The failing list is grouped by failure source - + # Python source line inspection tests are flaky on Windows (b/36375074). + "${tensorflow_source_dir}/tensorflow/python/debug/cli/analyzer_cli_test.py" + "${tensorflow_source_dir}/tensorflow/python/debug/cli/profile_analyzer_cli_test.py" + # Windows does not have the curses library and uses readline. + "${tensorflow_source_dir}/tensorflow/python/debug/cli/curses_ui_test.py" + # TFDBG grpc:// mode is not yet available on Windows. + "${tensorflow_source_dir}/tensorflow/python/debug/lib/dist_session_debug_grpc_test.py" + "${tensorflow_source_dir}/tensorflow/python/debug/lib/session_debug_grpc_test.py" + "${tensorflow_source_dir}/tensorflow/python/debug/lib/source_remote_test.py" # stl on windows handles overflows different "${tensorflow_source_dir}/tensorflow/python/kernel_tests/as_string_op_test.py" "${tensorflow_source_dir}/tensorflow/python/kernel_tests/string_to_number_op_test.py" @@ -225,6 +229,10 @@ if (tensorflow_BUILD_PYTHON_TESTS) # Numerical issues, calculations off. "${tensorflow_source_dir}/tensorflow/python/kernel_tests/concat_op_test.py" "${tensorflow_source_dir}/tensorflow/contrib/factorization/python/ops/wals_test.py" + "${tensorflow_source_dir}/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py" + "${tensorflow_source_dir}/tensorflow/python/keras/_impl/keras/utils/data_utils_test.py" + "${tensorflow_source_dir}/tensorflow/python/keras/_impl/keras/backend_test.py" + "${tensorflow_source_dir}/tensorflow/python/keras/_impl/keras/preprocessing/image_test.py" # Float division by zero "${tensorflow_source_dir}/tensorflow/python/kernel_tests/benchmark_test.py" # Flaky, for unknown reasons. Cannot reproduce in terminal. Revisit once we can get stack traces. @@ -233,11 +241,11 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/python/training/sync_replicas_optimizer_test.py" "${tensorflow_source_dir}/tensorflow/python/debug/lib/session_debug_grpc_test.py" "${tensorflow_source_dir}tensorflow/python/training/localhost_cluster_performance_test.py" - "${tensorflow_source_dir}/tensorflow/python/kernel_tests/iterator_ops_cluster_test.py" + "${tensorflow_source_dir}/tensorflow/python/data/kernel_tests/iterator_ops_cluster_test.py" "${tensorflow_source_dir}/tensorflow/python/kernel_tests/functional_ops_test.py" "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/iterator_ops_cluster_test.py" # Type error in testRemoteIteratorUsingRemoteCallOpDirectSessionGPUCPU. - "${tensorflow_source_dir}/tensorflow/python/kernel_tests/iterator_ops_test.py" + "${tensorflow_source_dir}/tensorflow/python/data/kernel_tests/iterator_ops_test.py" "${tensorflow_source_dir}/tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py" "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py" # IteratorGetMax OutOfRangeError @@ -261,9 +269,9 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/python/kernel_tests/linalg_grad_test.py" # cudaSolver handle creation fails. "${tensorflow_source_dir}/tensorflow/python/kernel_tests/array_ops_test.py" # depends on python/framework/test_ops # Dataset tests - "${tensorflow_source_dir}/tensorflow/python/kernel_tests/dataset_constructor_op_test.py" # Segfaults on windows + "${tensorflow_source_dir}/tensorflow/python/data/kernel_tests/dataset_constructor_op_test.py" # Segfaults on windows "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py" # Segfaults on Windows. - "${tensorflow_source_dir}/tensorflow/python/kernel_tests/iterator_ops_cluster_test.py" + "${tensorflow_source_dir}/tensorflow/python/data/kernel_tests/iterator_ops_cluster_test.py" # Broken tensorboard test due to cmake issues. "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/iterator_ops_cluster_test.py" # Needs portpicker "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/sloppy_transformation_dataset_op_test.py" # b/65430561 @@ -294,6 +302,9 @@ if (tensorflow_BUILD_PYTHON_TESTS) # Test should only be run manually "${tensorflow_source_dir}/tensorflow/python/kernel_tests/reduction_ops_test_big.py" "${tensorflow_source_dir}/tensorflow/python/kernel_tests/svd_op_test.py" + # Depends on python/framework/test_ops + "${tensorflow_source_dir}/tensorflow/python/kernel_tests/array_ops_test.py" + "${tensorflow_source_dir}/tensorflow/python/kernel_tests/control_flow_util_test.py" ) endif() list(REMOVE_ITEM tf_test_src_py ${tf_test_src_py_exclude}) diff --git a/tensorflow/contrib/cmake/tf_tools.cmake b/tensorflow/contrib/cmake/tf_tools.cmake index 6ef95989630a39eaedaddda68f7da709e7d9ab03..cb58a2e7df85b2f214654eff5547c5788592f208 100644 --- a/tensorflow/contrib/cmake/tf_tools.cmake +++ b/tensorflow/contrib/cmake/tf_tools.cmake @@ -73,7 +73,7 @@ add_executable(${transform_graph} $ $ $ - $<$:$> + $<$:$<$:$>> $<$:$> ) @@ -95,7 +95,7 @@ add_executable(${summarize_graph} $ $ $ - $<$:$> + $<$:$<$:$>> $<$:$> ) @@ -117,7 +117,7 @@ add_executable(${compare_graphs} $ $ $ - $<$:$> + $<$:$<$:$>> $<$:$> ) @@ -138,7 +138,7 @@ add_executable(${benchmark_model} $ $ $ - $<$:$> + $<$:$<$:$>> $<$:$> ) @@ -147,3 +147,8 @@ target_link_libraries(${benchmark_model} PUBLIC ${tf_core_gpu_kernels_lib} ${tensorflow_EXTERNAL_LIBRARIES} ) + +install(TARGETS ${transform_graph} ${summarize_graph} ${compare_graphs} ${benchmark_model} + RUNTIME DESTINATION bin + LIBRARY DESTINATION lib + ARCHIVE DESTINATION lib) diff --git a/tensorflow/contrib/cmake/tf_tutorials.cmake b/tensorflow/contrib/cmake/tf_tutorials.cmake index 858e7dda92e9e9f456d5fc56b563b2e3ec998520..e63fccc1810b348e543159681a73e7a9c1422c01 100644 --- a/tensorflow/contrib/cmake/tf_tutorials.cmake +++ b/tensorflow/contrib/cmake/tf_tutorials.cmake @@ -34,3 +34,8 @@ target_link_libraries(tf_tutorials_example_trainer PUBLIC ${tf_core_gpu_kernels_lib} ${tensorflow_EXTERNAL_LIBRARIES} ) + +install(TARGETS tf_tutorials_example_trainer + RUNTIME DESTINATION bin + LIBRARY DESTINATION lib + ARCHIVE DESTINATION lib) diff --git a/tensorflow/contrib/copy_graph/python/util/copy_elements.py b/tensorflow/contrib/copy_graph/python/util/copy_elements.py index 8c2528f548799f9facef740b0134ac56966b2b04..bae66ffd4289308f2cbfc730ec50d057b13923fb 100644 --- a/tensorflow/contrib/copy_graph/python/util/copy_elements.py +++ b/tensorflow/contrib/copy_graph/python/util/copy_elements.py @@ -19,7 +19,7 @@ from one graph to another. The copied elements are initialized inside a user-specified scope in the other graph. There are separate functions to copy ops and variables. There is also a function to retrive the copied version of an op from the -first graph inside a scope in the second graph. +first graph inside a scope in the second graph. @@copy_op_to_graph @@copy_variable_to_graph @@ -225,7 +225,7 @@ def copy_op_to_graph(org_instance, to_graph, variables, new_original_op, op_def) #Use Graph's hidden methods to add the op - to_graph._add_op(new_op) + to_graph._add_op(new_op) # pylint: disable=protected-access to_graph._record_op_seen_by_control_dependencies(new_op) for device_function in reversed(to_graph._device_function_stack): new_op._set_device(device_function(new_op)) diff --git a/tensorflow/contrib/crf/python/kernel_tests/crf_test.py b/tensorflow/contrib/crf/python/kernel_tests/crf_test.py index 9174c5eb989908d5a318e228bf231686b5117798..b47fb426a193e0fcc075deafae3eaab698f18ec9 100644 --- a/tensorflow/contrib/crf/python/kernel_tests/crf_test.py +++ b/tensorflow/contrib/crf/python/kernel_tests/crf_test.py @@ -23,7 +23,6 @@ import itertools import numpy as np from tensorflow.contrib.crf.python.ops import crf -from tensorflow.python.framework import dtypes from tensorflow.python.framework import constant_op from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops @@ -33,43 +32,58 @@ from tensorflow.python.platform import test class CrfTest(test.TestCase): def testCrfSequenceScore(self): - inputs = np.array( - [[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32) - tag_indices = np.array([1, 2, 1, 0], dtype=np.int32) transition_params = np.array( [[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32) - sequence_lengths = np.array(3, dtype=np.int32) - with self.test_session() as sess: - sequence_score = crf.crf_sequence_score( - inputs=array_ops.expand_dims(inputs, 0), - tag_indices=array_ops.expand_dims(tag_indices, 0), - sequence_lengths=array_ops.expand_dims(sequence_lengths, 0), - transition_params=constant_op.constant(transition_params)) - sequence_score = array_ops.squeeze(sequence_score, [0]) - tf_sequence_score = sess.run(sequence_score) - expected_unary_score = sum(inputs[i][tag_indices[i]] - for i in range(sequence_lengths)) - expected_binary_score = sum( - transition_params[tag_indices[i], tag_indices[i + 1]] - for i in range(sequence_lengths - 1)) - expected_sequence_score = expected_unary_score + expected_binary_score - self.assertAllClose(tf_sequence_score, expected_sequence_score) + # Test both the length-1 and regular cases. + sequence_lengths_list = [ + np.array(3, dtype=np.int32), + np.array(1, dtype=np.int32) + ] + inputs_list = [ + np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], + dtype=np.float32), + np.array([[4, 5, -3]], + dtype=np.float32), + ] + tag_indices_list = [ + np.array([1, 2, 1, 0], dtype=np.int32), + np.array([1], dtype=np.int32) + ] + for sequence_lengths, inputs, tag_indices in zip(sequence_lengths_list, + inputs_list, + tag_indices_list): + with self.test_session() as sess: + sequence_score = crf.crf_sequence_score( + inputs=array_ops.expand_dims(inputs, 0), + tag_indices=array_ops.expand_dims(tag_indices, 0), + sequence_lengths=array_ops.expand_dims(sequence_lengths, 0), + transition_params=constant_op.constant(transition_params)) + sequence_score = array_ops.squeeze(sequence_score, [0]) + tf_sequence_score = sess.run(sequence_score) + expected_unary_score = sum(inputs[i][tag_indices[i]] + for i in range(sequence_lengths)) + expected_binary_score = sum( + transition_params[tag_indices[i], tag_indices[i + 1]] + for i in range(sequence_lengths - 1)) + expected_sequence_score = expected_unary_score + expected_binary_score + self.assertAllClose(tf_sequence_score, expected_sequence_score) def testCrfUnaryScore(self): inputs = np.array( [[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32) - tag_indices = np.array([1, 2, 1, 0], dtype=np.int32) - sequence_lengths = np.array(3, dtype=np.int32) - with self.test_session() as sess: - unary_score = crf.crf_unary_score( - tag_indices=array_ops.expand_dims(tag_indices, 0), - sequence_lengths=array_ops.expand_dims(sequence_lengths, 0), - inputs=array_ops.expand_dims(inputs, 0)) - unary_score = array_ops.squeeze(unary_score, [0]) - tf_unary_score = sess.run(unary_score) - expected_unary_score = sum(inputs[i][tag_indices[i]] - for i in range(sequence_lengths)) - self.assertAllClose(tf_unary_score, expected_unary_score) + for dtype in (np.int32, np.int64): + tag_indices = np.array([1, 2, 1, 0], dtype=dtype) + sequence_lengths = np.array(3, dtype=np.int32) + with self.test_session() as sess: + unary_score = crf.crf_unary_score( + tag_indices=array_ops.expand_dims(tag_indices, 0), + sequence_lengths=array_ops.expand_dims(sequence_lengths, 0), + inputs=array_ops.expand_dims(inputs, 0)) + unary_score = array_ops.squeeze(unary_score, [0]) + tf_unary_score = sess.run(unary_score) + expected_unary_score = sum(inputs[i][tag_indices[i]] + for i in range(sequence_lengths)) + self.assertAllClose(tf_unary_score, expected_unary_score) def testCrfBinaryScore(self): tag_indices = np.array([1, 2, 1, 0], dtype=np.int32) @@ -89,38 +103,54 @@ class CrfTest(test.TestCase): self.assertAllClose(tf_binary_score, expected_binary_score) def testCrfLogNorm(self): - inputs = np.array( - [[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32) transition_params = np.array( [[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32) - num_words = inputs.shape[0] - num_tags = inputs.shape[1] - sequence_lengths = np.array(3, dtype=np.int32) - with self.test_session() as sess: - all_sequence_scores = [] - - # Compare the dynamic program with brute force computation. - for tag_indices in itertools.product( - range(num_tags), repeat=sequence_lengths): - tag_indices = list(tag_indices) - tag_indices.extend([0] * (num_words - sequence_lengths)) - all_sequence_scores.append( - crf.crf_sequence_score( - inputs=array_ops.expand_dims(inputs, 0), - tag_indices=array_ops.expand_dims(tag_indices, 0), - sequence_lengths=array_ops.expand_dims(sequence_lengths, 0), - transition_params=constant_op.constant(transition_params))) - - brute_force_log_norm = math_ops.reduce_logsumexp(all_sequence_scores) - log_norm = crf.crf_log_norm( - inputs=array_ops.expand_dims(inputs, 0), - sequence_lengths=array_ops.expand_dims(sequence_lengths, 0), - transition_params=constant_op.constant(transition_params)) - log_norm = array_ops.squeeze(log_norm, [0]) - tf_brute_force_log_norm, tf_log_norm = sess.run( - [brute_force_log_norm, log_norm]) + # Test both the length-1 and regular cases. + sequence_lengths_list = [ + np.array(3, dtype=np.int32), + np.array(1, dtype=np.int32) + ] + inputs_list = [ + np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], + dtype=np.float32), + np.array([[3, -1, 3]], + dtype=np.float32), + ] + tag_indices_list = [ + np.array([1, 2, 1, 0], dtype=np.int32), + np.array([2], dtype=np.int32) + ] + + for sequence_lengths, inputs, tag_indices in zip(sequence_lengths_list, + inputs_list, + tag_indices_list): + num_words = inputs.shape[0] + num_tags = inputs.shape[1] + with self.test_session() as sess: + all_sequence_scores = [] + + # Compare the dynamic program with brute force computation. + for tag_indices in itertools.product( + range(num_tags), repeat=sequence_lengths): + tag_indices = list(tag_indices) + tag_indices.extend([0] * (num_words - sequence_lengths)) + all_sequence_scores.append( + crf.crf_sequence_score( + inputs=array_ops.expand_dims(inputs, 0), + tag_indices=array_ops.expand_dims(tag_indices, 0), + sequence_lengths=array_ops.expand_dims(sequence_lengths, 0), + transition_params=constant_op.constant(transition_params))) + + brute_force_log_norm = math_ops.reduce_logsumexp(all_sequence_scores) + log_norm = crf.crf_log_norm( + inputs=array_ops.expand_dims(inputs, 0), + sequence_lengths=array_ops.expand_dims(sequence_lengths, 0), + transition_params=constant_op.constant(transition_params)) + log_norm = array_ops.squeeze(log_norm, [0]) + tf_brute_force_log_norm, tf_log_norm = sess.run( + [brute_force_log_norm, log_norm]) - self.assertAllClose(tf_log_norm, tf_brute_force_log_norm) + self.assertAllClose(tf_log_norm, tf_brute_force_log_norm) def testCrfLogLikelihood(self): inputs = np.array( @@ -201,50 +231,66 @@ class CrfTest(test.TestCase): expected_max_sequence[:sequence_lengths]) def testCrfDecode(self): - inputs = np.array( - [[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32) transition_params = np.array( [[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32) - sequence_lengths = np.array(3, dtype=np.int32) - num_words = inputs.shape[0] - num_tags = inputs.shape[1] - - with self.test_session() as sess: - all_sequence_scores = [] - all_sequences = [] - - # Compare the dynamic program with brute force computation. - for tag_indices in itertools.product( - range(num_tags), repeat=sequence_lengths): - tag_indices = list(tag_indices) - tag_indices.extend([0] * (num_words - sequence_lengths)) - all_sequences.append(tag_indices) - sequence_score = crf.crf_sequence_score( - inputs=array_ops.expand_dims(inputs, 0), - tag_indices=array_ops.expand_dims(tag_indices, 0), - sequence_lengths=array_ops.expand_dims(sequence_lengths, 0), - transition_params=constant_op.constant(transition_params)) - sequence_score = array_ops.squeeze(sequence_score, [0]) - all_sequence_scores.append(sequence_score) - - tf_all_sequence_scores = sess.run(all_sequence_scores) - - expected_max_sequence_index = np.argmax(tf_all_sequence_scores) - expected_max_sequence = all_sequences[expected_max_sequence_index] - expected_max_score = tf_all_sequence_scores[expected_max_sequence_index] - - actual_max_sequence, actual_max_score = crf.crf_decode( - array_ops.expand_dims(inputs, 0), - constant_op.constant(transition_params), - array_ops.expand_dims(sequence_lengths, 0)) - actual_max_sequence = array_ops.squeeze(actual_max_sequence, [0]) - actual_max_score = array_ops.squeeze(actual_max_score, [0]) - tf_actual_max_sequence, tf_actual_max_score = sess.run( - [actual_max_sequence, actual_max_score]) - - self.assertAllClose(tf_actual_max_score, expected_max_score) - self.assertEqual(list(tf_actual_max_sequence[:sequence_lengths]), - expected_max_sequence[:sequence_lengths]) + # Test both the length-1 and regular cases. + sequence_lengths_list = [ + np.array(3, dtype=np.int32), + np.array(1, dtype=np.int32) + ] + inputs_list = [ + np.array([[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], + dtype=np.float32), + np.array([[-1, 2, 1]], + dtype=np.float32), + ] + tag_indices_list = [ + np.array([1, 2, 1, 0], dtype=np.int32), + np.array([2], dtype=np.int32) + ] + + for sequence_lengths, inputs, tag_indices in zip(sequence_lengths_list, + inputs_list, + tag_indices_list): + num_words = inputs.shape[0] + num_tags = inputs.shape[1] + + with self.test_session() as sess: + all_sequence_scores = [] + all_sequences = [] + + # Compare the dynamic program with brute force computation. + for tag_indices in itertools.product( + range(num_tags), repeat=sequence_lengths): + tag_indices = list(tag_indices) + tag_indices.extend([0] * (num_words - sequence_lengths)) + all_sequences.append(tag_indices) + sequence_score = crf.crf_sequence_score( + inputs=array_ops.expand_dims(inputs, 0), + tag_indices=array_ops.expand_dims(tag_indices, 0), + sequence_lengths=array_ops.expand_dims(sequence_lengths, 0), + transition_params=constant_op.constant(transition_params)) + sequence_score = array_ops.squeeze(sequence_score, [0]) + all_sequence_scores.append(sequence_score) + + tf_all_sequence_scores = sess.run(all_sequence_scores) + + expected_max_sequence_index = np.argmax(tf_all_sequence_scores) + expected_max_sequence = all_sequences[expected_max_sequence_index] + expected_max_score = tf_all_sequence_scores[expected_max_sequence_index] + + actual_max_sequence, actual_max_score = crf.crf_decode( + array_ops.expand_dims(inputs, 0), + constant_op.constant(transition_params), + array_ops.expand_dims(sequence_lengths, 0)) + actual_max_sequence = array_ops.squeeze(actual_max_sequence, [0]) + actual_max_score = array_ops.squeeze(actual_max_score, [0]) + tf_actual_max_sequence, tf_actual_max_score = sess.run( + [actual_max_sequence, actual_max_score]) + + self.assertAllClose(tf_actual_max_score, expected_max_score) + self.assertEqual(list(tf_actual_max_sequence[:sequence_lengths]), + expected_max_sequence[:sequence_lengths]) if __name__ == "__main__": diff --git a/tensorflow/contrib/crf/python/ops/crf.py b/tensorflow/contrib/crf/python/ops/crf.py index 7166e38b28365a6dbce9cf134f81b08a57c722de..7f5ae937b26f465076c6976429697c35924432e5 100644 --- a/tensorflow/contrib/crf/python/ops/crf.py +++ b/tensorflow/contrib/crf/python/ops/crf.py @@ -53,7 +53,9 @@ from __future__ import print_function import numpy as np from tensorflow.python.framework import dtypes +from tensorflow.python.layers import utils from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import rnn @@ -101,12 +103,29 @@ def crf_sequence_score(inputs, tag_indices, sequence_lengths, Returns: sequence_scores: A [batch_size] vector of unnormalized sequence scores. """ - # Compute the scores of the given tag sequence. - unary_scores = crf_unary_score(tag_indices, sequence_lengths, inputs) - binary_scores = crf_binary_score(tag_indices, sequence_lengths, - transition_params) - sequence_scores = unary_scores + binary_scores - return sequence_scores + # If max_seq_len is 1, we skip the score calculation and simply gather the + # unary potentials of the single tag. + def _single_seq_fn(): + batch_size = array_ops.shape(inputs, out_type=tag_indices.dtype)[0] + example_inds = array_ops.reshape( + math_ops.range(batch_size, dtype=tag_indices.dtype), [-1, 1]) + return array_ops.gather_nd( + array_ops.squeeze(inputs, [1]), + array_ops.concat([example_inds, tag_indices], axis=1)) + + def _multi_seq_fn(): + # Compute the scores of the given tag sequence. + unary_scores = crf_unary_score(tag_indices, sequence_lengths, inputs) + binary_scores = crf_binary_score(tag_indices, sequence_lengths, + transition_params) + sequence_scores = unary_scores + binary_scores + return sequence_scores + + return utils.smart_cond( + pred=math_ops.equal(inputs.shape[1].value or array_ops.shape(inputs)[1], + 1), + fn1=_single_seq_fn, + fn2=_multi_seq_fn) def crf_log_norm(inputs, sequence_lengths, transition_params): @@ -124,19 +143,32 @@ def crf_log_norm(inputs, sequence_lengths, transition_params): # algorithm. first_input = array_ops.slice(inputs, [0, 0, 0], [-1, 1, -1]) first_input = array_ops.squeeze(first_input, [1]) - rest_of_input = array_ops.slice(inputs, [0, 1, 0], [-1, -1, -1]) - # Compute the alpha values in the forward algorithm in order to get the - # partition function. - forward_cell = CrfForwardRnnCell(transition_params) - _, alphas = rnn.dynamic_rnn( - cell=forward_cell, - inputs=rest_of_input, - sequence_length=sequence_lengths - 1, - initial_state=first_input, - dtype=dtypes.float32) - log_norm = math_ops.reduce_logsumexp(alphas, [1]) - return log_norm + # If max_seq_len is 1, we skip the algorithm and simply reduce_logsumexp over + # the "initial state" (the unary potentials). + def _single_seq_fn(): + return math_ops.reduce_logsumexp(first_input, [1]) + + def _multi_seq_fn(): + """Forward computation of alpha values.""" + rest_of_input = array_ops.slice(inputs, [0, 1, 0], [-1, -1, -1]) + + # Compute the alpha values in the forward algorithm in order to get the + # partition function. + forward_cell = CrfForwardRnnCell(transition_params) + _, alphas = rnn.dynamic_rnn( + cell=forward_cell, + inputs=rest_of_input, + sequence_length=sequence_lengths - 1, + initial_state=first_input, + dtype=dtypes.float32) + log_norm = math_ops.reduce_logsumexp(alphas, [1]) + return log_norm + + max_seq_len = array_ops.shape(inputs)[1] + return control_flow_ops.cond(pred=math_ops.equal(max_seq_len, 1), + true_fn=_single_seq_fn, + false_fn=_multi_seq_fn) def crf_log_likelihood(inputs, @@ -193,6 +225,9 @@ def crf_unary_score(tag_indices, sequence_lengths, inputs): offsets = array_ops.expand_dims( math_ops.range(batch_size) * max_seq_len * num_tags, 1) offsets += array_ops.expand_dims(math_ops.range(max_seq_len) * num_tags, 0) + # Use int32 or int64 based on tag_indices' dtype. + if tag_indices.dtype == dtypes.int64: + offsets = math_ops.to_int64(offsets) flattened_tag_indices = array_ops.reshape(offsets + tag_indices, [-1]) unary_scores = array_ops.reshape( @@ -305,7 +340,7 @@ def viterbi_decode(score, transition_params): Returns: viterbi: A [seq_len] list of integers containing the highest scoring tag - indicies. + indices. viterbi_score: A float containing the score for the Viterbi sequence. """ trellis = np.zeros_like(score) @@ -360,8 +395,8 @@ class CrfDecodeForwardRnnCell(rnn_cell.RNNCell): scope: Unused variable scope of this cell. Returns: - backpointers: [batch_size, num_tags], containing backpointers. - new_state: [batch_size, num_tags], containing new score values. + backpointers: A [batch_size, num_tags] matrix of backpointers. + new_state: A [batch_size, num_tags] matrix of new score values. """ # For simplicity, in shape comments, denote: # 'batch_size' by 'B', 'max_seq_len' by 'T' , 'num_tags' by 'O' (output). @@ -385,7 +420,7 @@ class CrfDecodeBackwardRnnCell(rnn_cell.RNNCell): """Initialize the CrfDecodeBackwardRnnCell. Args: - num_tags + num_tags: An integer. The number of tags. """ self._num_tags = num_tags @@ -401,8 +436,9 @@ class CrfDecodeBackwardRnnCell(rnn_cell.RNNCell): """Build the CrfDecodeBackwardRnnCell. Args: - inputs: [batch_size, num_tags], backpointer of next step (in time order). - state: [batch_size, 1], next position's tag index. + inputs: A [batch_size, num_tags] matrix of + backpointer of next step (in time order). + state: A [batch_size, 1] matrix of tag index of next step. scope: Unused variable scope of this cell. Returns: @@ -426,52 +462,71 @@ def crf_decode(potentials, transition_params, sequence_length): This is a function for tensor. Args: - potentials: A [batch_size, max_seq_len, num_tags] tensor, matrix of + potentials: A [batch_size, max_seq_len, num_tags] tensor of unary potentials. - transition_params: A [num_tags, num_tags] tensor, matrix of + transition_params: A [num_tags, num_tags] matrix of binary potentials. - sequence_length: A [batch_size] tensor, containing sequence lengths. + sequence_length: A [batch_size] vector of true sequence lengths. Returns: - decode_tags: A [batch_size, max_seq_len] tensor, with dtype tf.int32. - Contains the highest scoring tag indicies. - best_score: A [batch_size] tensor, containing the score of decode_tags. + decode_tags: A [batch_size, max_seq_len] matrix, with dtype `tf.int32`. + Contains the highest scoring tag indices. + best_score: A [batch_size] vector, containing the score of `decode_tags`. """ - # For simplicity, in shape comments, denote: - # 'batch_size' by 'B', 'max_seq_len' by 'T' , 'num_tags' by 'O' (output). - num_tags = potentials.get_shape()[2].value - - # Computes forward decoding. Get last score and backpointers. - crf_fwd_cell = CrfDecodeForwardRnnCell(transition_params) - initial_state = array_ops.slice(potentials, [0, 0, 0], [-1, 1, -1]) - initial_state = array_ops.squeeze(initial_state, axis=[1]) # [B, O] - inputs = array_ops.slice(potentials, [0, 1, 0], [-1, -1, -1]) # [B, T-1, O] - backpointers, last_score = rnn.dynamic_rnn( - crf_fwd_cell, - inputs=inputs, - sequence_length=sequence_length - 1, - initial_state=initial_state, - time_major=False, - dtype=dtypes.int32) # [B, T - 1, O], [B, O] - backpointers = gen_array_ops.reverse_sequence( - backpointers, sequence_length - 1, seq_dim=1) # [B, T-1, O] - - # Computes backward decoding. Extract tag indices from backpointers. - crf_bwd_cell = CrfDecodeBackwardRnnCell(num_tags) - initial_state = math_ops.cast(math_ops.argmax(last_score, axis=1), - dtype=dtypes.int32) # [B] - initial_state = array_ops.expand_dims(initial_state, axis=-1) # [B, 1] - decode_tags, _ = rnn.dynamic_rnn( - crf_bwd_cell, - inputs=backpointers, - sequence_length=sequence_length - 1, - initial_state=initial_state, - time_major=False, - dtype=dtypes.int32) # [B, T - 1, 1] - decode_tags = array_ops.squeeze(decode_tags, axis=[2]) # [B, T - 1] - decode_tags = array_ops.concat([initial_state, decode_tags], axis=1) # [B, T] - decode_tags = gen_array_ops.reverse_sequence( - decode_tags, sequence_length, seq_dim=1) # [B, T] - - best_score = math_ops.reduce_max(last_score, axis=1) # [B] - return decode_tags, best_score + # If max_seq_len is 1, we skip the algorithm and simply return the argmax tag + # and the max activation. + def _single_seq_fn(): + squeezed_potentials = array_ops.squeeze(potentials, [1]) + decode_tags = array_ops.expand_dims( + math_ops.argmax(squeezed_potentials, axis=1), 1) + best_score = math_ops.reduce_max(squeezed_potentials, axis=1) + return math_ops.cast(decode_tags, dtype=dtypes.int32), best_score + + def _multi_seq_fn(): + """Decoding of highest scoring sequence.""" + + # For simplicity, in shape comments, denote: + # 'batch_size' by 'B', 'max_seq_len' by 'T' , 'num_tags' by 'O' (output). + num_tags = potentials.get_shape()[2].value + + # Computes forward decoding. Get last score and backpointers. + crf_fwd_cell = CrfDecodeForwardRnnCell(transition_params) + initial_state = array_ops.slice(potentials, [0, 0, 0], [-1, 1, -1]) + initial_state = array_ops.squeeze(initial_state, axis=[1]) # [B, O] + inputs = array_ops.slice(potentials, [0, 1, 0], [-1, -1, -1]) # [B, T-1, O] + backpointers, last_score = rnn.dynamic_rnn( # [B, T - 1, O], [B, O] + crf_fwd_cell, + inputs=inputs, + sequence_length=sequence_length - 1, + initial_state=initial_state, + time_major=False, + dtype=dtypes.int32) + backpointers = gen_array_ops.reverse_sequence( # [B, T - 1, O] + backpointers, sequence_length - 1, seq_dim=1) + + # Computes backward decoding. Extract tag indices from backpointers. + crf_bwd_cell = CrfDecodeBackwardRnnCell(num_tags) + initial_state = math_ops.cast(math_ops.argmax(last_score, axis=1), # [B] + dtype=dtypes.int32) + initial_state = array_ops.expand_dims(initial_state, axis=-1) # [B, 1] + decode_tags, _ = rnn.dynamic_rnn( # [B, T - 1, 1] + crf_bwd_cell, + inputs=backpointers, + sequence_length=sequence_length - 1, + initial_state=initial_state, + time_major=False, + dtype=dtypes.int32) + decode_tags = array_ops.squeeze(decode_tags, axis=[2]) # [B, T - 1] + decode_tags = array_ops.concat([initial_state, decode_tags], # [B, T] + axis=1) + decode_tags = gen_array_ops.reverse_sequence( # [B, T] + decode_tags, sequence_length, seq_dim=1) + + best_score = math_ops.reduce_max(last_score, axis=1) # [B] + return decode_tags, best_score + + return utils.smart_cond( + pred=math_ops.equal( + potentials.shape[1].value or array_ops.shape(potentials)[1], 1), + fn1=_single_seq_fn, + fn2=_multi_seq_fn) diff --git a/tensorflow/contrib/cudnn_rnn/BUILD b/tensorflow/contrib/cudnn_rnn/BUILD index f192f78b98174d4e1af2e91f90b6a285fe51b628..fce2c03e69bc4b8b0ac46b8e081a33c43c9d41ab 100644 --- a/tensorflow/contrib/cudnn_rnn/BUILD +++ b/tensorflow/contrib/cudnn_rnn/BUILD @@ -54,48 +54,13 @@ tf_gen_op_wrapper_py( deps = [":cudnn_rnn_ops_op_lib"], ) -tf_custom_op_py_library( - name = "cudnn_rnn_ops_py", - srcs = [ - "__init__.py", - "python/ops/cudnn_rnn_ops.py", - ], - dso = [ - ":python/ops/_cudnn_rnn_ops.so", - ], - kernels = [ - ":cudnn_rnn_kernels", - ":cudnn_rnn_ops_op_lib", - ], - srcs_version = "PY2AND3", - visibility = ["//visibility:public"], - deps = [ - ":cudnn_rnn_ops", - "//tensorflow/contrib/rnn:rnn_py", - "//tensorflow/contrib/util:util_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:common_shapes", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:init_ops", - "//tensorflow/python:layers_base", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform", - "//tensorflow/python:random_seed", - "//tensorflow/python:rnn_cell", - "//tensorflow/python:state_ops", - "//tensorflow/python:training", - "//tensorflow/python:util", - "//tensorflow/python:variable_scope", - ], -) - tf_custom_op_py_library( name = "cudnn_rnn_py", srcs = [ "__init__.py", + "python/layers/__init__.py", "python/layers/cudnn_rnn.py", + "python/ops/cudnn_rnn_ops.py", ], dso = [ ":python/ops/_cudnn_rnn_ops.so", @@ -108,7 +73,6 @@ tf_custom_op_py_library( visibility = ["//visibility:public"], deps = [ ":cudnn_rnn_ops", - ":cudnn_rnn_ops_py", "//tensorflow/contrib/util:util_py", "//tensorflow/python:array_ops", "//tensorflow/python:control_flow_ops", @@ -129,7 +93,7 @@ cuda_py_test( size = "large", srcs = ["python/kernel_tests/cudnn_rnn_ops_test.py"], additional_deps = [ - ":cudnn_rnn_ops_py", + ":cudnn_rnn_py", "//tensorflow/core:protos_all_py", "//tensorflow/contrib/rnn:rnn_py", "//tensorflow/python/ops/losses:losses", @@ -154,7 +118,7 @@ cuda_py_test( cuda_py_test( name = "cudnn_rnn_test", - size = "large", + size = "enormous", srcs = ["python/kernel_tests/cudnn_rnn_test.py"], additional_deps = [ ":cudnn_rnn_py", diff --git a/tensorflow/contrib/cudnn_rnn/__init__.py b/tensorflow/contrib/cudnn_rnn/__init__.py index 87ba834770d8f707c5364ed7bb8db4aaaa21f286..5d8c6191f8db9f96532aa78e4790a4665d3b4877 100644 --- a/tensorflow/contrib/cudnn_rnn/__init__.py +++ b/tensorflow/contrib/cudnn_rnn/__init__.py @@ -30,15 +30,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnCompatibleGRUCell -from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnCompatibleLSTMCell -from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnGRU -from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnGRUSaveable -from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnLSTM -from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnLSTMSaveable -from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnRNNRelu -from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnRNNReluSaveable -from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnRNNTanhSaveable +# pylint: disable=unused-import,wildcard-import +from tensorflow.contrib.cudnn_rnn.python.layers import * +# pylint: enable=unused-import,wildcard-import from tensorflow.python.util.all_util import remove_undocumented @@ -56,4 +50,4 @@ _allowed_symbols = [ "CudnnRNNTanhSaveable", ] -remove_undocumented(__name__) +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/cudnn_rnn/kernels/cudnn_rnn_ops.cc b/tensorflow/contrib/cudnn_rnn/kernels/cudnn_rnn_ops.cc index 55fce0a916c9b057234d11d475b56322ce1e29d2..5d5f593d016a3bb9f7b5ea8f5cd40c29268dc4f5 100644 --- a/tensorflow/contrib/cudnn_rnn/kernels/cudnn_rnn_ops.cc +++ b/tensorflow/contrib/cudnn_rnn/kernels/cudnn_rnn_ops.cc @@ -577,6 +577,7 @@ class CudnnRNNParamsSizeOp : public CudnnRNNKernelCommon { .TypeConstraint("S"), \ CudnnRNNParamsSizeOp); +TF_CALL_half(REGISTER_GPU); TF_CALL_float(REGISTER_GPU); TF_CALL_double(REGISTER_GPU); #undef REGISTER_GPU @@ -711,6 +712,7 @@ class CudnnRNNParamsToCanonical : public CudnnRNNKernelCommon { .HostMemory("input_size") \ .TypeConstraint("T"), \ CudnnRNNParamsToCanonical); +TF_CALL_half(REGISTER_GPU); TF_CALL_float(REGISTER_GPU); TF_CALL_double(REGISTER_GPU); #undef REGISTER_GPU @@ -757,7 +759,9 @@ class CudnnRNNCanonicalToParams : public CudnnRNNKernelCommon { .HostMemory("input_size") \ .TypeConstraint("T"), \ CudnnRNNCanonicalToParams); -TF_CALL_float(REGISTER_GPU) TF_CALL_double(REGISTER_GPU); +TF_CALL_half(REGISTER_GPU); +TF_CALL_float(REGISTER_GPU); +TF_CALL_double(REGISTER_GPU); #undef REGISTER_GPU // Run the forward operation of the RNN model. @@ -906,6 +910,7 @@ class CudnnRNNForwardOp : public CudnnRNNKernelCommon { Name("CudnnRNN").Device(DEVICE_GPU).TypeConstraint("T"), \ CudnnRNNForwardOp); +TF_CALL_half(REGISTER_GPU); TF_CALL_float(REGISTER_GPU); TF_CALL_double(REGISTER_GPU); #undef REGISTER_GPU @@ -1125,6 +1130,7 @@ class CudnnRNNBackwardOp : public CudnnRNNKernelCommon { Name("CudnnRNNBackprop").Device(DEVICE_GPU).TypeConstraint("T"), \ CudnnRNNBackwardOp); +TF_CALL_half(REGISTER_GPU); TF_CALL_float(REGISTER_GPU); TF_CALL_double(REGISTER_GPU); #undef REGISTER_GPU diff --git a/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc b/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc index 2b297282b264a3777e0a981a1ecccabb0a3a2c4e..9e41e67857101534e8bfef8d5d0b8a45ed8f1f76 100644 --- a/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc +++ b/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc @@ -75,7 +75,7 @@ REGISTER_OP("CudnnRNNParamsSize") .Input("num_layers: int32") .Input("num_units: int32") .Input("input_size: int32") - .Attr("T: {float32, float64}") + .Attr("T: {float16, float32, float64}") .Attr("S: {int32, int64}") .Attr(kRNNModeAttrs) .Attr(kRNNInputModeAttrs) @@ -130,7 +130,7 @@ REGISTER_OP("CudnnRNN") .Output("output_h: T") .Output("output_c: T") .Output("reserve_space: T") - .Attr("T: {float32, float64}") + .Attr("T: {float16, float32, float64}") .Attr(kRNNModeAttrs) .Attr(kRNNInputModeAttrs) .Attr(kRNNDirectionAttrs) @@ -190,7 +190,7 @@ REGISTER_OP("CudnnRNNBackprop") .Output("input_h_backprop: T") .Output("input_c_backprop: T") .Output("params_backprop: T") - .Attr("T: {float32, float64}") + .Attr("T: {float16, float32, float64}") .Attr(kRNNModeAttrs) .Attr(kRNNInputModeAttrs) .Attr(kRNNDirectionAttrs) @@ -236,7 +236,7 @@ REGISTER_OP("CudnnRNNParamsToCanonical") .Input("params: T") .Output("weights: num_params * T") .Output("biases: num_params * T") - .Attr("T: {float32, float64}") + .Attr("T: {float16, float32, float64}") .Attr("num_params: int") .Attr(kRNNModeAttrs) .Attr(kRNNInputModeAttrs) @@ -279,7 +279,7 @@ REGISTER_OP("CudnnRNNCanonicalToParams") .Input("weights: num_params * T") .Input("biases: num_params * T") .Output("params: T") - .Attr("T: {float32, float64}") + .Attr("T: {float16, float32, float64}") .Attr("num_params: int") .Attr(kRNNModeAttrs) .Attr(kRNNInputModeAttrs) diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py index 9156087f338f0f59f102560d7538b1871c84e23e..5a667485beebe4bee7f051b5920920c72134987f 100644 --- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py +++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py @@ -35,15 +35,11 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradient_checker from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops -from tensorflow.python.ops import rnn as rnn_lib -from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import state_ops from tensorflow.python.ops import variables -from tensorflow.python.ops.losses import losses from tensorflow.python.platform import googletest from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.training import gradient_descent from tensorflow.python.training import saver as saver_lib CUDNN_RNN_UNIDIRECTION = cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION @@ -123,45 +119,6 @@ def _CreateParamsSavable(params, return params_saveable -def _BuildCudnnForward(rnn_mode, - num_layers, - num_units, - input_data, - is_training=False): - input_data_shape = input_data.get_shape().with_rank(3) - batch_size = input_data_shape[1].value - input_size = input_data_shape[2].value - model = _CreateModel(rnn_mode, num_layers, num_units, input_size) - - # Set zero init input states - input_h = constant_op.constant( - np.zeros([num_layers, batch_size, num_units]), dtype=dtypes.float32) - has_input_c = (rnn_mode == cudnn_rnn_ops.CUDNN_LSTM) - if has_input_c: - input_c = constant_op.constant( - np.zeros([num_layers, batch_size, num_units]), dtype=dtypes.float32) - - # Set rnn params - params_size_t = model.params_size() - params = variables.Variable( - random_ops.random_uniform([params_size_t]), validate_shape=False) - args = { - "input_data": input_data, - "input_h": input_h, - "params": params, - "is_training": is_training - } - if has_input_c: - args["input_c"] = input_c - # Build cell - output_tuple = model(**args) - - # Create savable objects for params - _CreateParamsSavable(params, model) - - return output_tuple, model - - def _MinLSTMParamSize(num_layers, num_units, input_size, @@ -181,25 +138,6 @@ def _MinLSTMParamSize(num_layers, raise ValueError("%s direction is not supported.") -def _CreateCudnnCompatibleCanonicalRNN(cudnn_model, - inputs, - scope=None): - model = cudnn_model.rnn_mode - if model not in (cudnn_rnn_ops.CUDNN_LSTM, cudnn_rnn_ops.CUDNN_GRU): - raise ValueError("%s is not supported!" % model) - - num_units = cudnn_model.num_units - num_layers = cudnn_model.num_layers - # To reuse cuDNN-trained models, must use cudnn compatible rnn cells. - if model == cudnn_rnn_ops.CUDNN_LSTM: - single_cell = lambda: cudnn_rnn_ops.CudnnCompatibleLSTMCell(num_units) - else: - single_cell = lambda: cudnn_rnn_ops.CudnnCompatibleGRUCell(num_units) - cell = rnn_cell_impl.MultiRNNCell([single_cell() for _ in range(num_layers)]) - return rnn_lib.dynamic_rnn( - cell, inputs, dtype=dtypes.float32, time_major=True, scope=scope) - - class CudnnRNNTestSaveRestore(TensorFlowTestCase): def _CompareWeights(self, lhs, rhs): @@ -436,143 +374,6 @@ class CudnnRNNTestSaveRestore(TensorFlowTestCase): self._testSaveRestoreOutput(rnn_mode, direction, dtype) -class CudnnRNNTestCompatibleRnnCells(TensorFlowTestCase): - - @unittest.skipUnless(test.is_built_with_cuda(), - "Test only applicable when running on GPUs") - def testCudnnCompatibleRnnCells(self): - configs = [ - { - "num_layers": 1, - "seq_length": 3, - "num_units": 4, - "input_size": 5, - "batch_size": 6, - }, - { - "num_layers": 2, - "seq_length": 8, - "num_units": 4, - "input_size": 8, - "batch_size": 16, - }, - { - "num_layers": 2, - "seq_length": 3, - "num_units": 4, - "input_size": 5, - "batch_size": 6, - }, - { - "num_layers": 1, - "seq_length": 2, - "num_units": 2, - "input_size": 4, - "batch_size": 1, - }, - ] - for rnn, cfg in itertools.product((cudnn_rnn_ops.CUDNN_LSTM,), configs): - self._testCudnnCompatibleRnnCells(cfg["num_layers"], cfg["seq_length"], - cfg["num_units"], cfg["input_size"], - cfg["batch_size"], rnn) - # TODO(jamesqin): Add CudnnCompatibleGRUBlockCell. - for rnn, cfg in itertools.product((cudnn_rnn_ops.CUDNN_GRU,), configs): - self._testCudnnCompatibleRnnCells(cfg["num_layers"], cfg["seq_length"], - cfg["num_units"], cfg["input_size"], - cfg["batch_size"], rnn) - - def _testCudnnCompatibleRnnCells(self, num_layers, seq_length, num_units, - input_size, batch_size, rnn_mode): - has_state_c = rnn_mode == cudnn_rnn_ops.CUDNN_LSTM - np.random.seed(0) - # Train graph - with ops.Graph().as_default(): - random_seed.set_random_seed(299) - input_data = array_ops.placeholder( - dtypes.float32, shape=[seq_length, batch_size, input_size]) - output_tuple, cudnn_model = _BuildCudnnForward( - rnn_mode, num_layers, num_units, input_data, is_training=True) - target_output = array_ops.placeholder(dtype=dtypes.float32, shape=None) - total_sum = sum(map(math_ops.reduce_sum, output_tuple)) - - loss_op = losses.log_loss(labels=target_output, predictions=total_sum) - optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1e-2) - train_op = optimizer.minimize(loss_op) - - saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V2) - - # Train Cudnn model - with self.test_session( - use_gpu=True, graph=ops.get_default_graph()) as sess: - sess.run(variables.global_variables_initializer()) - # Train 128 steps - num_steps = 128 - for _ in range(num_steps): - inputs = np.random.rand(seq_length, batch_size, - input_size).astype(np.float32) - targets = np.random.rand() - sess.run( - train_op, feed_dict={input_data: inputs, - target_output: targets}) - - save_path = os.path.join(self.get_temp_dir(), - ("cudnn-rnn-%s-test" % rnn_mode)) - save_v = saver.save(sess, save_path) - self.assertEqual(save_path, save_v) - - # cuDNN inference graph - with ops.Graph().as_default(): - random_seed.set_random_seed(299) - cudnn_inputs = array_ops.placeholder( - dtypes.float32, shape=[seq_length, batch_size, input_size]) - (cudnn_output_tuple, cudnn_model) = _BuildCudnnForward( - rnn_mode, num_layers, num_units, cudnn_inputs, is_training=False) - saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V2) - - inference_input = np.random.rand(seq_length, batch_size, - input_size).astype(np.float32) - with self.test_session( - use_gpu=True, graph=ops.get_default_graph()) as sess: - sess.run(variables.global_variables_initializer()) - saver.restore(sess, save_path) - - # Cudnn inference - cudnn_output = sess.run( - cudnn_output_tuple, feed_dict={cudnn_inputs: inference_input}) - - # Canonical RNN inference graph - with ops.Graph().as_default(): - random_seed.set_random_seed(299) - cell_inputs = array_ops.placeholder( - dtypes.float32, shape=[seq_length, batch_size, input_size]) - (output, states) = _CreateCudnnCompatibleCanonicalRNN( - cudnn_model, cell_inputs) - saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V2) - - with self.test_session( - use_gpu=True, graph=ops.get_default_graph()) as sess: - saver.restore(sess, save_path) - - # BlockCell inference - output_v, states_v = sess.run( - [output, states], feed_dict={cell_inputs: inference_input}) - - # output across timestamps are packed into one tensor. - self.assertAllClose(cudnn_output[0], output_v, atol=1e-6, rtol=1e-6) - - for i in range(num_layers): - if has_state_c: - # output_h - self.assertAllClose( - cudnn_output[1][i, :], states_v[i].h, atol=1e-6, rtol=1e-6) - # output_c - self.assertAllClose( - cudnn_output[2][i, :], states_v[i].c, atol=1e-6, rtol=1e-6) - else: - self.assertAllClose( - cudnn_output[1][i, :], states_v[i], atol=1e-6, rtol=1e-6) - - class CudnnRNNTestParamsSize(TensorFlowTestCase): def _testOneLSTMParamsSize(self, num_layers, num_units, input_size, diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py index 1ce8954bb09d7444a552d0ba6b3d9bb72cd919fd..e65394cba07574ed49398981f1cbd8bcb402e24f 100644 --- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py +++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py @@ -17,8 +17,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import argparse +import collections import itertools import os +import sys import unittest import numpy as np @@ -49,6 +52,7 @@ from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import gradient_descent from tensorflow.python.training import saver as saver_lib + CUDNN_LSTM = cudnn_rnn_ops.CUDNN_LSTM CUDNN_GRU = cudnn_rnn_ops.CUDNN_GRU CUDNN_RNN_RELU = cudnn_rnn_ops.CUDNN_RNN_RELU @@ -78,9 +82,10 @@ class CudnnTestModel(object): dropout=0., dtype=dtypes.float32, training=False, + seed=None, kernel_initializer=None, bias_initializer=None): - if dtype not in (dtypes.float32, dtypes.float64): + if dtype not in (dtypes.float16, dtypes.float32, dtypes.float64): raise ValueError("Invalid dtype: %s" % dtype) self._dtype = dtype @@ -110,6 +115,7 @@ class CudnnTestModel(object): direction=direction, dropout=dropout, dtype=dtype, + seed=seed, kernel_initializer=kernel_initializer, bias_initializer=bias_initializer) self._rnn.build([None, None, input_size]) @@ -499,7 +505,7 @@ class CudnnRNNTestSaveRestore(TensorFlowTestCase): def _TestSaveRestoreHelper(self, rnn_mode): directions = [CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION] - dtype_list = [dtypes.float32, dtypes.float64] + dtype_list = [dtypes.float16, dtypes.float32, dtypes.float64] for direction, dtype in itertools.product(directions, dtype_list): self._TestSaveRestoreVariable(rnn_mode, direction, dtype) self._TestSaveRestoreTwoVariables(rnn_mode, direction, dtype) @@ -722,19 +728,17 @@ class CudnnRNNTestCompatibleRNNCells(TensorFlowTestCase): outputs_v, output_state_v = sess.run( [outputs, output_state], feed_dict={cell_inputs: inference_input}) - self.assertAllClose(cudnn_outputs_v, outputs_v, atol=1e-5, rtol=1e-5) + self.assertAllClose(cudnn_outputs_v, outputs_v, atol=2e-5, rtol=2e-5) (cudnn_output_h_v,) = cudnn_output_states_v - self.assertAllClose(cudnn_output_h_v, output_state_v, atol=1e-5, - rtol=1e-5) + self.assertAllClose(cudnn_output_h_v, output_state_v, atol=2e-5, + rtol=2e-5) class CudnnRNNTestParamsSize(TensorFlowTestCase): def _TestOpaqueParamsSize(self, rnn_mode, num_layers, num_units, input_size, - direction): + dtype, direction): logging.info("Testing one lstm param size with config: %s", locals()) - dtype = dtypes.float32 - model = CudnnTestModel( rnn_mode, num_layers, @@ -767,13 +771,14 @@ class CudnnRNNTestParamsSize(TensorFlowTestCase): [3, 200, 400], ] directions = [CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION] + dtype_list = [dtypes.float16, dtypes.float32, dtypes.float64] rnns = [CUDNN_LSTM, CUDNN_GRU, CUDNN_RNN_RELU, CUDNN_RNN_TANH] - for (rnn, config, direction) in itertools.product(rnns, test_configs, - directions): + for (rnn, config, dtype, direction) in itertools.product( + rnns, test_configs, dtype_list, directions): num_layers, num_units, input_size = config with ops.Graph().as_default(): self._TestOpaqueParamsSize(rnn, num_layers, num_units, input_size, - direction) + dtype, direction) class CudnnRNNTestTraining(TensorFlowTestCase): @@ -819,9 +824,63 @@ class CudnnRNNTestTraining(TensorFlowTestCase): numeric_grad[i] = (y_pos - y_neg) / (2 * delta) return numeric_grad.reshape(x_shape) + def _GetShape(self, sess, inputs): + if not isinstance(inputs, collections.Iterable): + return sess.run(array_ops.shape(inputs)) + else: + return sess.run([array_ops.shape(x) for x in inputs]) + + def _GradientCheckFp16(self, sess, y, xs, num_samples, + tolerance=1e-6, delta=1e-4): + """Gradient check for Fp16. + + Fp16 numerical gradients end up being zeros. Use a new way to check + gradients: + + Given multi-variant function: + y = f(x1, x2, ... xn) + delta_y = f(x1 + delta_x1, x2+delta_x2, ..., xn+delta_xn) - + f(x1, x2, ..., xn) + = f'(x1) * delta_x1 + f'(x2) * delta_x2 + .. + f'(xn) * delta_xn + where: + delta_xi are very small disturbance. + f'(xi) is the gradient of y w.r.t xi. + + The gradient check verifies the expected delta_y calculated by the above + equation is close to the actual delta_y. + Args: + sess: tf.Session object. + y: output tensor. + xs: a tensor or a list of input tensors. + num_samples: number of test samples to run. + tolerance: error tolerance. + delta: the order of magnititued of input disturbance to apply to calculate + the output change w.r.t inputs. + """ + sym_grads = self._ComputeSymGrads(sess, y, xs) + xs_shapes = self._GetShape(sess, xs) + + x_vals = [sess.run(x) for x in xs] + for _ in range(num_samples): + delta_xs = [delta * np.random.rand(*shape.tolist()) + for shape in xs_shapes] + + feed_dict = {} + for x, x_val, delta_x in zip(xs, x_vals, delta_xs): + feed_dict[x] = x_val + delta_x + actual_delta_y = (float(sess.run(y, feed_dict=feed_dict)) - + float(sess.run(y))) + + expected_delta_y = 0. + for sym_grad, delta_x in zip(sym_grads, delta_xs): + expected_delta_y += np.dot( + sym_grad.astype(np.float32).flatten(), + delta_x.astype(np.float32).flatten()) + self.assertAllClose(expected_delta_y, actual_delta_y, + atol=tolerance, rtol=tolerance) + def _GradientCheck(self, sess, y, xs, tolerance=1e-6, delta=1e-4): - sym_grads_t = gradients.gradients(y, xs) - sym_grads = sess.run(sym_grads_t) + sym_grads = self._ComputeSymGrads(sess, y, xs) num_grads = [self._ComputeNumericGrad(sess, y, x, delta) for x in xs] self.assertEqual(len(sym_grads), len(num_grads)) @@ -830,6 +889,10 @@ class CudnnRNNTestTraining(TensorFlowTestCase): self.assertFalse(np.any(np.isnan(num))) self.assertAllClose(sym, num, atol=tolerance, rtol=tolerance) + def _ComputeSymGrads(self, sess, y, xs): + sym_grads_t = gradients.gradients(y, xs) + return sess.run(sym_grads_t) + def _TestOneSimpleTraining(self, rnn_mode, num_layers, num_units, input_size, batch_size, seq_length, dir_count, dropout, dtype, delta, tolerance): @@ -838,6 +901,8 @@ class CudnnRNNTestTraining(TensorFlowTestCase): logging.info("Training test with config: %s", locals()) old_env_state = os.environ.get("TF_CUDNN_RESET_RND_GEN_STATE", str(False)) os.environ["TF_CUDNN_RESET_RND_GEN_STATE"] = str(True) + + np.random.seed(1234) random_seed.set_random_seed(5678) has_input_c = (rnn_mode == CUDNN_LSTM) direction = (CUDNN_RNN_UNIDIRECTION @@ -879,12 +944,22 @@ class CudnnRNNTestTraining(TensorFlowTestCase): all_inputs = [inputs, params] for s in initial_state: all_inputs.append(s) - self._GradientCheck( - sess, total_sum, all_inputs, tolerance=tolerance, delta=delta) + if dtype == dtypes.float16: + self._GradientCheckFp16( + sess, total_sum, all_inputs, + num_samples=FLAGS.grad_check_num_samples, + tolerance=tolerance, delta=delta) + else: + for _ in range(FLAGS.grad_check_num_samples): + # Each time choose a different set of inputs. + sess.run(variables.global_variables_initializer()) + self._GradientCheck( + sess, total_sum, all_inputs, + tolerance=tolerance, delta=delta) os.environ["TF_CUDNN_RESET_RND_GEN_STATE"] = old_env_state def _TestSimpleTrainingHelper(self, rnn_mode, test_configs): - dropouts = [0., 0.5, 1.] + dropouts = [0, 0.5, 1.] for config, dropout in itertools.product(test_configs, dropouts): dtype = config.get("dtype", dtypes.float32) delta = config.get("delta", 1e-4) @@ -895,11 +970,12 @@ class CudnnRNNTestTraining(TensorFlowTestCase): self._TestOneSimpleTraining(rnn_mode, shape["num_layers"], shape["num_units"], shape["input_size"], shape["batch_size"], shape["seq_length"], - dir_count, dropout, dtype, delta, tolerance) + dir_count, dropout, dtype, delta, + tolerance) @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") - def testSimpleTrainingLSTM64(self): + def testSimpleTrainingLSTMFp64(self): test_configs = [ { "dtype": dtypes.float64, @@ -917,7 +993,7 @@ class CudnnRNNTestTraining(TensorFlowTestCase): @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") - def testSimpleTrainingLSTM32(self): + def testSimpleTrainingLSTMFp32(self): test_configs = [ { "dtype": dtypes.float32, @@ -936,7 +1012,38 @@ class CudnnRNNTestTraining(TensorFlowTestCase): @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") - def testSimpleTrainingGRU64(self): + def testSimpleTrainingLSTMFp16(self): + test_configs = [ + { + "dtype": dtypes.float16, + "delta": 1e-3, + "tolerance": 9e-2, + "shape": { + "num_layers": 2, + "num_units": 3, + "input_size": 4, + "batch_size": 3, + "seq_length": 4, + }, + }, + { + "dtype": dtypes.float16, + "delta": 1e-2, + "tolerance": 9e-2, + "shape": { + "num_layers": 2, + "num_units": 6, + "input_size": 8, + "batch_size": 6, + "seq_length": 4, + }, + }, + ] + self._TestSimpleTrainingHelper(CUDNN_LSTM, test_configs) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testSimpleTrainingGRUFp64(self): test_configs = [ { "dtype": dtypes.float64, @@ -954,7 +1061,7 @@ class CudnnRNNTestTraining(TensorFlowTestCase): @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") - def testSimpleTrainingGRU32(self): + def testSimpleTrainingGRUFp32(self): test_configs = [ { "dtype": dtypes.float32, @@ -973,7 +1080,26 @@ class CudnnRNNTestTraining(TensorFlowTestCase): @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") - def testSimpleTrainingRNNTanh64(self): + def testSimpleTrainingGRUFp16(self): + test_configs = [ + { + "dtype": dtypes.float16, + "delta": 2e-3, + "tolerance": 6e-2, + "shape": { + "num_layers": 2, + "num_units": 3, + "input_size": 4, + "batch_size": 3, + "seq_length": 4, + }, + }, + ] + self._TestSimpleTrainingHelper(CUDNN_GRU, test_configs) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testSimpleTrainingRNNTanhFp64(self): test_configs = [ { "dtype": dtypes.float64, @@ -991,7 +1117,7 @@ class CudnnRNNTestTraining(TensorFlowTestCase): @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") - def testSimpleTrainingRNNTanh32(self): + def testSimpleTrainingRNNTanhFp32(self): test_configs = [ { "dtype": dtypes.float32, @@ -1010,7 +1136,26 @@ class CudnnRNNTestTraining(TensorFlowTestCase): @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") - def testSimpleTrainingRNNRelu64(self): + def testSimpleTrainingRNNTanhFp16(self): + test_configs = [ + { + "dtype": dtypes.float16, + "delta": 1e-3, + "tolerance": 5e-2, + "shape": { + "num_layers": 2, + "num_units": 3, + "input_size": 4, + "batch_size": 3, + "seq_length": 4, + }, + }, + ] + self._TestSimpleTrainingHelper(CUDNN_RNN_TANH, test_configs) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testSimpleTrainingRNNReluFp64(self): test_configs = [ { "dtype": dtypes.float64, @@ -1028,10 +1173,29 @@ class CudnnRNNTestTraining(TensorFlowTestCase): @unittest.skipUnless(test.is_built_with_cuda(), "Test only applicable when running on GPUs") - def testSimpleTrainingRNNRelu32(self): + def testSimpleTrainingRNNReluFp32(self): test_configs = [ { "dtype": dtypes.float32, + "delta": 1e-4, + "tolerance": 3e-1, + "shape": { + "num_layers": 2, + "num_units": 3, + "input_size": 4, + "batch_size": 3, + "seq_length": 4, + }, + }, + ] + self._TestSimpleTrainingHelper(CUDNN_RNN_RELU, test_configs) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testSimpleTrainingRNNReluFp16(self): + test_configs = [ + { + "dtype": dtypes.float16, "delta": 1e-3, "tolerance": 7e-2, "shape": { @@ -1047,4 +1211,13 @@ class CudnnRNNTestTraining(TensorFlowTestCase): if __name__ == "__main__": + argv0 = sys.argv[0] + parser = argparse.ArgumentParser() + parser.add_argument( + "--grad_check_num_samples", + type=int, + default=5, + help="Number of samples to run for gradient check.") + FLAGS, unparsed = parser.parse_known_args() + sys.argv = [argv0] + unparsed googletest.main() diff --git a/tensorflow/contrib/cudnn_rnn/python/layers/__init__.py b/tensorflow/contrib/cudnn_rnn/python/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f09466b631f69d6234573dd5eafada650421c117 --- /dev/null +++ b/tensorflow/contrib/cudnn_rnn/python/layers/__init__.py @@ -0,0 +1,31 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""layers module with higher level CudnnRNN primitives.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import sys + +# pylint: disable=unused-import,wildcard-import +from tensorflow.contrib.cudnn_rnn.python.layers.cudnn_rnn import * +# pylint: enable=unused-import,wildcard-import + +from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnCompatibleGRUCell +from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnCompatibleLSTMCell +from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnGRUSaveable +from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnLSTMSaveable +from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnRNNReluSaveable +from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnRNNTanhSaveable diff --git a/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py b/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py index 3d3f8a3be0554c709ce053106f754f27d8ed630a..37c61a71a3bdac4fadef58ba8c24b853fb3638ef 100644 --- a/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py +++ b/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py @@ -27,6 +27,7 @@ from tensorflow.python.ops import init_ops from tensorflow.python.ops import variable_scope as vs from tensorflow.python.platform import tf_logging as logging + CUDNN_RNN_UNIDIRECTION = cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION CUDNN_RNN_BIDIRECTION = cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION CUDNN_LSTM = cudnn_rnn_ops.CUDNN_LSTM @@ -45,6 +46,9 @@ CUDNN_INPUT_SKIP_MODE = cudnn_rnn_ops.CUDNN_INPUT_SKIP_MODE CUDNN_INPUT_AUTO_MODE = cudnn_rnn_ops.CUDNN_INPUT_AUTO_MODE +__all__ = ["CudnnLSTM", "CudnnGRU", "CudnnRNNTanh", "CudnnRNNRelu"] + + class _CudnnRNN(base_layer.Layer): # pylint:disable=line-too-long """Abstract class for RNN layers with Cudnn implementation. @@ -146,7 +150,6 @@ class _CudnnRNN(base_layer.Layer): # Custom SaveableObject class for the CudnnRNN class. _saveable_cls = None - # TODO(jamesqin): support float16 CuDNN RNN def __init__(self, num_layers, num_units, @@ -177,7 +180,7 @@ class _CudnnRNN(base_layer.Layer): inputs of each layer. When set to 0, dropout is disabled. seed: the op seed used for initializing dropout. See @{tf.set_random_seed} for behavior. - dtype: tf.float32 or tf.float64 + dtype: tf.float16, tf.float32 or tf.float64 kernel_initializer: starting value to initialize the weight. bias_initializer: starting value to initialize the bias (default is all zeros). @@ -192,8 +195,9 @@ class _CudnnRNN(base_layer.Layer): cudnn_rnn_ops.check_direction(direction) cudnn_rnn_ops.check_input_mode(input_mode) - if dtype not in [dtypes.float32, dtypes.float64]: - raise ValueError("Only support float32, float64, provided %s" % dtype) + if dtype not in [dtypes.float16, dtypes.float32, dtypes.float64]: + raise ValueError( + "Only support float16, float32, float64, provided %s" % dtype) # Layer self.dtype is type name, the original DType object is kept here. self._plain_dtype = dtype self._num_layers = num_layers @@ -454,6 +458,8 @@ class _CudnnRNN(base_layer.Layer): weights=cu_weights, biases=cu_biases, input_mode=self._input_mode, + seed=self._seed, + dropout=self._dropout, direction=self._direction) def _forward(self, inputs, h, c, opaque_params, training): diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py index 7d658c746ee1ecd21cefca9c9e52f611869f6176..dcd3d4732a27ae4bec579ac12ac568dc4a53baaa 100644 --- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py +++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py @@ -28,6 +28,7 @@ from tensorflow.python.layers import base as base_layer from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope as vs @@ -54,6 +55,11 @@ CUDNN_INPUT_LINEAR_MODE = "linear_input" CUDNN_INPUT_SKIP_MODE = "skip_input" CUDNN_INPUT_AUTO_MODE = "auto_select" +# pylint:disable=protected-access +_BIAS_VARIABLE_NAME = rnn_cell_impl._BIAS_VARIABLE_NAME +_WEIGHTS_VARIABLE_NAME = rnn_cell_impl._WEIGHTS_VARIABLE_NAME +# pylint:enable=protected-access + class CudnnCompatibleLSTMCell(lstm_ops.LSTMBlockCell): """Cudnn Compatible LSTMCell. @@ -86,9 +92,9 @@ class CudnnCompatibleGRUCell(rnn_cell_impl.GRUCell): Cudnn compatible GRU (from Cudnn library user guide): ```python r_t = sigma(x_t * W_r + h_t-1 * R_h + b_Wr + b_Rr) # reset gate - i_t = sigma(x_t * W_i + h_t-1 * R_i + b_Wi + b_Ru) # update gate + u_t = sigma(x_t * W_u + h_t-1 * R_u + b_Wu + b_Ru) # update gate h'_t = tanh(x_t * W_h + r_t .* (h_t-1 * R_h + b_Rh) + b_Wh) # new memory gate - h_t = (1 - i_t) .* h'_t + i_t .* h_t-1 + h_t = (1 - u_t) .* h'_t + u_t .* h_t-1 ``` Other GRU (see @{tf.nn.rnn_cell.GRUCell} and @{tf.contrib.rnn.GRUBlockCell}): @@ -99,9 +105,6 @@ class CudnnCompatibleGRUCell(rnn_cell_impl.GRUCell): ```python r .* (h * R) != (r .* h) * R ``` - - TODO(jamesqin): update the impl after Cudnn 7.1 when Nvidia would adopt the - canonical version compatible with other tf GRU cells. """ def __init__(self, num_units, reuse=None, kernel_initializer=None): @@ -111,33 +114,65 @@ class CudnnCompatibleGRUCell(rnn_cell_impl.GRUCell): reuse=reuse, kernel_initializer=kernel_initializer) + def build(self, inputs_shape): + if inputs_shape[1].value is None: + raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" + % inputs_shape) + + input_depth = inputs_shape[1].value + self._gate_kernel = self.add_variable( + "gates/%s" % _WEIGHTS_VARIABLE_NAME, + shape=[input_depth + self._num_units, 2 * self._num_units], + initializer=self._kernel_initializer) + self._gate_bias = self.add_variable( + "gates/%s" % _BIAS_VARIABLE_NAME, + shape=[2 * self._num_units], + initializer=( + self._bias_initializer + if self._bias_initializer is not None + else init_ops.constant_initializer(1.0, dtype=self.dtype))) + + self._candidate_input_kernel = self.add_variable( + "candidate/input_projection/%s" % _WEIGHTS_VARIABLE_NAME, + shape=[input_depth, self._num_units], + initializer=self._kernel_initializer) + self._candidate_hidden_kernel = self.add_variable( + "candidate/hidden_projection/%s" % _WEIGHTS_VARIABLE_NAME, + shape=[self._num_units, self._num_units], + initializer=self._kernel_initializer) + + self._candidate_input_bias = self.add_variable( + "candidate/input_projection/%s" % _BIAS_VARIABLE_NAME, + shape=[self._num_units], + initializer=( + self._bias_initializer + if self._bias_initializer is not None + else init_ops.zeros_initializer(dtype=self.dtype))) + self._candidate_hidden_bias = self.add_variable( + "candidate/hidden_projection/%s" % _BIAS_VARIABLE_NAME, + shape=[self._num_units], + initializer=( + self._bias_initializer + if self._bias_initializer is not None + else init_ops.zeros_initializer(dtype=self.dtype))) + def call(self, inputs, state): """Gated recurrent unit (GRU) with nunits cells.""" - with vs.variable_scope("gates"): # Reset gate and update gate. - # We start with bias of 1.0 to not reset and not update. - bias_ones = self._bias_initializer - if self._bias_initializer is None: - dtype = inputs.dtype - bias_ones = init_ops.constant_initializer(1.0, dtype=dtype) - # pylint: disable=protected-access - value = math_ops.sigmoid( - rnn_cell_impl._linear([inputs, state], 2 * self._num_units, True, - bias_ones, self._kernel_initializer)) - r, u = array_ops.split(value=value, num_or_size_splits=2, axis=1) - # pylint: enable=protected-access - with vs.variable_scope("candidate"): - # pylint: disable=protected-access - with vs.variable_scope("input_projection"): - hi = rnn_cell_impl._linear(inputs, self._num_units, True, - self._bias_initializer, - self._kernel_initializer) - with vs.variable_scope("hidden_projection"): - hh = r * (rnn_cell_impl._linear(state, self._num_units, True, - self._bias_initializer, - self._kernel_initializer)) - # pylint: enable=protected-access - c = self._activation(hi + hh) - new_h = u * state + (1 - u) * c + gate_inputs = math_ops.matmul( + array_ops.concat([inputs, state], 1), self._gate_kernel) + gate_inputs = nn_ops.bias_add(gate_inputs, self._gate_bias) + + value = math_ops.sigmoid(gate_inputs) + r, u = array_ops.split(value=value, num_or_size_splits=2, axis=1) + + candidate = nn_ops.bias_add( + math_ops.matmul(inputs, self._candidate_input_kernel), + self._candidate_input_bias) + candidate += r * nn_ops.bias_add( + math_ops.matmul(state, self._candidate_hidden_kernel), + self._candidate_hidden_bias) + candidate = self._activation(candidate) + new_h = (1-u) * candidate + u * state return new_h, new_h diff --git a/tensorflow/contrib/data/BUILD b/tensorflow/contrib/data/BUILD index eaede0e00ecf1986873d50709d135d3f4b3ac9cd..3b1c33063f1214b68f79560f50d56bf5d31c9560 100644 --- a/tensorflow/contrib/data/BUILD +++ b/tensorflow/contrib/data/BUILD @@ -17,8 +17,8 @@ py_library( deps = [ "//tensorflow/contrib/data/python/ops:dataset_ops", "//tensorflow/contrib/data/python/ops:iterator_ops", - "//tensorflow/contrib/data/python/ops:prefetching_py", "//tensorflow/contrib/data/python/ops:readers", + "//tensorflow/contrib/data/python/ops:shuffle_ops", "//tensorflow/contrib/data/python/ops:transformation_ops", "//tensorflow/python:util", "//tensorflow/python/data/ops:iterator_ops", @@ -27,12 +27,8 @@ py_library( tf_custom_op_library( name = "_prefetching_ops.so", - srcs = [ - "ops/prefetching_ops.cc", - ], - deps = [ - "//tensorflow/contrib/data/kernels:prefetching_kernels", - ], + srcs = ["ops/prefetching_ops.cc"], + deps = ["//tensorflow/contrib/data/kernels:prefetching_kernels"], ) tf_gen_op_libs( @@ -42,7 +38,9 @@ tf_gen_op_libs( filegroup( name = "all_files", srcs = glob( - ["**/*"], + include = [ + "**/*", + ], exclude = [ "**/METADATA", "**/OWNERS", diff --git a/tensorflow/contrib/data/README.md b/tensorflow/contrib/data/README.md index 30e909111f460bb4d0ea5fcdefaf5bdedc93b9c0..848782e8d89b8670caf3b45de4912a7e0855c102 100644 --- a/tensorflow/contrib/data/README.md +++ b/tensorflow/contrib/data/README.md @@ -18,7 +18,7 @@ The arguments accepted by the `Dataset.map()` transformation have changed: * `dataset.map(..., num_threads=T)` is now `dataset.map(num_parallel_calls=T)`. * `dataset.map(..., output_buffer_size=B)` is now - `dataset.map(...).prefetch(B). + `dataset.map(...).prefetch(B)`. Some transformations have been removed from `tf.data.Dataset`, and you must instead apply them using `Dataset.apply()` transformation. The full list of diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py index 6c46acf20442c2cc435829afa57e8383b493d6af..c9ad091bd44d6e3a9368e182c3df9fc1c6e48071 100644 --- a/tensorflow/contrib/data/__init__.py +++ b/tensorflow/contrib/data/__init__.py @@ -17,12 +17,14 @@ See the @{$datasets$Importing Data} Programmer's Guide for an overview. @@Dataset +@@Counter @@Iterator @@TFRecordDataset @@FixedLengthRecordDataset @@TextLineDataset @@batch_and_drop_remainder +@@padded_batch_and_drop_remainder @@dense_to_sparse_batch @@enumerate_dataset @@group_by_window @@ -30,7 +32,9 @@ See the @{$datasets$Importing Data} Programmer's Guide for an overview. @@make_saveable_from_iterator @@read_batch_features @@unbatch +@@parallel_interleave @@rejection_resample +@@scan @@sloppy_interleave @@get_single_element @@ -44,12 +48,15 @@ from __future__ import print_function from tensorflow.contrib.data.python.ops.batching import batch_and_drop_remainder from tensorflow.contrib.data.python.ops.batching import dense_to_sparse_batch +from tensorflow.contrib.data.python.ops.batching import padded_batch_and_drop_remainder from tensorflow.contrib.data.python.ops.batching import unbatch +from tensorflow.contrib.data.python.ops.counter import Counter from tensorflow.contrib.data.python.ops.dataset_ops import Dataset from tensorflow.contrib.data.python.ops.dataset_ops import get_single_element from tensorflow.contrib.data.python.ops.enumerate_ops import enumerate_dataset from tensorflow.contrib.data.python.ops.error_ops import ignore_errors from tensorflow.contrib.data.python.ops.grouping import group_by_window +from tensorflow.contrib.data.python.ops.interleave_ops import parallel_interleave from tensorflow.contrib.data.python.ops.interleave_ops import sloppy_interleave from tensorflow.contrib.data.python.ops.iterator_ops import make_saveable_from_iterator from tensorflow.contrib.data.python.ops.readers import FixedLengthRecordDataset @@ -58,6 +65,8 @@ from tensorflow.contrib.data.python.ops.readers import SqlDataset from tensorflow.contrib.data.python.ops.readers import TextLineDataset from tensorflow.contrib.data.python.ops.readers import TFRecordDataset from tensorflow.contrib.data.python.ops.resampling import rejection_resample +from tensorflow.contrib.data.python.ops.scan_ops import scan +from tensorflow.contrib.data.python.ops.shuffle_ops import shuffle_and_repeat from tensorflow.python.data.ops.iterator_ops import Iterator # pylint: enable=unused-import diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index 424eb198522ce3d11152c2f8da6a2a5d82432cec..9b6ad9329482815b666d11d1b32b245e3ea62b54 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -4,14 +4,16 @@ licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) -load("//tensorflow:tensorflow.bzl", "py_test") +load("//tensorflow:tensorflow.bzl", "py_test", "tf_py_test") py_test( name = "batch_dataset_op_test", size = "small", srcs = ["batch_dataset_op_test.py"], srcs_version = "PY2AND3", + tags = ["no_pip"], deps = [ + ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:dataset_ops", "//tensorflow/contrib/data/python/ops:transformation_ops", "//tensorflow/python:array_ops", @@ -95,8 +97,8 @@ py_test( "nomac", # b/62040583 ], deps = [ + ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:dataset_ops", - "//tensorflow/contrib/data/python/ops:iterator_ops", "//tensorflow/contrib/data/python/ops:transformation_ops", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", @@ -108,18 +110,42 @@ py_test( "//tensorflow/python:resource_variable_ops", "//tensorflow/python:session", "//tensorflow/python:sparse_tensor", - "//tensorflow/python:training", + "//tensorflow/python:tensor_shape", "//tensorflow/python/data/util:nest", "//third_party/py/numpy", ], ) +py_library( + name = "dataset_serialization_test", + testonly = 1, + srcs = [ + "dataset_serialization_test_base.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/data/python/ops:iterator_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", + "//tensorflow/python:platform", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python:training", + "//tensorflow/python:util", + "//tensorflow/python:variables", + "//tensorflow/python/data/ops:iterator_ops", + "//third_party/py/numpy", + ], +) + py_test( name = "filter_dataset_op_test", size = "small", srcs = ["filter_dataset_op_test.py"], srcs_version = "PY2AND3", + tags = ["no_pip"], deps = [ + ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:dataset_ops", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -131,21 +157,28 @@ py_test( ], ) -py_test( +tf_py_test( name = "flat_map_dataset_op_test", size = "small", srcs = ["flat_map_dataset_op_test.py"], - srcs_version = "PY2AND3", - deps = [ + additional_deps = [ + ":dataset_serialization_test", + "//third_party/py/numpy", "//tensorflow/contrib/data/python/ops:dataset_ops", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", "//tensorflow/python:errors", + "//tensorflow/python:function", + "//tensorflow/python:math_ops", + "//tensorflow/python:random_ops", "//tensorflow/python:session", "//tensorflow/python:training", - "//third_party/py/numpy", + "//tensorflow/python:variable_scope", ], + grpc_enabled = True, + tags = ["no_pip"], ) py_test( @@ -157,6 +190,7 @@ py_test( "manual", # b/67958761 ], deps = [ + ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:dataset_ops", "//tensorflow/contrib/data/python/ops:transformation_ops", "//tensorflow/python:array_ops", @@ -166,18 +200,18 @@ py_test( "//tensorflow/python:errors", "//tensorflow/python:math_ops", "//tensorflow/python:script_ops", + "//tensorflow/python:sparse_ops", + "//tensorflow/python:sparse_tensor", "//tensorflow/python:training", "//third_party/py/numpy", ], ) -py_test( +tf_py_test( name = "iterator_ops_cluster_test", size = "small", srcs = ["iterator_ops_cluster_test.py"], - srcs_version = "PY2AND3", - tags = ["no_windows"], - deps = [ + additional_deps = [ "//tensorflow/contrib/data/python/ops:dataset_ops", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", @@ -191,14 +225,19 @@ py_test( "//tensorflow/python:session", "//tensorflow/python/data/ops:iterator_ops", ], + grpc_enabled = True, + tags = [ + "no_windows", + "oss_serial", + ], ) -py_test( +tf_py_test( name = "iterator_ops_test", size = "small", srcs = ["iterator_ops_test.py"], - srcs_version = "PY2AND3", - deps = [ + additional_deps = [ + "//third_party/py/numpy", "//tensorflow/contrib/data/python/ops:dataset_ops", "//tensorflow/contrib/data/python/ops:readers", "//tensorflow/core:protos_all_py", @@ -220,8 +259,8 @@ py_test( "//tensorflow/python:session", "//tensorflow/python:training", "//tensorflow/python/data/ops:iterator_ops", - "//third_party/py/numpy", ], + grpc_enabled = True, ) py_test( @@ -241,12 +280,13 @@ py_test( py_test( name = "map_dataset_op_test", - size = "small", + size = "medium", srcs = ["map_dataset_op_test.py"], srcs_version = "PY2AND3", + tags = ["no_pip"], deps = [ + ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:dataset_ops", - "//tensorflow/contrib/data/python/ops:iterator_ops", "//tensorflow/contrib/data/python/ops:transformation_ops", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -255,20 +295,35 @@ py_test( "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:framework_ops", + "//tensorflow/python:function", "//tensorflow/python:functional_ops", "//tensorflow/python:io_ops", "//tensorflow/python:lookup_ops", "//tensorflow/python:math_ops", "//tensorflow/python:random_ops", "//tensorflow/python:script_ops", + "//tensorflow/python:sparse_ops", + "//tensorflow/python:sparse_tensor", "//tensorflow/python:string_ops", - "//tensorflow/python:training", "//tensorflow/python:util", "//tensorflow/python:variable_scope", "//third_party/py/numpy", ], ) +py_test( + name = "prefetch_dataset_op_test", + size = "small", + srcs = ["prefetch_dataset_op_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test", + "//tensorflow/python:platform", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + py_test( name = "range_dataset_op_test", size = "small", @@ -297,25 +352,22 @@ py_test( py_test( name = "reader_dataset_ops_test", - size = "small", + size = "medium", srcs = ["reader_dataset_ops_test.py"], srcs_version = "PY2AND3", + tags = ["no_pip"], deps = [ - "//tensorflow/contrib/data/python/ops:iterator_ops", + ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:readers", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", - "//tensorflow/python:dataset_ops_gen", "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:framework_ops", - "//tensorflow/python:io_ops", "//tensorflow/python:lib", "//tensorflow/python:parsing_ops", - "//tensorflow/python:tensor_shape", - "//tensorflow/python:training", "//tensorflow/python:util", "//tensorflow/python/data/ops:iterator_ops", ], @@ -341,10 +393,12 @@ py_test( py_test( name = "sequence_dataset_op_test", - size = "small", + size = "medium", srcs = ["sequence_dataset_op_test.py"], srcs_version = "PY2AND3", + tags = ["no_pip"], deps = [ + ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:dataset_ops", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -368,16 +422,24 @@ py_test( py_test( name = "shuffle_dataset_op_test", - size = "small", + size = "medium", srcs = ["shuffle_dataset_op_test.py"], srcs_version = "PY2AND3", + tags = ["no_pip"], deps = [ + ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:dataset_ops", + "//tensorflow/contrib/data/python/ops:iterator_ops", + "//tensorflow/contrib/data/python/ops:shuffle_ops", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", + "//tensorflow/python:platform", + "//tensorflow/python:training", + "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/ops:iterator_ops", "//third_party/py/numpy", ], @@ -397,21 +459,32 @@ py_test( ], ) +py_test( + name = "stats_dataset_ops_test", + size = "small", + srcs = ["stats_dataset_ops_test.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/data/python/ops:dataset_ops", + "//tensorflow/contrib/data/python/ops:transformation_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:errors", + ], +) + py_test( name = "zip_dataset_op_test", size = "small", srcs = ["zip_dataset_op_test.py"], srcs_version = "PY2AND3", + tags = ["no_pip"], deps = [ + ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:dataset_ops", - "//tensorflow/contrib/data/python/ops:iterator_ops", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", "//tensorflow/python:errors", - "//tensorflow/python:framework_ops", - "//tensorflow/python:training", - "//tensorflow/python/data/util:nest", "//third_party/py/numpy", ], ) @@ -421,20 +494,31 @@ py_test( size = "small", srcs = ["prefetching_ops_test.py"], srcs_version = "PY2AND3", + tags = [ + "manual", + "no_oss", # b/68785503 + ], deps = [ - "//tensorflow/contrib/data/python/ops:dataset_ops", "//tensorflow/contrib/data/python/ops:prefetching_py", "//tensorflow/core:protos_all_py", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:function", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/ops:iterator_ops", ], ) filegroup( name = "all_files", srcs = glob( - ["**/*"], + include = [ + "**/*", + ], exclude = [ "**/METADATA", "**/OWNERS", diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py index add17ff8bcea0f228dc36ec6157fe95b9ce44d80..d975a0167fe2cc8ae81431a8687aaf8695119a98 100644 --- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py @@ -21,6 +21,7 @@ import math import numpy as np +from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import batching from tensorflow.contrib.data.python.ops import dataset_ops from tensorflow.python.framework import constant_op @@ -51,8 +52,9 @@ class BatchDatasetTest(test.TestCase): def _map_fn(x, y, z): return math_ops.square(x), math_ops.square(y), math_ops.square(z) - iterator = (dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn) - .repeat(count).batch(batch_size).make_initializable_iterator()) + iterator = ( + dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn) + .repeat(count).batch(batch_size).make_initializable_iterator()) init_op = iterator.initializer get_next = iterator.get_next() @@ -68,7 +70,7 @@ class BatchDatasetTest(test.TestCase): result = sess.run(get_next) for component, result_component in zip(components, result): for j in range(14): - self.assertAllEqual(component[(i*14 + j) % 7]**2, + self.assertAllEqual(component[(i * 14 + j) % 7]**2, result_component[j]) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) @@ -83,12 +85,12 @@ class BatchDatasetTest(test.TestCase): result = sess.run(get_next) for component, result_component in zip(components, result): for j in range(8): - self.assertAllEqual(component[(i*8 + j) % 7]**2, + self.assertAllEqual(component[(i * 8 + j) % 7]**2, result_component[j]) result = sess.run(get_next) for component, result_component in zip(components, result): for j in range((14 * 7) % 8): - self.assertAllEqual(component[((num_batches - 1)*8 + j) % 7]**2, + self.assertAllEqual(component[((num_batches - 1) * 8 + j) % 7]**2, result_component[j]) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) @@ -102,14 +104,67 @@ class BatchDatasetTest(test.TestCase): with self.assertRaises(errors.InvalidArgumentError): sess.run(init_op, feed_dict={count: 14, batch_size: 0}) + def assertSparseValuesEqual(self, a, b): + self.assertAllEqual(a.indices, b.indices) + self.assertAllEqual(a.values, b.values) + self.assertAllEqual(a.dense_shape, b.dense_shape) + + def testBatchSparse(self): + + def _sparse(i): + return sparse_tensor.SparseTensorValue( + indices=[[0]], values=(i * [1]), dense_shape=[1]) + + iterator = dataset_ops.Dataset.range(10).map(_sparse).batch( + 5).make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.test_session() as sess: + sess.run(init_op) + for i in range(2): + actual = sess.run(get_next) + expected = sparse_tensor.SparseTensorValue( + indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]], + values=[i * 5, i * 5 + 1, i * 5 + 2, i * 5 + 3, i * 5 + 4], + dense_shape=[5, 1]) + self.assertTrue(sparse_tensor.is_sparse(actual)) + self.assertSparseValuesEqual(actual, expected) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testNestedBatchSparse(self): + + def _sparse(i): + return sparse_tensor.SparseTensorValue( + indices=[[0]], values=(i * [1]), dense_shape=[1]) + + iterator = dataset_ops.Dataset.range(10).map(_sparse).batch(5).batch( + 2).make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.test_session() as sess: + sess.run(init_op) + actual = sess.run(get_next) + expected = sparse_tensor.SparseTensorValue( + indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0], [0, 4, 0], + [1, 0, 0], [1, 1, 0], [1, 2, 0], [1, 3, 0], [1, 4, 0]], + values=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + dense_shape=[2, 5, 1]) + self.assertTrue(sparse_tensor.is_sparse(actual)) + self.assertSparseValuesEqual(actual, expected) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + def testPaddedBatchDataset(self): seq_lens = array_ops.placeholder(dtypes.int32, shape=[None]) padded_shape = array_ops.placeholder(dtypes.int64, shape=[1]) - iterator = (dataset_ops.Dataset.from_tensor_slices(seq_lens) - .map(lambda x: array_ops.fill([x], x)).padded_batch( - 4, - padded_shapes=padded_shape).make_initializable_iterator()) + iterator = ( + dataset_ops.Dataset.from_tensor_slices(seq_lens) + .map(lambda x: array_ops.fill([x], x)).padded_batch( + 4, padded_shapes=padded_shape).make_initializable_iterator()) init_op = iterator.initializer get_next = iterator.get_next() @@ -117,35 +172,40 @@ class BatchDatasetTest(test.TestCase): with self.test_session() as sess: # Test with random sequence lengths, and max padding. random_seq_lens = np.random.randint(20, size=(32,)).astype(np.int32) - sess.run(init_op, feed_dict={padded_shape: [-1], - seq_lens: random_seq_lens}) + sess.run( + init_op, feed_dict={ + padded_shape: [-1], + seq_lens: random_seq_lens + }) for i in range(8): result = sess.run(get_next) padded_len = np.max(result) self.assertEqual((4, padded_len), result.shape) for j in range(4): - seq_len = random_seq_lens[(i*4)+j] + seq_len = random_seq_lens[(i * 4) + j] self.assertAllEqual(result[j, :seq_len], [seq_len] * seq_len) self.assertAllEqual(result[j, seq_len:], [0] * (padded_len - seq_len)) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) # Test with random sequence lengths, and constant padding. - sess.run(init_op, feed_dict={padded_shape: [25], - seq_lens: random_seq_lens}) + sess.run( + init_op, feed_dict={ + padded_shape: [25], + seq_lens: random_seq_lens + }) for i in range(8): result = sess.run(get_next) self.assertEqual((4, 25), result.shape) for j in range(4): - seq_len = random_seq_lens[(i*4)+j] + seq_len = random_seq_lens[(i * 4) + j] self.assertAllEqual(result[j, :seq_len], [seq_len] * seq_len) self.assertAllEqual(result[j, seq_len:], [0] * (25 - seq_len)) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) # Test correct handling of empty tensors. - sess.run(init_op, feed_dict={padded_shape: [-1], - seq_lens: [0, 0, 0, 0]}) + sess.run(init_op, feed_dict={padded_shape: [-1], seq_lens: [0, 0, 0, 0]}) result = sess.run(get_next) self.assertAllEqual([[], [], [], []], result) with self.assertRaises(errors.OutOfRangeError): @@ -153,8 +213,7 @@ class BatchDatasetTest(test.TestCase): # Test error handling with constant sequence lengths, and # too-short padding. - sess.run(init_op, feed_dict={padded_shape: [5], - seq_lens: [6, 5, 5, 5]}) + sess.run(init_op, feed_dict={padded_shape: [5], seq_lens: [6, 5, 5, 5]}) with self.assertRaises(errors.DataLossError): result = sess.run(get_next) @@ -165,11 +224,13 @@ class BatchDatasetTest(test.TestCase): def fill_tuple(x): filled = array_ops.fill([x], x) return (filled, string_ops.as_string(filled)) - iterator = (dataset_ops.Dataset.from_tensor_slices(seq_lens).map(fill_tuple) - .padded_batch( - 4, - padded_shapes=(padded_shape, padded_shape), - padding_values=(-1, "")).make_initializable_iterator()) + + iterator = ( + dataset_ops.Dataset.from_tensor_slices(seq_lens).map(fill_tuple) + .padded_batch( + 4, + padded_shapes=(padded_shape, padded_shape), + padding_values=(-1, "")).make_initializable_iterator()) init_op = iterator.initializer get_next = iterator.get_next() @@ -177,15 +238,18 @@ class BatchDatasetTest(test.TestCase): with self.test_session() as sess: # Test with random sequence lengths, and max padding. random_seq_lens = np.random.randint(20, size=(32,)).astype(np.int32) - sess.run(init_op, feed_dict={padded_shape: [-1], - seq_lens: random_seq_lens}) + sess.run( + init_op, feed_dict={ + padded_shape: [-1], + seq_lens: random_seq_lens + }) for i in range(8): result = sess.run(get_next) padded_len = np.max(result[0]) self.assertEqual((4, padded_len), result[0].shape) self.assertEqual((4, padded_len), result[1].shape) for j in range(4): - seq_len = random_seq_lens[(i*4)+j] + seq_len = random_seq_lens[(i * 4) + j] self.assertAllEqual(result[0][j, :seq_len], [seq_len] * seq_len) self.assertAllEqual(result[0][j, seq_len:], [-1] * (padded_len - seq_len)) @@ -219,20 +283,30 @@ class BatchDatasetTest(test.TestCase): constant_op.constant([-1, -1], dtype=dtypes.int64), constant_op.constant([37], dtype=dtypes.int64))) - for dataset in [dynamic_padding_from_tensor_shapes, - dynamic_padding_from_lists, - dynamic_padding_from_lists_with_minus_one, - dynamic_padding_from_tensors]: + for dataset in [ + dynamic_padding_from_tensor_shapes, dynamic_padding_from_lists, + dynamic_padding_from_lists_with_minus_one, dynamic_padding_from_tensors + ]: self.assertEqual([None, None], dataset.output_shapes[0].as_list()) self.assertEqual([None, None, None], dataset.output_shapes[1].as_list()) self.assertEqual([None, 37], dataset.output_shapes[2].as_list()) + def testPaddedBatchSparseError(self): + + def _map_fn(i): + return sparse_tensor.SparseTensorValue( + indices=[[0, 0]], values=(i * [1]), dense_shape=[1, 1]), i + + with self.assertRaises(TypeError): + _ = dataset_ops.Dataset.range(10).map(_map_fn).padded_batch(10) + def testDenseToSparseBatchDataset(self): components = np.random.randint(12, size=(100,)).astype(np.int32) - iterator = (dataset_ops.Dataset.from_tensor_slices(components) - .map(lambda x: array_ops.fill([x], x)).apply( - batching.dense_to_sparse_batch(4, [12])) - .make_initializable_iterator()) + iterator = ( + dataset_ops.Dataset.from_tensor_slices(components) + .map(lambda x: array_ops.fill([x], x)).apply( + batching.dense_to_sparse_batch(4, + [12])).make_initializable_iterator()) init_op = iterator.initializer get_next = sparse_tensor.SparseTensor(*iterator.get_next()) @@ -241,24 +315,26 @@ class BatchDatasetTest(test.TestCase): for start in range(0, len(components), 4): results = sess.run(get_next) + self.assertAllEqual([[i, j] + for i, c in enumerate(components[start:start + 4]) + for j in range(c)], results.indices) self.assertAllEqual( - [[i, j] for i, c in enumerate(components[start:start+4]) - for j in range(c)], results.indices) - self.assertAllEqual( - [c for c in components[start:start+4] for _ in range(c)], + [c for c in components[start:start + 4] for _ in range(c)], results.values) - self.assertAllEqual( - [min(4, len(components) - start), 12], results.dense_shape) + self.assertAllEqual([min(4, + len(components) - start), 12], + results.dense_shape) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) def testDenseToSparseBatchDatasetWithUnknownShape(self): components = np.random.randint(5, size=(40,)).astype(np.int32) - iterator = (dataset_ops.Dataset.from_tensor_slices(components) - .map(lambda x: array_ops.fill([x, x], x)).apply( - batching.dense_to_sparse_batch( - 4, [5, -1])).make_initializable_iterator()) + iterator = ( + dataset_ops.Dataset.from_tensor_slices(components) + .map(lambda x: array_ops.fill([x, x], x)).apply( + batching.dense_to_sparse_batch( + 4, [5, -1])).make_initializable_iterator()) init_op = iterator.initializer get_next = sparse_tensor.SparseTensor(*iterator.get_next()) @@ -267,27 +343,30 @@ class BatchDatasetTest(test.TestCase): for start in range(0, len(components), 4): results = sess.run(get_next) - self.assertAllEqual( - [[i, j, z] for i, c in enumerate(components[start:start+4]) - for j in range(c) for z in range(c)], results.indices) - self.assertAllEqual( - [c for c in components[start:start+4] - for _ in range(c) for _ in range(c)], - results.values) - self.assertAllEqual( - [min(4, len(components) - start), - 5, - np.max(components[start:start+4])], - results.dense_shape) + self.assertAllEqual([[i, j, z] + for i, c in enumerate(components[start:start + 4]) + for j in range(c) + for z in range(c)], results.indices) + self.assertAllEqual([ + c + for c in components[start:start + 4] for _ in range(c) + for _ in range(c) + ], results.values) + self.assertAllEqual([ + min(4, + len(components) - start), 5, + np.max(components[start:start + 4]) + ], results.dense_shape) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) def testDenseToSparseBatchDatasetWithInvalidShape(self): input_tensor = array_ops.constant([[1]]) - iterator = (dataset_ops.Dataset.from_tensors(input_tensor) - .apply(batching.dense_to_sparse_batch(4, [-2])) - .make_initializable_iterator()) + iterator = ( + dataset_ops.Dataset.from_tensors(input_tensor).apply( + batching.dense_to_sparse_batch(4, [-2])) + .make_initializable_iterator()) init_op = iterator.initializer with self.test_session() as sess: @@ -297,8 +376,10 @@ class BatchDatasetTest(test.TestCase): def testDenseToSparseBatchDatasetShapeErrors(self): input_tensor = array_ops.placeholder(dtypes.int32) - iterator = (dataset_ops.Dataset.from_tensors(input_tensor).apply( - batching.dense_to_sparse_batch(4, [12])).make_initializable_iterator()) + iterator = ( + dataset_ops.Dataset.from_tensors(input_tensor).apply( + batching.dense_to_sparse_batch(4, + [12])).make_initializable_iterator()) init_op = iterator.initializer get_next = sparse_tensor.SparseTensor(*iterator.get_next()) @@ -355,8 +436,7 @@ class BatchDatasetTest(test.TestCase): def testUnbatchMultiElementTupleDataset(self): data = tuple([(math_ops.range(10 * i, 10 * i + 10), - array_ops.fill([10], "hi")) - for i in range(3)]) + array_ops.fill([10], "hi")) for i in range(3)]) data = dataset_ops.Dataset.from_tensor_slices(data) expected_types = ((dtypes.int32, dtypes.string),) * 3 data = data.batch(2) @@ -369,9 +449,7 @@ class BatchDatasetTest(test.TestCase): with self.test_session() as sess: for i in range(10): - self.assertEqual(((i, b"hi"), - (10 + i, b"hi"), - (20 + i, b"hi")), + self.assertEqual(((i, b"hi"), (10 + i, b"hi"), (20 + i, b"hi")), sess.run(op)) with self.assertRaises(errors.OutOfRangeError): @@ -384,9 +462,10 @@ class BatchDatasetTest(test.TestCase): batch_size = array_ops.placeholder(dtypes.int64, shape=[]) - iterator = (dataset_ops.Dataset.from_tensor_slices(components).apply( - batching.batch_and_drop_remainder(batch_size)) - .make_initializable_iterator()) + iterator = ( + dataset_ops.Dataset.from_tensor_slices(components).apply( + batching.batch_and_drop_remainder(batch_size)) + .make_initializable_iterator()) next_element = iterator.get_next() @@ -403,14 +482,85 @@ class BatchDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(next_element) + def testBatchAndDropRemainderSparse(self): + + def _sparse(i): + return sparse_tensor.SparseTensorValue( + indices=[[0]], values=(i * [1]), dense_shape=[1]) + + iterator = dataset_ops.Dataset.range(12).map(_sparse).apply( + batching.batch_and_drop_remainder(5)).make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.test_session() as sess: + sess.run(init_op) + for i in range(2): + actual = sess.run(get_next) + expected = sparse_tensor.SparseTensorValue( + indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]], + values=[i * 5, i * 5 + 1, i * 5 + 2, i * 5 + 3, i * 5 + 4], + dense_shape=[5, 1]) + self.assertTrue(sparse_tensor.is_sparse(actual)) + self.assertSparseValuesEqual(actual, expected) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testPaddedBatchAndDropRemainder(self): + els = [] + for length in [3, 6, 9, 4, 12, 10, 2]: + els.append((np.array(length), np.arange(length) + 1, + np.array(length * 2))) + + dataset = dataset_ops.Dataset.from_tensors(els[0]) + for el in els[1:]: + dataset = dataset.concatenate(dataset_ops.Dataset.from_tensors(el)) + + batch_size = array_ops.placeholder(dtypes.int64, shape=[]) + iterator = ( + dataset.apply( + batching.padded_batch_and_drop_remainder( + batch_size, ([], [None], []))).make_initializable_iterator()) + + next_element = iterator.get_next() + + with self.test_session() as sess: + for test_batch_size in [1, 3, 7, 10]: + sess.run(iterator.initializer, feed_dict={batch_size: test_batch_size}) + num_batches = 7 // test_batch_size + for i in range(num_batches): + result = sess.run(next_element) + for component_idx, result_component in enumerate(result): + for j in range(test_batch_size): + data_idx = i * test_batch_size + j + comp = result_component[j] + unpadded = comp[comp > 0] + if np.isscalar(comp): + # The boolean mask indexing above adds a dim back. Rm it. + unpadded = unpadded[0] + self.assertAllEqual(els[data_idx][component_idx], unpadded) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testPaddedBatchAndDropRemainderSparseError(self): + + def _map_fn(i): + return sparse_tensor.SparseTensorValue( + indices=[[0, 0]], values=(i * [1]), dense_shape=[1, 1]), i + + with self.assertRaises(TypeError): + _ = dataset_ops.Dataset.range(10).map(_map_fn).apply( + batching.padded_batch_and_drop_remainder(5)) + def testBatchAndDropRemainderShapeInference(self): - components = (array_ops.placeholder(dtypes.int32), (array_ops.placeholder( - dtypes.int32, shape=[None]), array_ops.placeholder( - dtypes.int32, shape=[20, 30]))) + components = (array_ops.placeholder(dtypes.int32), + (array_ops.placeholder(dtypes.int32, shape=[None]), + array_ops.placeholder(dtypes.int32, shape=[20, 30]))) # Test with a statically known batch size. - dataset = (dataset_ops.Dataset.from_tensor_slices(components).apply( - batching.batch_and_drop_remainder(128))) + dataset = ( + dataset_ops.Dataset.from_tensor_slices(components).apply( + batching.batch_and_drop_remainder(128))) self.assertIs(None, dataset.output_shapes[0].ndims) self.assertEqual([128], dataset.output_shapes[1][0].as_list()) @@ -419,14 +569,15 @@ class BatchDatasetTest(test.TestCase): # Test with a dynamic batch size: the static shape will be unknown, because # `batch_size` is a placeholder. batch_size = array_ops.placeholder(dtypes.int64) - dataset = (dataset_ops.Dataset.from_tensor_slices(components).apply( - batching.batch_and_drop_remainder(batch_size))) + dataset = ( + dataset_ops.Dataset.from_tensor_slices(components).apply( + batching.batch_and_drop_remainder(batch_size))) self.assertIs(None, dataset.output_shapes[0].ndims) self.assertEqual([None], dataset.output_shapes[1][0].as_list()) self.assertEqual([None, 30], dataset.output_shapes[1][1].as_list()) - def testBatchAndMapDataset(self): + def _testBatchAndMapDatasetHelper(self, num_parallel_batches=1): """Test a dataset that maps a TF function across its input elements.""" # The pipeline is TensorSliceDataset -> # RepeatDataset(count) -> BatchAndMapDataset(square_3, batch_size). @@ -440,9 +591,13 @@ class BatchDatasetTest(test.TestCase): def _map_fn(x, y, z): return math_ops.square(x), math_ops.square(y), math_ops.square(z) - iterator = (dataset_ops.Dataset.from_tensor_slices(components).repeat(count) - .apply(batching.map_and_batch(_map_fn, batch_size)) - .make_initializable_iterator()) + iterator = ( + dataset_ops.Dataset.from_tensor_slices(components).repeat(count).apply( + batching.map_and_batch( + map_func=_map_fn, + batch_size=batch_size, + num_parallel_batches=num_parallel_batches)) + .make_initializable_iterator()) init_op = iterator.initializer get_next = iterator.get_next() @@ -458,7 +613,7 @@ class BatchDatasetTest(test.TestCase): result = sess.run(get_next) for component, result_component in zip(components, result): for j in range(14): - self.assertAllEqual(component[(i*14 + j) % 7]**2, + self.assertAllEqual(component[(i * 14 + j) % 7]**2, result_component[j]) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) @@ -473,7 +628,7 @@ class BatchDatasetTest(test.TestCase): result = sess.run(get_next) for component, result_component in zip(components, result): for j in range(8): - self.assertAllEqual(component[(i*8 + j) % 7]**2, + self.assertAllEqual(component[(i * 8 + j) % 7]**2, result_component[j]) # The last batch should fail with `OutOfRange`. with self.assertRaises(errors.OutOfRangeError): @@ -488,14 +643,49 @@ class BatchDatasetTest(test.TestCase): with self.assertRaises(errors.InvalidArgumentError): sess.run(init_op, feed_dict={count: 14, batch_size: 0}) + def testBatchAndMapDataset(self): + return self._testBatchAndMapDatasetHelper() + + def testBatchAndMapDatasetWithParallelBatching(self): + # TODO(b/70299909): This test surfaces a bug in the `map_and_batch` + # transformation, which manifests as premature EOF. Fix it. + # + # return self._testBatchAndMapDatasetHelper(num_parallel_batches=10) + pass + + def testMapAndBatchSparse(self): + + def _sparse(i): + return sparse_tensor.SparseTensorValue( + indices=[[0]], values=(i * [1]), dense_shape=[1]) + + iterator = dataset_ops.Dataset.range(10).apply( + batching.map_and_batch(_sparse, 5)).make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.test_session() as sess: + sess.run(init_op) + for i in range(2): + actual = sess.run(get_next) + expected = sparse_tensor.SparseTensorValue( + indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]], + values=[i * 5, i * 5 + 1, i * 5 + 2, i * 5 + 3, i * 5 + 4], + dense_shape=[5, 1]) + self.assertTrue(sparse_tensor.is_sparse(actual)) + self.assertSparseValuesEqual(actual, expected) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + def testBatchAndMapDatasetFails(self): """Test a dataset that maps a TF function across its input elements.""" dataset = dataset_ops.Dataset.from_tensors( array_ops.check_numerics( constant_op.constant(1.0) / constant_op.constant(0.0), "oops")) batch_size = array_ops.placeholder(dtypes.int64, shape=[]) - iterator = (dataset.apply(batching.map_and_batch(lambda x: x, batch_size)) - .make_initializable_iterator()) + iterator = ( + dataset.apply(batching.map_and_batch(lambda x: x, batch_size)) + .make_initializable_iterator()) init_op = iterator.initializer with self.test_session() as sess: with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"): @@ -503,6 +693,7 @@ class BatchDatasetTest(test.TestCase): def testBatchAndMapDatasetShapeMismatch(self): """Test a dataset that maps a TF function across its input elements.""" + def generator(): yield [1] yield [2] @@ -523,5 +714,63 @@ class BatchDatasetTest(test.TestCase): "number of elements does not match"): sess.run(get_next) + +class BatchDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def build_dataset(self, multiplier=15.0, tensor_slice_len=2, batch_size=2): + components = ( + np.arange(tensor_slice_len), + np.array([[1, 2, 3]]) * np.arange(tensor_slice_len)[:, np.newaxis], + np.array(multiplier) * np.arange(tensor_slice_len)) + + return dataset_ops.Dataset.from_tensor_slices(components).batch(batch_size) + + def testCore(self): + tensor_slice_len = 8 + batch_size = 2 + num_outputs = tensor_slice_len // batch_size + self.run_core_tests( + lambda: self.build_dataset(15.0, tensor_slice_len, batch_size), + lambda: self.build_dataset(20.0, tensor_slice_len, batch_size), + num_outputs) + + +class PaddedBatchDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def testPaddedBatch(self): + + def build_dataset(seq_lens): + return dataset_ops.Dataset.from_tensor_slices(seq_lens).map( + lambda x: array_ops.fill([x], x)).padded_batch( + 4, padded_shapes=[-1]) + + seq_lens1 = np.random.randint(1, 20, size=(32,)).astype(np.int32) + seq_lens2 = np.random.randint(21, 40, size=(32,)).astype(np.int32) + self.run_core_tests(lambda: build_dataset(seq_lens1), + lambda: build_dataset(seq_lens2), 8) + + def testPaddedBatchNonDefaultPadding(self): + + def build_dataset(seq_lens): + + def fill_tuple(x): + filled = array_ops.fill([x], x) + return (filled, string_ops.as_string(filled)) + + padded_shape = [-1] + return dataset_ops.Dataset.from_tensor_slices(seq_lens).map( + fill_tuple).padded_batch( + 4, + padded_shapes=(padded_shape, padded_shape), + padding_values=(-1, "")) + + seq_lens1 = np.random.randint(1, 20, size=(32,)).astype(np.int32) + seq_lens2 = np.random.randint(21, 40, size=(32,)).astype(np.int32) + self.run_core_tests(lambda: build_dataset(seq_lens1), + lambda: build_dataset(seq_lens2), 8) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py b/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py index c3d6bfc097798530008f186cce68906b6af8fe47..55a1d3b95b212466b262ad3c26f1efd7ed0e067e 100644 --- a/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py @@ -17,14 +17,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os import threading import numpy as np +from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import batching from tensorflow.contrib.data.python.ops import dataset_ops -from tensorflow.contrib.data.python.ops import iterator_ops from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session from tensorflow.python.data.util import nest @@ -32,16 +31,16 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import test -from tensorflow.python.training import saver as saver_lib class DatasetConstructorTest(test.TestCase): - def testTensorDataset(self): + def testFromTensors(self): """Test an dataset that represents a single tuple of tensors.""" components = (np.array(1), np.array([1, 2, 3]), np.array(37.0)) @@ -61,7 +60,75 @@ class DatasetConstructorTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) - def testTensorSliceDataset(self): + def assertSparseValuesEqual(self, a, b): + self.assertAllEqual(a.indices, b.indices) + self.assertAllEqual(a.values, b.values) + self.assertAllEqual(a.dense_shape, b.dense_shape) + + def testFromTensorsSparse(self): + """Test an dataset that represents a single tuple of tensors.""" + components = (sparse_tensor.SparseTensorValue( + indices=np.array([[0]]), + values=np.array([0]), + dense_shape=np.array([1])), + sparse_tensor.SparseTensorValue( + indices=np.array([[0, 0], [1, 1]]), + values=np.array([-1, 1]), + dense_shape=np.array([2, 2]))) + + iterator = ( + dataset_ops.Dataset.from_tensors(components) + .make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + + self.assertEqual( + [tensor_shape.TensorShape(c.dense_shape) for c in components], + [shape for shape in iterator.output_shapes]) + + with self.test_session() as sess: + sess.run(init_op) + results = sess.run(get_next) + for component, result_component in zip(components, results): + self.assertSparseValuesEqual(component, result_component) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testFromTensorsMixed(self): + """Test an dataset that represents a single tuple of tensors.""" + components = (np.array(1), np.array([1, 2, 3]), np.array(37.0), + sparse_tensor.SparseTensorValue( + indices=np.array([[0]]), + values=np.array([0]), + dense_shape=np.array([1])), + sparse_tensor.SparseTensorValue( + indices=np.array([[0, 0], [1, 1]]), + values=np.array([-1, 1]), + dense_shape=np.array([2, 2]))) + + iterator = ( + dataset_ops.Dataset.from_tensors(components) + .make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + + self.assertEqual([ + tensor_shape.TensorShape(c.dense_shape) + if sparse_tensor.is_sparse(c) else c.shape for c in components + ], [shape for shape in iterator.output_shapes]) + + with self.test_session() as sess: + sess.run(init_op) + results = sess.run(get_next) + for component, result_component in zip(components, results): + if sparse_tensor.is_sparse(component): + self.assertSparseValuesEqual(component, result_component) + else: + self.assertAllEqual(component, result_component) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testFromTensorSlices(self): """Test an dataset that represents the slices from a tuple of tensors.""" components = ( np.tile(np.array([[1], [2], [3], [4]]), 20), np.tile( @@ -86,7 +153,127 @@ class DatasetConstructorTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) - def testTensorSliceDatasetWithDict(self): + def testFromTensorSlicesSparse(self): + """Test an dataset that represents the slices from a tuple of tensors.""" + components = (sparse_tensor.SparseTensorValue( + indices=np.array([[0, 0], [1, 0], [2, 0]]), + values=np.array([0, 0, 0]), + dense_shape=np.array([3, 1])), + sparse_tensor.SparseTensorValue( + indices=np.array([[0, 0], [1, 1], [2, 2]]), + values=np.array([1, 2, 3]), + dense_shape=np.array([3, 3]))) + + iterator = ( + dataset_ops.Dataset.from_tensor_slices(components) + .make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + + self.assertEqual( + [tensor_shape.TensorShape(c.dense_shape[1:]) for c in components], + [shape for shape in iterator.output_shapes]) + + with self.test_session() as sess: + sess.run(init_op) + expected = [ + (sparse_tensor.SparseTensorValue( + indices=np.array([[0]]), + values=np.array([0]), + dense_shape=np.array([1])), + sparse_tensor.SparseTensorValue( + indices=np.array([[0]]), + values=np.array([1]), + dense_shape=np.array([3]))), + (sparse_tensor.SparseTensorValue( + indices=np.array([[0]]), + values=np.array([0]), + dense_shape=np.array([1])), + sparse_tensor.SparseTensorValue( + indices=np.array([[1]]), + values=np.array([2]), + dense_shape=np.array([3]))), + (sparse_tensor.SparseTensorValue( + indices=np.array([[0]]), + values=np.array([0]), + dense_shape=np.array([1])), + sparse_tensor.SparseTensorValue( + indices=np.array([[2]]), + values=np.array([3]), + dense_shape=np.array([3]))), + ] + for i in range(3): + results = sess.run(get_next) + for component, result_component in zip(expected[i], results): + self.assertSparseValuesEqual(component, result_component) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testFromTensorSlicesMixed(self): + """Test an dataset that represents the slices from a tuple of tensors.""" + components = (np.tile(np.array([[1], [2], [3]]), 20), + np.tile(np.array([[12], [13], [14]]), 22), + np.array([37.0, 38.0, 39.0]), + sparse_tensor.SparseTensorValue( + indices=np.array([[0, 0], [1, 0], [2, 0]]), + values=np.array([0, 0, 0]), + dense_shape=np.array([3, 1])), + sparse_tensor.SparseTensorValue( + indices=np.array([[0, 0], [1, 1], [2, 2]]), + values=np.array([1, 2, 3]), + dense_shape=np.array([3, 3]))) + + iterator = ( + dataset_ops.Dataset.from_tensor_slices(components) + .make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + + self.assertEqual([ + tensor_shape.TensorShape(c.dense_shape[1:]) + if sparse_tensor.is_sparse(c) else c.shape[1:] for c in components + ], [shape for shape in iterator.output_shapes]) + + with self.test_session() as sess: + sess.run(init_op) + expected = [ + (sparse_tensor.SparseTensorValue( + indices=np.array([[0]]), + values=np.array([0]), + dense_shape=np.array([1])), + sparse_tensor.SparseTensorValue( + indices=np.array([[0]]), + values=np.array([1]), + dense_shape=np.array([3]))), + (sparse_tensor.SparseTensorValue( + indices=np.array([[0]]), + values=np.array([0]), + dense_shape=np.array([1])), + sparse_tensor.SparseTensorValue( + indices=np.array([[1]]), + values=np.array([2]), + dense_shape=np.array([3]))), + (sparse_tensor.SparseTensorValue( + indices=np.array([[0]]), + values=np.array([0]), + dense_shape=np.array([1])), + sparse_tensor.SparseTensorValue( + indices=np.array([[2]]), + values=np.array([3]), + dense_shape=np.array([3]))), + ] + for i in range(3): + results = sess.run(get_next) + for component, result_component in zip( + (zip(*components[:3])[i] + expected[i]), results): + if sparse_tensor.is_sparse(component): + self.assertSparseValuesEqual(component, result_component) + else: + self.assertAllEqual(component, result_component) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testFromTensorSlicesWithDict(self): components = {"foo": [1, 2, 3], "bar": [[4.0], [5.0], [6.0]]} iterator = (dataset_ops.Dataset.from_tensor_slices(components) .make_initializable_iterator()) @@ -107,7 +294,7 @@ class DatasetConstructorTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) - def testSparseTensorSliceDataset(self): + def testFromSparseTensorSlices(self): """Test a dataset based on slices of a `tf.SparseTensor`.""" st = array_ops.sparse_placeholder(dtypes.float64) iterator = (dataset_ops.Dataset.from_sparse_tensor_slices(st) @@ -574,135 +761,63 @@ class DatasetConstructorTest(test.TestCase): new = batching._RestructuredDataset(dataset, new_types, new_shape_lists) # pylint: enable=protected-access - def _iterator_checkpoint_prefix(self): - return os.path.join(self.get_temp_dir(), "iterator") - def _testSaveRestoreFromTensorsUtility(self, start, break_range, stop): - path = self._iterator_checkpoint_prefix() - step = 0 - meta_filename = path + "-%d.meta" % step +class DatasetConstructorSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): - components = (np.array(1), np.array([1, 2, 3]), np.array(37.0)) + def _build_tensor_dataset(self, variable_array): + components = (variable_array, np.array([1, 2, 3]), np.array(37.0)) - with ops.Graph().as_default() as g: - iterator = ( - dataset_ops.Dataset.from_tensors(components) - .make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - saveable = iterator_ops.make_saveable_from_iterator(iterator) - ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) - for t in nest.flatten(get_next): - ops.add_to_collection("get_next", t) - saver = saver_lib.Saver() - with self.test_session(graph=g) as sess: - sess.run(init_op) - for _ in range(start, break_range): - result = sess.run(get_next) - for component, result_component in zip(components, result): - self.assertAllEqual(component, result_component) - saver.save(sess, path, step) - - with ops.Graph().as_default() as g: - saver = saver_lib.import_meta_graph(meta_filename) - with self.test_session(graph=g) as sess: - get_next = nest.pack_sequence_as(("a", "b", "c"), - ops.get_collection("get_next")) - saver.restore(sess, saver_lib.latest_checkpoint(self.get_temp_dir())) - for _ in range(break_range, stop): - result = sess.run(get_next) - for component, result_component in zip(components, result): - self.assertAllEqual(component, result_component) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) + return dataset_ops.Dataset.from_tensors(components) - def testRestoreFromTensors(self): - self._testSaveRestoreFromTensorsUtility(0, 0, 1) + def testFromTensorsCore(self): + # Equal length components + arr = np.array(1) + num_outputs = 1 + diff_arr = np.array(2) + self.run_core_tests(lambda: self._build_tensor_dataset(arr), + lambda: self._build_tensor_dataset(diff_arr), + num_outputs) - def testRestoreExhuatedIteratorFromTensors(self): - self._testSaveRestoreFromTensorsUtility(0, 1, 1) + def _build_tensor_slices_dataset(self, components): + return dataset_ops.Dataset.from_tensor_slices(components) - def _build_graph_tensor_slices(self, components): - iterator = dataset_ops.Dataset.from_tensor_slices( - components).make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - saveable = iterator_ops.make_saveable_from_iterator(iterator) - ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) - for t in nest.flatten(get_next): - ops.add_to_collection("get_next", t) - return init_op, get_next - - def _testSaveRestoreFromTensorSlicesUtility(self, start, break_range, stop): - path = self._iterator_checkpoint_prefix() - step = 0 - meta_filename = path + "-%d.meta" % step - - components = (np.tile(np.array([[1], [2], [3], [4]]), 20), np.tile( - np.array([[12], [13], [14], [15]]), 22), + def testFromTensorSlicesCore(self): + # Equal length components + components = (np.tile(np.array([[1], [2], [3], [4]]), 20), + np.tile(np.array([[12], [13], [14], [15]]), 22), np.array([37.0, 38.0, 39.0, 40.0])) - with ops.Graph().as_default() as g: - init_op, get_next = self._build_graph_tensor_slices(components) - saver = saver_lib.Saver() - with self.test_session(graph=g) as sess: - sess.run(init_op) - for i in range(start, break_range): - result = sess.run(get_next) - for component, result_component in zip(components, result): - self.assertAllEqual(component[i], result_component) - saver.save(sess, path, step) - - with ops.Graph().as_default() as g: - saver = saver_lib.import_meta_graph(meta_filename) - with self.test_session(graph=g) as sess: - get_next = nest.pack_sequence_as(("a", "b", "c"), - ops.get_collection("get_next")) - saver.restore(sess, saver_lib.latest_checkpoint(self.get_temp_dir())) - for i in range(break_range, stop): - result = sess.run(get_next) - for component, result_component in zip(components, result): - self.assertAllEqual(component[i], result_component) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testRestoreFromTensorSlices(self): - self._testSaveRestoreFromTensorSlicesUtility(0, 4, 2) - - def testRestoreExhaustedIteratorFromTensorSlices(self): - self._testSaveRestoreFromTensorSlicesUtility(0, 4, 4) - - def tesRestoreFromTensorSlicesWithDict(self): - - path = self._iterator_checkpoint_prefix() - step = 0 - meta_filename = path + "-%d.meta" % step - - components = {"foo": [1, 2, 3], "bar": [[4.0], [5.0], [6.0]]} - - with ops.Graph().as_default() as g: - init_op, get_next = self._build_graph_tensor_slices(components) - saver = saver_lib.Saver() - with self.test_session(graph=g) as sess: - sess.run(init_op) - for i in range(2): - results = sess.run(get_next) - self.assertEqual(components["foo"][i], results["foo"]) - self.assertEqual(components["bar"][i], results["bar"]) - saver.save(sess, path, step) - - with ops.Graph().as_default() as g: - saver = saver_lib.import_meta_graph(meta_filename) - with self.test_session(graph=g) as sess: - get_next = nest.pack_sequence_as(("a", "b"), - ops.get_collection("get_next")) - saver.restore(sess, saver_lib.latest_checkpoint(self.get_temp_dir())) - for i in range(2, 3): - results = sess.run(get_next) - self.assertEqual(components["foo"][i], results["foo"]) - self.assertEqual(components["bar"][i], results["bar"]) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) + diff_comp = (np.tile(np.array([[1], [2], [3], [4]]), 20), + np.tile(np.array([[5], [6], [7], [8]]), 22), + np.array([1.0, 2.0, 3.0, 4.0])) + + dict_components = {"foo": [1, 2, 3], "bar": [[4.0], [5.0], [6.0]]} + + self.run_core_tests(lambda: self._build_tensor_slices_dataset(components), + lambda: self._build_tensor_slices_dataset(diff_comp), 4) + self.run_core_tests( + lambda: self._build_tensor_slices_dataset(dict_components), None, 3) + + def _build_sparse_tensor_slice_dataset(self, slices): + indices = np.array( + [[i, j] for i in range(len(slices)) for j in range(len(slices[i]))], + dtype=np.int64) + values = np.array([val for s in slices for val in s], dtype=np.float64) + dense_shape = np.array( + [len(slices), max(len(s) for s in slices) + 1], dtype=np.int64) + sparse_components = sparse_tensor.SparseTensor(indices, values, dense_shape) + return dataset_ops.Dataset.from_sparse_tensor_slices(sparse_components) + + def testFromSparseTensorSlicesCore(self): + slices = [[1., 2., 3.], [1.], [1.], [1., 2.], [], [1., 2.], [], [], []] + diff_slices = [[1., 2.], [2.], [2., 3., 4.], [], [], []] + + self.run_core_tests( + lambda: self._build_sparse_tensor_slice_dataset(slices), + lambda: self._build_sparse_tensor_slice_dataset(diff_slices), + 9, + sparse_tensors=True) if __name__ == "__main__": diff --git a/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py b/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py new file mode 100644 index 0000000000000000000000000000000000000000..bf25cc60a1c0efc09bed6501fd2d6f4ccb07764b --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py @@ -0,0 +1,633 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Base class for testing serializable datasets.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +import numpy as np + +from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops +from tensorflow.python.data.ops import iterator_ops +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import variables +from tensorflow.python.platform import gfile +from tensorflow.python.platform import test +from tensorflow.python.training import saver as saver_lib +from tensorflow.python.util import nest + + +class DatasetSerializationTestBase(test.TestCase): + """Base class for testing serializable datasets.""" + + def tearDown(self): + self._delete_ckpt() + + def run_core_tests(self, ds_fn1, ds_fn2, num_outputs, sparse_tensors=False): + """Runs the core tests. + + Args: + ds_fn1: 0-argument function that returns a Dataset. + ds_fn2: 0-argument function that returns a Dataset different from + ds_fn1. If None, verify_restore_in_modified_graph test is not run. + num_outputs: Total number of outputs expected from this Dataset. + sparse_tensors: Whether dataset is built from SparseTensor(s). + + Raises: + AssertionError if any test fails. + """ + self.verify_unused_iterator( + ds_fn1, num_outputs, sparse_tensors=sparse_tensors) + self.verify_fully_used_iterator( + ds_fn1, num_outputs, sparse_tensors=sparse_tensors) + self.verify_exhausted_iterator( + ds_fn1, num_outputs, sparse_tensors=sparse_tensors) + self.verify_init_before_restore( + ds_fn1, num_outputs, sparse_tensors=sparse_tensors) + self.verify_multiple_breaks( + ds_fn1, num_outputs, sparse_tensors=sparse_tensors) + self.verify_reset_restored_iterator( + ds_fn1, num_outputs, sparse_tensors=sparse_tensors) + self.verify_restore_in_empty_graph( + ds_fn1, num_outputs, sparse_tensors=sparse_tensors) + if ds_fn2: + self.verify_restore_in_modified_graph( + ds_fn1, ds_fn2, num_outputs, sparse_tensors=sparse_tensors) + + def verify_unused_iterator(self, + ds_fn, + num_outputs, + sparse_tensors=False, + verify_exhausted=True): + """Verifies that saving and restoring an unused iterator works. + + Args: + ds_fn: See `run_core_tests`. + num_outputs: See `run_core_tests`. + sparse_tensors: See `run_core_tests`. + verify_exhausted: See `gen_outputs`. + + Raises: + AssertionError if any test fails. + """ + self.verify_run_with_breaks( + ds_fn, [0], + num_outputs, + sparse_tensors=sparse_tensors, + verify_exhausted=verify_exhausted) + + def verify_fully_used_iterator(self, ds_fn, num_outputs, + sparse_tensors=False): + """Verifies that saving and restoring a fully used iterator works. + + Note that this only checks saving and restoring an iterator from which + `num_outputs` items have been produced but does not check for an + exhausted iterator, i.e., one from which an OutOfRange error has been + returned. + + Args: + ds_fn: See `run_core_tests`. + num_outputs: See `run_core_tests`. + sparse_tensors: See `run_core_tests`. + + Raises: + AssertionError if test fails. + """ + self.verify_run_with_breaks( + ds_fn, [num_outputs], num_outputs, sparse_tensors=sparse_tensors) + + def verify_exhausted_iterator(self, ds_fn, num_outputs, sparse_tensors=False): + """Verifies that saving and restoring an exhausted iterator works. + + An exhausted iterator is one which has returned an OutOfRange error. + + Args: + ds_fn: See `run_core_tests`. + num_outputs: See `run_core_tests`. + sparse_tensors: See `run_core_tests`. + + Raises: + AssertionError if any test fails. + """ + self.gen_outputs( + ds_fn, [], + num_outputs, + verify_exhausted=True, + sparse_tensors=sparse_tensors) + actual = self.gen_outputs( + ds_fn, [], + 0, + ckpt_saved=True, + verify_exhausted=True, + sparse_tensors=sparse_tensors) + self.assertEqual(len(actual), 0) + + def verify_init_before_restore(self, + ds_fn, + num_outputs, + sparse_tensors=False, + verify_exhausted=True): + """Verifies that restoring into an already initilized iterator works. + + Args: + ds_fn: See `run_core_tests`. + num_outputs: See `run_core_tests`. + sparse_tensors: See `run_core_tests`. + verify_exhausted: See `gen_outputs`. + + Raises: + AssertionError if any test fails. + """ + self.verify_run_with_breaks( + ds_fn, + self.gen_break_points(num_outputs), + num_outputs, + init_before_restore=True, + sparse_tensors=sparse_tensors, + verify_exhausted=verify_exhausted) + + def verify_multiple_breaks(self, + ds_fn, + num_outputs, + num_breaks=10, + sparse_tensors=False, + verify_exhausted=True): + """Attempts to save/restore at multiple break points. + + Args: + ds_fn: See `run_core_tests`. + num_outputs: See `run_core_tests`. + num_breaks: The number of break points. These are uniformly spread in + [0, num_outputs] both inclusive. + sparse_tensors: See `run_core_tests`. + verify_exhausted: See `gen_outputs`. + + Raises: + AssertionError if any test fails. + """ + self.verify_run_with_breaks( + ds_fn, + self.gen_break_points(num_outputs, num_breaks), + num_outputs, + sparse_tensors=sparse_tensors, + verify_exhausted=verify_exhausted) + + def verify_reset_restored_iterator(self, + ds_fn, + num_outputs, + break_point=None, + sparse_tensors=False, + verify_exhausted=True): + """Attempts to re-initialize a restored iterator. + + This is useful when restoring a training checkpoint during validation. + + Args: + ds_fn: See `run_core_tests`. + num_outputs: See `run_core_tests`. + break_point: Break point. Optional. Defaults to num_outputs/2. + sparse_tensors: See `run_core_tests`. + verify_exhausted: See `gen_outputs`. + + Raises: + AssertionError if any test fails. + """ + break_point = num_outputs // 2 if not break_point else break_point + + # Collect ground truth containing all outputs. + expected = self.gen_outputs( + ds_fn, [], + num_outputs, + sparse_tensors=sparse_tensors, + verify_exhausted=verify_exhausted) + + # Skip some items and save checkpoint. + self.gen_outputs( + ds_fn, [], + break_point, + sparse_tensors=sparse_tensors, + verify_exhausted=False) + + actual = [] + # Restore from checkpoint and then run init_op. + with ops.Graph().as_default() as g: + saver = self._import_meta_graph() + init_op, get_next_op = self._get_iterator_ops_from_collection( + ds_fn, sparse_tensors=sparse_tensors) + with self.test_session(graph=g) as sess: + self._restore(saver, sess) + sess.run(variables.global_variables_initializer()) + sess.run(init_op) + for _ in range(num_outputs): + actual.append(sess.run(get_next_op)) + if verify_exhausted: + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + self.match(expected, actual) + + def verify_restore_in_modified_graph(self, + ds_fn1, + ds_fn2, + num_outputs, + break_point=None, + sparse_tensors=False, + verify_exhausted=True): + """Attempts to restore an iterator in a modified graph. + + Builds an input pipeline using ds_fn1, runs it for `break_point` steps + and saves a checkpoint. Then builds a new graph using ds_fn2, restores + the checkpoint from ds_fn1 and verifies that the restore is successful. + + Args: + ds_fn1: See `run_core_tests`. + ds_fn2: See `run_core_tests`. + num_outputs: See `run_core_tests`. + break_point: Break point. Optional. Defaults to num_outputs/2. + sparse_tensors: See `run_core_tests`. + verify_exhausted: See `gen_outputs`. + + Raises: + AssertionError if any test fails. + """ + break_point = num_outputs // 2 if not break_point else break_point + + # Skip `break_point` items and store the remaining produced from ds_fn1 + # in `expected`. + self.gen_outputs( + ds_fn1, [], + break_point, + sparse_tensors=sparse_tensors, + verify_exhausted=False) + expected = self.gen_outputs( + ds_fn1, [], + num_outputs - break_point, + ckpt_saved=True, + sparse_tensors=sparse_tensors, + verify_exhausted=verify_exhausted) + + # Generate `break_point` items from ds_fn1 and save checkpoint. + self.gen_outputs( + ds_fn1, [], + break_point, + sparse_tensors=sparse_tensors, + verify_exhausted=False) + + actual = [] + # Build graph for ds_fn2 but load checkpoint for ds_fn1. + with ops.Graph().as_default() as g: + _, get_next_op, saver = self._build_graph( + ds_fn2, sparse_tensors=sparse_tensors) + with self.test_session(graph=g) as sess: + self._restore(saver, sess) + for _ in range(num_outputs - break_point): + actual.append(sess.run(get_next_op)) + if verify_exhausted: + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + + self.match(expected, actual) + + def verify_restore_in_empty_graph(self, + ds_fn, + num_outputs, + break_point=None, + sparse_tensors=False, + verify_exhausted=True): + """Attempts to restore an iterator in an empty graph. + + Builds an input pipeline using ds_fn, runs it for `break_point` steps + and saves a checkpoint. Then builds a new empty graph, restores + the checkpoint from ds_fn and verifies that the restore is successful. + + Args: + ds_fn: See `run_core_tests`. + num_outputs: See `run_core_tests`. + break_point: Break point. Optional. Defaults to num_outputs/2. + sparse_tensors: See `run_core_tests`. + verify_exhausted: See `gen_outputs`. + + Raises: + AssertionError if any test fails. + """ + break_point = num_outputs // 2 if not break_point else break_point + + # Skip `break_point` items and store the remaining produced from ds_fn + # in `expected`. + self.gen_outputs( + ds_fn, [], + break_point, + sparse_tensors=sparse_tensors, + verify_exhausted=False) + expected = self.gen_outputs( + ds_fn, [], + num_outputs - break_point, + ckpt_saved=True, + sparse_tensors=sparse_tensors, + verify_exhausted=verify_exhausted) + + # Generate `break_point` items from ds_fn and save checkpoint. + self.gen_outputs( + ds_fn, [], + break_point, + sparse_tensors=sparse_tensors, + verify_exhausted=False) + + actual = [] + # Build an empty graph but load checkpoint for ds_fn. + with ops.Graph().as_default() as g: + get_next_op, saver = self._build_empty_graph( + ds_fn, sparse_tensors=sparse_tensors) + with self.test_session(graph=g) as sess: + self._restore(saver, sess) + for _ in range(num_outputs - break_point): + actual.append(sess.run(get_next_op)) + if verify_exhausted: + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + + self.match(expected, actual) + + def verify_error_on_save(self, + ds_fn, + num_outputs, + error, + break_point=None, + sparse_tensors=False): + """Attempts to save a non-saveable iterator. + + Args: + ds_fn: See `run_core_tests`. + num_outputs: See `run_core_tests`. + error: Declared error when trying to save iterator. + break_point: Break point. Optional. Defaults to num_outputs/2. + sparse_tensors: See `run_core_tests`. + + Raises: + AssertionError if any test fails. + """ + + break_point = num_outputs // 2 if not break_point else break_point + with ops.Graph().as_default() as g: + init_op, get_next_op, saver = self._build_graph( + ds_fn, sparse_tensors=sparse_tensors) + with self.test_session(graph=g) as sess: + sess.run(variables.global_variables_initializer()) + sess.run(init_op) + for _ in range(break_point): + sess.run(get_next_op) + with self.assertRaises(error): + self._save(sess, saver) + + def verify_run_with_breaks(self, + ds_fn, + break_points, + num_outputs, + init_before_restore=False, + sparse_tensors=False, + verify_exhausted=True): + """Verifies that ds_fn() produces the same outputs with and without breaks. + + 1. Builds a Dataset using `ds_fn` and produces `num_outputs` items from it + *without* stopping at break points. + 2. Builds a Dataset using `ds_fn` and produces `num_outputs` items from it + with stopping at break points. + + Deep matches outputs from 1 and 2. + + Args: + ds_fn: See `gen_outputs`. + break_points: See `gen_outputs`. + num_outputs: See `gen_outputs`. + init_before_restore: See `gen_outputs`. + sparse_tensors: See `run_core_tests`. + verify_exhausted: See `gen_outputs`. + + Raises: + AssertionError if any test fails. + """ + expected = self.gen_outputs( + ds_fn, [], + num_outputs, + init_before_restore=init_before_restore, + sparse_tensors=sparse_tensors, + verify_exhausted=verify_exhausted) + + actual = self.gen_outputs( + ds_fn, + break_points, + num_outputs, + init_before_restore=init_before_restore, + sparse_tensors=sparse_tensors, + verify_exhausted=verify_exhausted) + + self.match(expected, actual) + + def gen_outputs(self, + ds_fn, + break_points, + num_outputs, + ckpt_saved=False, + init_before_restore=False, + sparse_tensors=False, + verify_exhausted=True): + """Generates elements from input dataset while stopping at break points. + + Produces `num_outputs` outputs and saves the state of the iterator in the + Saver checkpoint. + + Args: + ds_fn: 0-argument function that returns the dataset. + break_points: A list of integers. For each `break_point` in + `break_points`, we produce outputs till `break_point` number of items + have been produced and then checkpoint the state. The current graph + and session are destroyed and a new graph and session are used to + produce outputs till next checkpoint or till `num_outputs` elements + have been produced. `break_point` must be <= `num_outputs`. + num_outputs: The total number of outputs to produce from the iterator. + ckpt_saved: Whether a checkpoint already exists. If False, we build the + graph from ds_fn. + init_before_restore: Whether init should be called before saver.restore. + This is just so that we can verify that restoring an already initialized + iterator works. + sparse_tensors: Whether dataset is built from SparseTensor(s). + verify_exhausted: Whether to verify that the iterator has been exhausted + after producing `num_outputs` elements. + + Returns: + A list of `num_outputs` items. + """ + outputs = [] + + def get_ops(): + if ckpt_saved: + saver = self._import_meta_graph() + init_op, get_next_op = self._get_iterator_ops_from_collection( + ds_fn, sparse_tensors=sparse_tensors) + else: + init_op, get_next_op, saver = self._build_graph( + ds_fn, sparse_tensors=sparse_tensors) + return init_op, get_next_op, saver + + for i in range(len(break_points) + 1): + with ops.Graph().as_default() as g: + init_op, get_next_op, saver = get_ops() + with self.test_session(graph=g) as sess: + if ckpt_saved: + if init_before_restore: + sess.run(variables.global_variables_initializer()) + sess.run(init_op) + self._restore(saver, sess) + else: + sess.run(variables.global_variables_initializer()) + sess.run(init_op) + start = break_points[i - 1] if i > 0 else 0 + end = break_points[i] if i < len(break_points) else num_outputs + num_iters = end - start + for _ in range(num_iters): + outputs.append(sess.run(get_next_op)) + if i == len(break_points) and verify_exhausted: + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + self._save(sess, saver) + ckpt_saved = True + + return outputs + + def match(self, expected, actual): + """Matches nested structures. + + Recursively matches shape and values of `expected` and `actual`. + Handles scalars, numpy arrays and other python sequence containers + e.g. list, dict. + + Args: + expected: Nested structure 1. + actual: Nested structure 2. + + Raises: + AssertionError if matching fails. + """ + if isinstance(expected, np.ndarray): + expected = expected.tolist() + if isinstance(actual, np.ndarray): + actual = actual.tolist() + self.assertEqual(type(expected), type(actual)) + + if nest.is_sequence(expected): + self.assertEqual(len(expected), len(actual)) + if isinstance(expected, dict): + for key1, key2 in zip(sorted(expected), sorted(actual)): + self.assertEqual(key1, key2) + self.match(expected[key1], actual[key2]) + else: + for item1, item2 in zip(expected, actual): + self.match(item1, item2) + else: + self.assertEqual(expected, actual) + + def does_not_match(self, expected, actual): + with self.assertRaises(AssertionError): + self.match(expected, actual) + + def gen_break_points(self, num_outputs, num_samples=10): + """Generates `num_samples` breaks points in [0, num_outputs].""" + return np.linspace(0, num_outputs, num_samples, dtype=int) + + def _build_graph(self, ds_fn, sparse_tensors=False): + iterator = ds_fn().make_initializable_iterator() + + saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator) + ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) + init_op = iterator.initializer + if sparse_tensors: + get_next = sparse_tensor.SparseTensor(*iterator.get_next()) + else: + get_next = iterator.get_next() + self._add_iterator_ops_to_collection(init_op, get_next, sparse_tensors) + saver = saver_lib.Saver(allow_empty=True) + return init_op, get_next, saver + + def _build_empty_graph(self, ds_fn, sparse_tensors=False): + iterator = iterator_ops.Iterator.from_structure( + self._get_output_types(ds_fn), self._get_output_shapes(ds_fn)) + saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator) + ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) + if sparse_tensors: + get_next = sparse_tensor.SparseTensor(*iterator.get_next()) + else: + get_next = iterator.get_next() + saver = saver_lib.Saver(allow_empty=True) + return get_next, saver + + def _add_iterator_ops_to_collection(self, + init_op, + get_next, + sparse_tensors=False): + ops.add_to_collection("iterator_ops", init_op) + # `get_next` may be a tuple e.g. in TensorSliceDataset. Since Collections + # do not support tuples we flatten the tensors and restore the shape in + # `_get_iterator_ops_from_collection`. + if sparse_tensors: + ops.add_to_collection("iterator_ops", get_next.indices) + ops.add_to_collection("iterator_ops", get_next.values) + ops.add_to_collection("iterator_ops", get_next.dense_shape) + else: + for el in nest.flatten(get_next): + ops.add_to_collection("iterator_ops", el) + + def _get_iterator_ops_from_collection(self, ds_fn, sparse_tensors=False): + all_ops = ops.get_collection("iterator_ops") + if sparse_tensors: + init_op, indices, values, dense_shape = all_ops + return init_op, sparse_tensor.SparseTensor(indices, values, dense_shape) + else: + return all_ops[0], nest.pack_sequence_as( + self._get_output_types(ds_fn), all_ops[1:]) + + def _get_output_types(self, ds_fn): + with ops.Graph().as_default(): + return ds_fn().output_types + + def _get_output_shapes(self, ds_fn): + with ops.Graph().as_default(): + return ds_fn().output_shapes + + def _ckpt_path(self): + return os.path.join(self.get_temp_dir(), "iterator") + + def _latest_ckpt(self): + return saver_lib.latest_checkpoint(self.get_temp_dir()) + + def _save(self, sess, saver): + saver.save(sess, self._ckpt_path()) + + def _restore(self, saver, sess): + saver.restore(sess, self._latest_ckpt()) + + def _import_meta_graph(self): + meta_file_path = self._ckpt_path() + ".meta" + return saver_lib.import_meta_graph(meta_file_path) + + def _delete_ckpt(self): + # Remove all checkpoint files. + prefix = self._ckpt_path() + pattern = prefix + "*" + files = gfile.Glob(pattern) + map(gfile.Remove, files) diff --git a/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py index 00323da3110bb7f32b589f72e4e867f9c71e92ee..5921be2ae89ba1bbbb8d6e3a509cf49c65949544 100644 --- a/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py @@ -19,9 +19,11 @@ from __future__ import print_function import numpy as np +from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import functional_ops from tensorflow.python.ops import math_ops @@ -124,6 +126,74 @@ class FilterDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) + def assertSparseValuesEqual(self, a, b): + self.assertAllEqual(a.indices, b.indices) + self.assertAllEqual(a.values, b.values) + self.assertAllEqual(a.dense_shape, b.dense_shape) + + def testSparse(self): + + def _map_fn(i): + return sparse_tensor.SparseTensorValue( + indices=np.array([[0, 0]]), + values=(i * np.array([1])), + dense_shape=np.array([1, 1])), i + + def _filter_fn(_, i): + return math_ops.equal(i % 2, 0) + + iterator = ( + dataset_ops.Dataset.range(10).map(_map_fn).filter(_filter_fn).map( + lambda x, i: x).make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.test_session() as sess: + sess.run(init_op) + for i in range(5): + actual = sess.run(get_next) + self.assertTrue(isinstance(actual, sparse_tensor.SparseTensorValue)) + self.assertSparseValuesEqual(actual, _map_fn(i * 2)[0]) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + +class FilterDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_filter_range_graph(self, div): + return dataset_ops.Dataset.range(100).filter( + lambda x: math_ops.not_equal(math_ops.mod(x, div), 2)) + + def testFilterCore(self): + div = 3 + num_outputs = np.sum([x % 3 is not 2 for x in range(100)]) + self.run_core_tests(lambda: self._build_filter_range_graph(div), + lambda: self._build_filter_range_graph(div * 2), + num_outputs) + + def _build_filter_dict_graph(self): + return dataset_ops.Dataset.range(10).map( + lambda x: {"foo": x * 2, "bar": x ** 2}).filter( + lambda d: math_ops.equal(d["bar"] % 2, 0)).map( + lambda d: d["foo"] + d["bar"]) + + def testFilterDictCore(self): + num_outputs = np.sum([(x**2) % 2 == 0 for x in range(10)]) + self.run_core_tests(self._build_filter_dict_graph, None, num_outputs) + + def _build_sparse_filter(self): + + def _map_fn(i): + return sparse_tensor.SparseTensor( + indices=[[0, 0]], values=(i * [1]), dense_shape=[1, 1]), i + + def _filter_fn(_, i): + return math_ops.equal(i % 2, 0) + + return dataset_ops.Dataset.range(10).map(_map_fn).filter(_filter_fn).map( + lambda x, i: x) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/flat_map_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/flat_map_dataset_op_test.py index 2a582ae6620ac8276d290c7b995588640e36929c..d4fbaa5cdcdd315aa0524134b48eb0515169722c 100644 --- a/tensorflow/contrib/data/python/kernel_tests/flat_map_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/flat_map_dataset_op_test.py @@ -17,16 +17,22 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import itertools import random import numpy as np +from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import dataset_ops from tensorflow.python.client import session +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors -from tensorflow.python.ops import array_ops +from tensorflow.python.framework import function +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import sparse_ops +from tensorflow.python.ops import variable_scope from tensorflow.python.platform import test from tensorflow.python.training import server_lib @@ -123,154 +129,101 @@ class FlatMapDatasetTest(test.TestCase): sess.run(get_next) # pylint: enable=g-long-lambda + def testSparse(self): + def _map_fn(i): + return sparse_tensor.SparseTensorValue( + indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2]) -class InterleaveDatasetTest(test.TestCase): - - def _interleave(self, lists, cycle_length, block_length): - num_open = 0 - - # `all_iterators` acts as a queue of iterators over each element of `lists`. - all_iterators = [iter(l) for l in lists] - - # `open_iterators` are the iterators whose elements are currently being - # interleaved. - open_iterators = [] - for i in range(cycle_length): - if all_iterators: - open_iterators.append(all_iterators.pop(0)) - num_open += 1 - else: - open_iterators.append(None) - - while num_open or all_iterators: - for i in range(cycle_length): - if open_iterators[i] is None: - if all_iterators: - open_iterators[i] = all_iterators.pop(0) - num_open += 1 - else: - continue - for _ in range(block_length): - try: - yield next(open_iterators[i]) - except StopIteration: - open_iterators[i] = None - num_open -= 1 - break - - def testPythonImplementation(self): - input_lists = [[4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6], - [4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6]] - - # Cycle length 1 acts like `Dataset.flat_map()`. - expected_elements = itertools.chain(*input_lists) - for expected, produced in zip( - expected_elements, self._interleave(input_lists, 1, 1)): - self.assertEqual(expected, produced) - - # Cycle length > 1. - expected_elements = [4, 5, 4, 5, 4, 5, 4, - 5, 5, 6, 6, # NOTE(mrry): When we cycle back - # to a list and are already at - # the end of that list, we move - # on to the next element. - 4, 6, 4, 6, 4, 6, 4, 6, 5, 6, 5, 6, 5, 6, 5, 6, 5] - for expected, produced in zip( - expected_elements, self._interleave(input_lists, 2, 1)): - self.assertEqual(expected, produced) - - # Cycle length > 1 and block length > 1. - expected_elements = [4, 4, 4, 5, 5, 5, 4, 5, 5, 6, 6, 6, 4, 4, 4, 6, 6, 6, - 4, 5, 5, 5, 6, 6, 6, 5, 5, 6, 6, 6] - for expected, produced in zip( - expected_elements, self._interleave(input_lists, 2, 3)): - self.assertEqual(expected, produced) - - # Cycle length > len(input_values). - expected_elements = [4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, - 4, 4, 5, 5, 6, 6, 5, 6, 6, 5, 6, 6] - for expected, produced in zip( - expected_elements, self._interleave(input_lists, 7, 2)): - self.assertEqual(expected, produced) - - def testInterleaveDataset(self): - input_values = array_ops.placeholder(dtypes.int64, shape=[None]) - cycle_length = array_ops.placeholder(dtypes.int64, shape=[]) - block_length = array_ops.placeholder(dtypes.int64, shape=[]) - - repeat_count = 2 - - dataset = ( - dataset_ops.Dataset.from_tensor_slices(input_values) - .repeat(repeat_count) - .interleave(lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x), - cycle_length, block_length)) - iterator = dataset.make_initializable_iterator() + def _flat_map_fn(x): + return dataset_ops.Dataset.from_tensor_slices( + sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values)) + + iterator = ( + dataset_ops.Dataset.range(10).map(_map_fn).flat_map(_flat_map_fn) + .make_initializable_iterator()) init_op = iterator.initializer - next_element = iterator.get_next() + get_next = iterator.get_next() with self.test_session() as sess: - # Cycle length 1 acts like `Dataset.flat_map()`. - sess.run(init_op, feed_dict={input_values: [4, 5, 6], - cycle_length: 1, block_length: 3}) - - for expected_element in self._interleave( - [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 1, 3): - self.assertEqual(expected_element, sess.run(next_element)) - - # Cycle length > 1. - # expected: [4, 5, 4, 5, 4, 5, 4, 5, 5, 6, 6, 4, 6, 4, 6, 4, 6, 4, 6, 5, - # 6, 5, 6, 5, 6, 5, 6, 5] - sess.run(init_op, feed_dict={input_values: [4, 5, 6], - cycle_length: 2, block_length: 1}) - for expected_element in self._interleave( - [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 2, 1): - self.assertEqual(expected_element, sess.run(next_element)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - # Cycle length > 1 and block length > 1. - # expected: [4, 4, 4, 5, 5, 5, 4, 5, 5, 6, 6, 6, 4, 4, 4, 6, 6, 6, 4, 5, - # 5, 5, 6, 6, 6, 5, 5, 6, 6, 6] - sess.run(init_op, feed_dict={input_values: [4, 5, 6], - cycle_length: 2, block_length: 3}) - for expected_element in self._interleave( - [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 2, 3): - self.assertEqual(expected_element, sess.run(next_element)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - # Cycle length > len(input_values) * repeat_count. - # expected: [4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, 4, 4, - # 5, 5, 6, 6, 5, 6, 6, 5, 6, 6] - sess.run(init_op, feed_dict={input_values: [4, 5, 6], - cycle_length: 7, block_length: 2}) - for expected_element in self._interleave( - [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 7, 2): - self.assertEqual(expected_element, sess.run(next_element)) + sess.run(init_op) + for i in range(10): + for j in range(2): + expected = [i, 0] if j % 2 == 0 else [0, -i] + self.assertAllEqual(expected, sess.run(get_next)) with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) + sess.run(get_next) - # Empty input. - sess.run(init_op, feed_dict={input_values: [], - cycle_length: 2, block_length: 3}) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - # Non-empty input leading to empty output. - sess.run(init_op, feed_dict={input_values: [0, 0, 0], - cycle_length: 2, block_length: 3}) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - # Mixture of non-empty and empty interleaved datasets. - sess.run(init_op, feed_dict={input_values: [4, 0, 6], - cycle_length: 2, block_length: 3}) - for expected_element in self._interleave( - [[4] * 4, [], [6] * 6] * repeat_count, 2, 3): - self.assertEqual(expected_element, sess.run(next_element)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) +class FlatMapDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def testCore(self): + # Complicated way of saying range(start, start+25). + def build_ds(start): + + def map_fn(x): + return dataset_ops.Dataset.range(x, x + 5) + + return dataset_ops.Dataset.range(start, start + 5 * 5, 5).flat_map(map_fn) + + self.run_core_tests(lambda: build_ds(0), lambda: build_ds(10), 25) + + def testMapThenFlatMap(self): + + def build_ds(): + + def flat_map_fn(_): + + def map_fn(y): + return 10 * math_ops.to_int32(y) + + return dataset_ops.Dataset.range(100).map(map_fn) + + return dataset_ops.Dataset.range(5).flat_map(flat_map_fn) + + self.run_core_tests(build_ds, None, 500) + + def testCaptureDefunInMapFn(self): + + def build_ds(): + + def map_fn(x): + + @function.Defun(dtypes.int64) + def defun_fn(x): + return constant_op.constant(1000) + math_ops.to_int32(x) + + return dataset_ops.Dataset.from_tensor_slices([defun_fn(x)]) + + return dataset_ops.Dataset.range(100).flat_map(map_fn) + + self.run_core_tests(build_ds, None, 100) + + def testDisallowVariableCapture(self): + + def build_ds(): + test_var = variable_scope.get_variable( + name="test_var", shape=(), use_resource=True) + return dataset_ops.Dataset.range(5).flat_map( + lambda _: dataset_ops.Dataset.from_tensor_slices([test_var])) + + self.verify_error_on_save(build_ds, 5, errors.InvalidArgumentError) + + def testDisallowCapturingStatefulOps(self): + + def build_ds(): + + def flat_map_fn(_): + + def map_fn(x): + return random_ops.random_uniform( + (), 0, 10, dtype=dtypes.int32) * math_ops.to_int32(x) + + return dataset_ops.Dataset.range(100).map(map_fn) + + return dataset_ops.Dataset.range(5).flat_map(flat_map_fn) + + self.verify_error_on_save(build_ds, 500, errors.InvalidArgumentError) if __name__ == "__main__": diff --git a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py index 0aa9ea88de82b0851b0236d9412039d6573ab291..e66ed3f7aa2a512813ef353d2d0744ae67005884 100644 --- a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py @@ -22,18 +22,236 @@ import math import threading import time +import numpy as np from six.moves import zip_longest +from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import dataset_ops from tensorflow.contrib.data.python.ops import interleave_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import script_ops +from tensorflow.python.ops import sparse_ops from tensorflow.python.platform import test +class InterleaveDatasetTest(test.TestCase): + + def _interleave(self, lists, cycle_length, block_length): + num_open = 0 + + # `all_iterators` acts as a queue of iterators over each element of `lists`. + all_iterators = [iter(l) for l in lists] + + # `open_iterators` are the iterators whose elements are currently being + # interleaved. + open_iterators = [] + for i in range(cycle_length): + if all_iterators: + open_iterators.append(all_iterators.pop(0)) + num_open += 1 + else: + open_iterators.append(None) + + while num_open or all_iterators: + for i in range(cycle_length): + if open_iterators[i] is None: + if all_iterators: + open_iterators[i] = all_iterators.pop(0) + num_open += 1 + else: + continue + for _ in range(block_length): + try: + yield next(open_iterators[i]) + except StopIteration: + open_iterators[i] = None + num_open -= 1 + break + + def testPythonImplementation(self): + input_lists = [[4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6], + [4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6]] + + # Cycle length 1 acts like `Dataset.flat_map()`. + expected_elements = itertools.chain(*input_lists) + for expected, produced in zip( + expected_elements, self._interleave(input_lists, 1, 1)): + self.assertEqual(expected, produced) + + # Cycle length > 1. + expected_elements = [4, 5, 4, 5, 4, 5, 4, + 5, 5, 6, 6, # NOTE(mrry): When we cycle back + # to a list and are already at + # the end of that list, we move + # on to the next element. + 4, 6, 4, 6, 4, 6, 4, 6, 5, 6, 5, 6, 5, 6, 5, 6, 5] + for expected, produced in zip( + expected_elements, self._interleave(input_lists, 2, 1)): + self.assertEqual(expected, produced) + + # Cycle length > 1 and block length > 1. + expected_elements = [4, 4, 4, 5, 5, 5, 4, 5, 5, 6, 6, 6, 4, 4, 4, 6, 6, 6, + 4, 5, 5, 5, 6, 6, 6, 5, 5, 6, 6, 6] + for expected, produced in zip( + expected_elements, self._interleave(input_lists, 2, 3)): + self.assertEqual(expected, produced) + + # Cycle length > len(input_values). + expected_elements = [4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, + 4, 4, 5, 5, 6, 6, 5, 6, 6, 5, 6, 6] + for expected, produced in zip( + expected_elements, self._interleave(input_lists, 7, 2)): + self.assertEqual(expected, produced) + + def testInterleaveDataset(self): + input_values = array_ops.placeholder(dtypes.int64, shape=[None]) + cycle_length = array_ops.placeholder(dtypes.int64, shape=[]) + block_length = array_ops.placeholder(dtypes.int64, shape=[]) + + repeat_count = 2 + + dataset = ( + dataset_ops.Dataset.from_tensor_slices(input_values) + .repeat(repeat_count) + .interleave(lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x), + cycle_length, block_length)) + iterator = dataset.make_initializable_iterator() + init_op = iterator.initializer + next_element = iterator.get_next() + + with self.test_session() as sess: + # Cycle length 1 acts like `Dataset.flat_map()`. + sess.run(init_op, feed_dict={input_values: [4, 5, 6], + cycle_length: 1, block_length: 3}) + + for expected_element in self._interleave( + [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 1, 3): + self.assertEqual(expected_element, sess.run(next_element)) + + # Cycle length > 1. + # expected: [4, 5, 4, 5, 4, 5, 4, 5, 5, 6, 6, 4, 6, 4, 6, 4, 6, 4, 6, 5, + # 6, 5, 6, 5, 6, 5, 6, 5] + sess.run(init_op, feed_dict={input_values: [4, 5, 6], + cycle_length: 2, block_length: 1}) + for expected_element in self._interleave( + [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 2, 1): + self.assertEqual(expected_element, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + # Cycle length > 1 and block length > 1. + # expected: [4, 4, 4, 5, 5, 5, 4, 5, 5, 6, 6, 6, 4, 4, 4, 6, 6, 6, 4, 5, + # 5, 5, 6, 6, 6, 5, 5, 6, 6, 6] + sess.run(init_op, feed_dict={input_values: [4, 5, 6], + cycle_length: 2, block_length: 3}) + for expected_element in self._interleave( + [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 2, 3): + self.assertEqual(expected_element, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + # Cycle length > len(input_values) * repeat_count. + # expected: [4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, 4, 4, + # 5, 5, 6, 6, 5, 6, 6, 5, 6, 6] + sess.run(init_op, feed_dict={input_values: [4, 5, 6], + cycle_length: 7, block_length: 2}) + for expected_element in self._interleave( + [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 7, 2): + self.assertEqual(expected_element, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + # Empty input. + sess.run(init_op, feed_dict={input_values: [], + cycle_length: 2, block_length: 3}) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + # Non-empty input leading to empty output. + sess.run(init_op, feed_dict={input_values: [0, 0, 0], + cycle_length: 2, block_length: 3}) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + # Mixture of non-empty and empty interleaved datasets. + sess.run(init_op, feed_dict={input_values: [4, 0, 6], + cycle_length: 2, block_length: 3}) + for expected_element in self._interleave( + [[4] * 4, [], [6] * 6] * repeat_count, 2, 3): + self.assertEqual(expected_element, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testSparse(self): + + def _map_fn(i): + return sparse_tensor.SparseTensorValue( + indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2]) + + def _interleave_fn(x): + return dataset_ops.Dataset.from_tensor_slices( + sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values)) + + iterator = ( + dataset_ops.Dataset.range(10).map(_map_fn).interleave( + _interleave_fn, cycle_length=1).make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.test_session() as sess: + sess.run(init_op) + for i in range(10): + for j in range(2): + expected = [i, 0] if j % 2 == 0 else [0, -i] + self.assertAllEqual(expected, sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + +class InterleaveDatasetSeriazationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_iterator_graph(self, input_values, cycle_length, block_length): + repeat_count = 2 + return dataset_ops.Dataset.from_tensor_slices(input_values).repeat( + repeat_count).interleave( + lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x), + cycle_length, block_length) + + def testSerializationCore(self): + input_values = np.array([4, 5, 6], dtype=np.int64) + num_outputs = np.sum(input_values) * 2 + # cycle_length > 1, block_length > 1 + cycle_length = 2 + block_length = 3 + # pylint: disable=g-long-lambda + self.run_core_tests( + lambda: self._build_iterator_graph( + input_values, cycle_length, block_length), + lambda: self._build_iterator_graph( + input_values, cycle_length * 2, block_length * 1), + num_outputs) + # cycle_length = 1 + cycle_length = 1 + block_length = 3 + self.run_core_tests( + lambda: self._build_iterator_graph( + input_values, cycle_length, block_length), + None, num_outputs) + # block_length = 1 + cycle_length = 2 + block_length = 1 + self.run_core_tests( + lambda: self._build_iterator_graph( + input_values, cycle_length, block_length), + None, num_outputs) + # pylint: enable=g-long-lambda + + class ParallelInterleaveDatasetTest(test.TestCase): def setUp(self): @@ -547,5 +765,31 @@ class ParallelInterleaveDatasetTest(test.TestCase): def testTooManyReadersSloppy(self): self._testTooManyReaders(sloppy=True) + def testSparse(self): + def _map_fn(i): + return sparse_tensor.SparseTensor( + indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2]) + + def _interleave_fn(x): + return dataset_ops.Dataset.from_tensor_slices( + sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values)) + + dataset = dataset_ops.Dataset.range(10).map(_map_fn) + iterator = dataset.apply( + interleave_ops.parallel_interleave( + _interleave_fn, cycle_length=1)).make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.test_session() as sess: + sess.run(init_op) + for i in range(10): + for j in range(2): + expected = [i, 0] if j % 2 == 0 else [0, -i] + self.assertAllEqual(expected, sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py index 8a1d99499be702d91f87f65f443261b47ce5c5cd..e9a07da84a8c80c09ebd4dab0b1d69febe1c9790 100644 --- a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py @@ -20,15 +20,18 @@ from collections import namedtuple import os import threading -from collections import namedtuple import numpy as np -from tensorflow.contrib.data.python.ops import error_ops +from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import dataset_ops +from tensorflow.contrib.data.python.ops import error_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.framework import function +from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import functional_ops @@ -37,6 +40,7 @@ from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import script_ops +from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import string_ops from tensorflow.python.ops import variable_scope from tensorflow.python.platform import test @@ -616,6 +620,182 @@ class MapDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) + def assertSparseValuesEqual(self, a, b): + self.assertAllEqual(a.indices, b.indices) + self.assertAllEqual(a.values, b.values) + self.assertAllEqual(a.dense_shape, b.dense_shape) + + def testSparse(self): + + def _sparse(i): + return sparse_tensor.SparseTensorValue( + indices=np.array([[0, 0]]), + values=(i * np.array([1])), + dense_shape=np.array([1, 1])) + + iterator = (dataset_ops.Dataset.range(10) + .map(_sparse) + .make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.test_session() as sess: + sess.run(init_op) + for i in range(10): + actual = sess.run(get_next) + self.assertTrue(isinstance(actual, sparse_tensor.SparseTensorValue)) + self.assertSparseValuesEqual(actual, _sparse(i)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testSparseChain(self): + + def _sparse(i): + return sparse_tensor.SparseTensorValue( + indices=np.array([[0, 0]]), + values=(i * np.array([1])), + dense_shape=np.array([1, 1])) + + def _check(i): + self.assertTrue(sparse_tensor.is_sparse(i)) + return sparse_ops.sparse_concat(0, [i, i]) + + iterator = ( + dataset_ops.Dataset.range(10).map(_sparse).map(_check) + .make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.test_session() as sess: + sess.run(init_op) + for i in range(10): + actual = sess.run(get_next) + self.assertTrue(isinstance(actual, sparse_tensor.SparseTensorValue)) + self.assertSparseValuesEqual(actual, _check(_sparse(i)).eval()) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testCaptureResourceInMapFn(self): + + def _build_ds(iterator): + + def _map_fn(x): + get_next = iterator.get_next() + return x * get_next + + return dataset_ops.Dataset.range(10).map(_map_fn) + + def _build_graph(): + captured_iterator = dataset_ops.Dataset.range( + 10).make_initializable_iterator() + ds = _build_ds(captured_iterator) + iterator = ds.make_initializable_iterator() + init_op = iterator.initializer + return captured_iterator.initializer, init_op + + with ops.Graph().as_default() as g: + captured_init_op, init_op = _build_graph() + with self.test_session(graph=g) as sess: + sess.run(captured_init_op) + with self.assertRaises(errors.UnimplementedError): + # CapturedFunction does not support capturing IteratorResource. + sess.run(init_op) + + +class MapDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def setUp(self): + self._tensor_slice_len = 7 + self._num_epochs = 14 + self._num_outputs = self._tensor_slice_len * self._num_epochs + + def _build_ds(self, multiplier=37.0): + components = (np.arange(self._tensor_slice_len), np.array([[1, 2, 3]]) * + np.arange(self._tensor_slice_len)[:, np.newaxis], + np.array(multiplier) * np.arange(self._tensor_slice_len)) + + def _map_fn(x, y, z): + return math_ops.square(x), math_ops.square(y), math_ops.square(z) + + return (dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn) + .repeat(self._num_epochs)) + + def testSaveRestoreCore(self): + self.run_core_tests( + self._build_ds, + lambda: self._build_ds(multiplier=15.0), + self._num_outputs) + + def testSaveStatefulFunction(self): + + def _build_ds(): + + def _map_fn(x): + return random_ops.random_uniform( + (), 0, 10, dtype=dtypes.int32) * math_ops.to_int32(x) + + return dataset_ops.Dataset.range(100).map(_map_fn) + + self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError) + + def testCaptureVariableInMapFn(self): + + def _build_ds(): + counter_var = variable_scope.get_variable( + "counter", (), dtypes.int32, use_resource=True) + return (dataset_ops.Dataset.from_tensors(0).repeat(10).map( + lambda _: counter_var.assign_add(1))) + + self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError) + + def testCaptureDefunInMapFn(self): + num_outputs = 100 + + def _build_ds(): + + @function.Defun(dtypes.int64) + def defun_fn(x): + return constant_op.constant(1000) + math_ops.to_int32(x) + + return dataset_ops.Dataset.range(num_outputs).map(defun_fn) + + self.run_core_tests(_build_ds, None, num_outputs) + + def testBuildDefunInMapFn(self): + num_outputs = 100 + + def _build_ds(): + + @function.Defun(dtypes.int64) + def defun_fn(x): + + @function.Defun(dtypes.int32) + def defun_fn_deep(x): + return constant_op.constant(1000) + math_ops.to_int32(x) + + return constant_op.constant(11000) + defun_fn_deep(math_ops.to_int32(x)) + + return dataset_ops.Dataset.range(num_outputs).map(defun_fn) + + self.run_core_tests(_build_ds, None, num_outputs) + + +class IgnoreErrorsSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_ds(self, components): + return dataset_ops.Dataset.from_tensor_slices(components).map( + lambda x: array_ops.check_numerics(x, "message")).apply( + error_ops.ignore_errors()) + + def testIgnoreErrorsCore(self): + components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32) + diff_components = np.array([1., 2., 3., np.nan]).astype(np.float32) + num_outputs = 4 + self.run_core_tests(lambda: self._build_ds(components), + lambda: self._build_ds(diff_components), num_outputs) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid_impl.py b/tensorflow/contrib/data/python/kernel_tests/prefetch_dataset_op_test.py similarity index 51% rename from tensorflow/contrib/distributions/python/ops/bijectors/sigmoid_impl.py rename to tensorflow/contrib/data/python/kernel_tests/prefetch_dataset_op_test.py index a640dfe7dfbcce96261589c7fc49107deaefdd54..3d120a3071ef730f21221e3291d8c84385b51aa3 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid_impl.py +++ b/tensorflow/contrib/data/python/kernel_tests/prefetch_dataset_op_test.py @@ -12,37 +12,28 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Sigmoid bijector.""" - +"""Tests for the experimental input pipeline ops.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn_ops -from tensorflow.python.ops.distributions import bijector - - -__all__ = [ - "Sigmoid", -] - +from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.platform import test -class Sigmoid(bijector.Bijector): - """Bijector which computes `Y = g(X) = 1 / (1 + exp(-X))`.""" - def __init__(self, validate_args=False, name="sigmoid"): - super(Sigmoid, self).__init__( - event_ndims=0, validate_args=validate_args, name=name) +class PrefetchDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): - def _forward(self, x): - return math_ops.sigmoid(x) + def build_dataset(self, seed): + return dataset_ops.Dataset.range(100).prefetch(10).shuffle( + buffer_size=10, seed=seed, reshuffle_each_iteration=False) - def _inverse(self, y): - return math_ops.log(y) - math_ops.log1p(-y) + def testCore(self): + num_outputs = 100 + self.run_core_tests(lambda: self.build_dataset(10), + lambda: self.build_dataset(20), num_outputs) - def _inverse_log_det_jacobian(self, y): - return -math_ops.log(y) - math_ops.log1p(-y) - def _forward_log_det_jacobian(self, x): - return -nn_ops.softplus(-x) - nn_ops.softplus(x) +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py index 539c6f215536f50a0b56f173a9240542faa2e643..dc3e38db59301bf1819999f479171af35930e9d2 100644 --- a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py @@ -29,6 +29,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.framework import test_util +from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import test @@ -85,6 +86,9 @@ class StagingAreaOpsTest(test.TestCase): self._event.wait() elem = sess.run(prefetch_op) self.assertEqual(elem, [5.0]) + sess.run( + resource_variable_ops.destroy_resource_op( + buffer_resource_handle, ignore_lookup_error=True)) def testSameDeviceCPU(self): self._prefetch_fn_helper("same_device_cpu", diff --git a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py index f59ac760dc83a504e563f055b91f1002cb0c80fc..8e6ad061a11752ab7b1ffc13c90b4fa52f67d6aa 100644 --- a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py @@ -19,6 +19,7 @@ from __future__ import print_function import os +from tensorflow.contrib.data.python.ops import counter from tensorflow.contrib.data.python.ops import dataset_ops from tensorflow.contrib.data.python.ops import enumerate_ops from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops @@ -194,6 +195,27 @@ class RangeDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) + def testCounter(self): + """Test dataset construction using `count`.""" + iterator = (counter.Counter(start=3, step=4) + .make_one_shot_iterator()) + get_next = iterator.get_next() + self.assertEqual([], get_next.shape.as_list()) + self.assertEqual(dtypes.int64, get_next.dtype) + + negative_iterator = (counter.Counter(start=0, step=-1) + .make_one_shot_iterator()) + negative_get_next = negative_iterator.get_next() + + with self.test_session() as sess: + self.assertEqual(3, sess.run(get_next)) + self.assertEqual(3 + 4, sess.run(get_next)) + self.assertEqual(3 + 2 * 4, sess.run(get_next)) + + self.assertEqual(0, sess.run(negative_get_next)) + self.assertEqual(-1, sess.run(negative_get_next)) + self.assertEqual(-2, sess.run(negative_get_next)) + def _iterator_checkpoint_prefix(self): return os.path.join(self.get_temp_dir(), "iterator") diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py index 3ae8f71d77fa6ecf08e42bedac702b8f75eec309..1c42a3d855bc16c21e385d7108c3106884ae4f5e 100644 --- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py @@ -21,7 +21,7 @@ import gzip import os import zlib -from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops +from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import readers from tensorflow.core.example import example_pb2 from tensorflow.core.example import feature_pb2 @@ -30,18 +30,14 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape from tensorflow.python.lib.io import python_io from tensorflow.python.ops import array_ops -from tensorflow.python.ops import gen_dataset_ops -from tensorflow.python.ops import io_ops from tensorflow.python.ops import parsing_ops from tensorflow.python.platform import test -from tensorflow.python.training import saver as saver_lib from tensorflow.python.util import compat -class TextLineDatasetTest(test.TestCase): +class TextLineDatasetTestBase(test.TestCase): def _lineText(self, f, l): return compat.as_bytes("%d: %d" % (f, l)) @@ -79,6 +75,9 @@ class TextLineDatasetTest(test.TestCase): return filenames + +class TextLineDatasetTest(TextLineDatasetTestBase): + def _testTextLineDataset(self, compression_type=None): test_filenames = self._createFiles( 2, 5, crlf=True, compression_type=compression_type) @@ -165,282 +164,37 @@ class TextLineDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(iterator.get_next()) - def _ckpt_path(self): - return os.path.join(self.get_temp_dir(), "iterator") - - def _latest_ckpt(self): - return saver_lib.latest_checkpoint(self.get_temp_dir()) - - def _save(self, saver, sess): - saver.save(sess, self._ckpt_path()) - - def _restore(self, saver, sess): - saver.restore(sess, self._latest_ckpt()) - def _import_meta_graph(self): - meta_file_path = self._ckpt_path() + ".meta" - return saver_lib.import_meta_graph(meta_file_path) +class TextLineDatasetSerializationTest( + TextLineDatasetTestBase, + dataset_serialization_test_base.DatasetSerializationTestBase): - def _build_graph(self, - test_filenames, - compression_type=None, - build_saveable=True): - ds = readers.TextLineDataset( + def _build_iterator_graph(self, test_filenames, compression_type=None): + return readers.TextLineDataset( test_filenames, compression_type=compression_type, buffer_size=10) - iterator = ds.make_initializable_iterator() - if build_saveable: - saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator) - ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) - init_op = iterator.initializer - get_next = iterator.get_next() - ops.add_to_collection("iterator_ops", init_op) - ops.add_to_collection("iterator_ops", get_next) - saver = saver_lib.Saver(allow_empty=True) - return init_op, get_next, saver - - def _testReadWithBreaks(self, breaks, num_files=5, lines_per_file=5): - """Tests reading from input pipeline with regular breaks. - - At each break point the iterator state gets saved using Saver and reloaded - in a new Graph and session. - - Args: - breaks: List of counts of records after reading which iterator state is - checkpointed. Must to in non-decreasing order. - num_files: Total number of files. - lines_per_file: Total number of lines per file. - """ + + def testTextLineCore(self): compression_types = [None, "GZIP", "ZLIB"] + num_files = 5 + lines_per_file = 5 + num_outputs = num_files * lines_per_file for compression_type in compression_types: test_filenames = self._createFiles( num_files, lines_per_file, crlf=True, compression_type=compression_type) + # pylint: disable=cell-var-from-loop + self.run_core_tests( + lambda: self._build_iterator_graph(test_filenames, compression_type), + lambda: self._build_iterator_graph(test_filenames), num_outputs) + # pylint: enable=cell-var-from-loop - # Collect ground truth. - total_records = num_files * lines_per_file - expected_records = [] - with ops.Graph().as_default() as g: - init_op, get_next, saver = self._build_graph( - test_filenames, compression_type=compression_type) - with self.test_session(graph=g) as sess: - sess.run(init_op) - for _ in range(total_records): - expected_records.append(sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Simulate run with breaks. - actual_records = [] - next_record_index = 0 - load_from_ckpt = False - breaks.append(total_records) - for break_index in breaks: - with ops.Graph().as_default() as g: - if not load_from_ckpt: - init_op, get_next, saver = self._build_graph( - test_filenames, compression_type=compression_type) - else: - saver = self._import_meta_graph() - init_op, get_next = ops.get_collection("iterator_ops") - with self.test_session(graph=g) as sess: - if not load_from_ckpt: - sess.run(init_op) - else: - self._restore(saver, sess) - while next_record_index != break_index: - actual_records.append(sess.run(get_next)) - next_record_index += 1 - if break_index == total_records: - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - self._save(saver, sess) - load_from_ckpt = True - self.assertEqual(actual_records, expected_records) - - def testSaveAtFileBoundary(self): - self._testReadWithBreaks([10]) - - def testSaveWithinFile(self): - self._testReadWithBreaks([12]) - - def testSaveUnusedIterator(self): - self._testReadWithBreaks([0]) - - def testSaveRestoreIdempotence(self): - # Attempt to save an iterator immediately after it has been - # restored. - self._testReadWithBreaks([0, 0]) - self._testReadWithBreaks([10, 10]) - self._testReadWithBreaks([12, 12]) - - def testMultipleBreaks(self): - self._testReadWithBreaks([0, 4, 20]) - - def testRestoreExhaustedIterator(self): - num_files = 2 - lines_per_file = 5 - test_filenames = self._createFiles(num_files, lines_per_file, crlf=True) - - with ops.Graph().as_default() as g: - init_op, get_next, saver = self._build_graph(test_filenames) - with self.test_session(graph=g) as sess: - sess.run(init_op) - for _ in range(num_files * lines_per_file): - sess.run(get_next) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - self._save(saver, sess) - - with ops.Graph().as_default() as g: - with self.test_session(graph=g) as sess: - saver = self._import_meta_graph() - self._restore(saver, sess) - _, get_next = ops.get_collection("iterator_ops") - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testInitThenRestore(self): - num_files = 5 - lines_per_file = 5 - total_records = num_files * lines_per_file - break_record = 8 - test_filenames = self._createFiles(num_files, lines_per_file, crlf=True) - - expected_records = [] - with ops.Graph().as_default() as g: - init_op, get_next, saver = self._build_graph(test_filenames) - with self.test_session(graph=g) as sess: - sess.run(init_op) - for _ in range(break_record): - sess.run(get_next) - self._save(saver, sess) - for _ in range(total_records - break_record): - expected_records.append(sess.run(get_next)) - - actual_records = [] - with ops.Graph().as_default() as g: - with self.test_session(graph=g) as sess: - saver = self._import_meta_graph() - init_op, get_next = ops.get_collection("iterator_ops") - sess.run(init_op) - self._restore(saver, sess) - for _ in range(total_records - break_record): - actual_records.append(sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - self.assertEqual(actual_records, expected_records) - - def testRestoreInModifiedGraph(self): - num_files = 5 - lines_per_file = 5 - total_records = num_files * lines_per_file - break_record = 8 - test_filenames = self._createFiles(num_files, lines_per_file, crlf=True) - - expected_records = [] - with ops.Graph().as_default() as g: - init_op, get_next, saver = self._build_graph(test_filenames) - with self.test_session(graph=g) as sess: - sess.run(init_op) - for _ in range(break_record): - sess.run(get_next) - self._save(saver, sess) - for _ in range(total_records - break_record): - expected_records.append(sess.run(get_next)) - - actual_records = [] - with ops.Graph().as_default() as g: - with self.test_session(graph=g) as sess: - init_op, get_next, saver = self._build_graph( - test_filenames, compression_type="GZIP") - self._restore(saver, sess) - for _ in range(total_records - break_record): - actual_records.append(sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - self.assertEqual(actual_records, expected_records) - - def testRestoreInModifiedGraphThenInit(self): - num_files = 5 - lines_per_file = 5 - total_records = num_files * lines_per_file - break_record = 8 - test_filenames = self._createFiles(num_files, lines_per_file, crlf=True) - - expected_records = [] - with ops.Graph().as_default() as g: - init_op, get_next, saver = self._build_graph(test_filenames) - with self.test_session(graph=g) as sess: - sess.run(init_op) - for _ in range(break_record): - expected_records.append(sess.run(get_next)) - self._save(saver, sess) - for _ in range(total_records - break_record): - expected_records.append(sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Test that calling the init_op overrides the restored iterator. The - # iterator for the old graph was build to read uncompressed files and - # would fail when trying to read the new files. - actual_records = [] - with ops.Graph().as_default() as g: - with self.test_session(graph=g) as sess: - test_filenames = self._createFiles( - num_files, lines_per_file, crlf=True, compression_type="GZIP") - init_op, get_next, saver = self._build_graph( - test_filenames, compression_type="GZIP") - self._restore(saver, sess) - sess.run(init_op) - for _ in range(total_records): - actual_records.append(sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - self.assertEqual(actual_records, expected_records) - - def testDoNotRestoreIterator(self): - num_files = 5 - lines_per_file = 5 - total_records = num_files * lines_per_file - break_record = 8 - test_filenames = self._createFiles(num_files, lines_per_file, crlf=True) - - expected_records = [] - with ops.Graph().as_default() as g: - init_op, get_next, saver = self._build_graph(test_filenames) - with self.test_session(graph=g) as sess: - sess.run(init_op) - for _ in range(break_record): - expected_records.append(sess.run(get_next)) - self._save(saver, sess) - for _ in range(total_records - break_record): - expected_records.append(sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - actual_records = [] - with ops.Graph().as_default() as g: - with self.test_session(graph=g) as sess: - init_op, get_next, saver = self._build_graph( - test_filenames, build_saveable=False) - self._restore(saver, sess) - with self.assertRaises(errors.FailedPreconditionError): - sess.run(get_next) - sess.run(init_op) - for _ in range(total_records): - actual_records.append(sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - self.assertEqual(actual_records, expected_records) - - -class FixedLengthRecordReaderTest(test.TestCase): +class FixedLengthRecordReaderTestBase(test.TestCase): def setUp(self): - super(FixedLengthRecordReaderTest, self).setUp() + super(FixedLengthRecordReaderTestBase, self).setUp() self._num_files = 2 self._num_records = 7 self._header_bytes = 5 @@ -462,6 +216,9 @@ class FixedLengthRecordReaderTest(test.TestCase): f.write(b"F" * self._footer_bytes) return filenames + +class FixedLengthRecordReaderTest(FixedLengthRecordReaderTestBase): + def testFixedLengthRecordDataset(self): test_filenames = self._createFiles() filenames = array_ops.placeholder(dtypes.string, shape=[None]) @@ -547,304 +304,29 @@ class FixedLengthRecordReaderTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(iterator.get_next()) - def _iterator_checkpoint_path(self): - return os.path.join(self.get_temp_dir(), "iterator") - - def _save_op(self, iterator_resource): - iterator_state_variant = gen_dataset_ops.serialize_iterator( - iterator_resource) - save_op = io_ops.write_file( - self._iterator_checkpoint_path(), - parsing_ops.serialize_tensor(iterator_state_variant)) - return save_op - - def _restore_op(self, iterator_resource): - iterator_state_variant = parsing_ops.parse_tensor( - io_ops.read_file(self._iterator_checkpoint_path()), dtypes.variant) - restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource, - iterator_state_variant) - return restore_op - - def _build_iterator_graph(self, num_epochs): + +class FixedLengthRecordDatasetSerializationTest( + FixedLengthRecordReaderTestBase, + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_iterator_graph(self, num_epochs, compression_type=None): filenames = self._createFiles() - dataset = (readers.FixedLengthRecordDataset( - filenames, self._record_bytes, self._header_bytes, self._footer_bytes) - .repeat(num_epochs)) - iterator = dataset.make_initializable_iterator() - init_op = iterator.initializer - get_next_op = iterator.get_next() - save_op = self._save_op(iterator._iterator_resource) - restore_op = self._restore_op(iterator._iterator_resource) - return init_op, get_next_op, save_op, restore_op - - def _restore_iterator(self): - output_types = dtypes.string - output_shapes = tensor_shape.scalar() - iterator = iterator_ops.Iterator.from_structure(output_types, output_shapes) - get_next = iterator.get_next() - restore_op = self._restore_op(iterator._iterator_resource) - return restore_op, get_next - - def testSaveRestore(self): - num_epochs = 10 - epoch_break = 5 - file_break = self._num_files // 2 - record_break = self._num_records // 2 - - with ops.Graph().as_default() as g: - init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( - num_epochs=num_epochs) - with self.test_session(graph=g) as sess: - sess.run(init_op) - # Note: There is no checkpoint saved currently so a NotFoundError is - # raised. - with self.assertRaises(errors.NotFoundError): - sess.run(restore_op) - for epoch in range(num_epochs): - for f in range(self._num_files): - for r in range(self._num_records): - if (epoch == epoch_break and f == file_break and - r == record_break): - sess.run(save_op) - break - self.assertEqual(self._record(f, r), sess.run(get_next_op)) - else: - continue - break - else: - continue - break - else: - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) - - with ops.Graph().as_default() as g: - init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( - num_epochs=num_epochs) - with self.test_session(graph=g) as sess: - sess.run(restore_op) - for epoch in range(num_epochs): - for f in range(self._num_files): - for r in range(self._num_records): - if (epoch < epoch_break or - (epoch == epoch_break and f < file_break) or - (epoch == epoch_break and f == file_break and - r < record_break)): - continue - self.assertEqual(self._record(f, r), sess.run(get_next_op)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) - - def testInitThenRestore(self): - # Note: Calling init_op before restore_op is redundant. This test just makes - # sure we do not fail if restore is called on an already initialized - # iterator resource. - num_epochs = 10 - epoch_break = 5 - file_break = self._num_files // 2 - record_break = self._num_records // 2 - - with ops.Graph().as_default() as g: - init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( - num_epochs=num_epochs) - with self.test_session(graph=g) as sess: - sess.run(init_op) - # Note: There is no checkpoint saved currently so a NotFoundError is - # raised. - with self.assertRaises(errors.NotFoundError): - sess.run(restore_op) - for epoch in range(num_epochs): - for f in range(self._num_files): - for r in range(self._num_records): - if (epoch == epoch_break and f == file_break and - r == record_break): - sess.run(save_op) - break - self.assertEqual(self._record(f, r), sess.run(get_next_op)) - else: - continue - break - else: - continue - break - else: - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) - - with ops.Graph().as_default() as g: - init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( - num_epochs=num_epochs) - with self.test_session(graph=g) as sess: - sess.run(init_op) - sess.run(restore_op) - for epoch in range(num_epochs): - for f in range(self._num_files): - for r in range(self._num_records): - if (epoch < epoch_break or - (epoch == epoch_break and f < file_break) or - (epoch == epoch_break and f == file_break and - r < record_break)): - continue - self.assertEqual(self._record(f, r), sess.run(get_next_op)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) - - def testRestoreInModifiedGraph(self): - num_epochs = 10 - num_epochs_1 = 20 - epoch_break = 5 - file_break = self._num_files // 2 - record_break = self._num_records // 2 - - with ops.Graph().as_default() as g: - init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( - num_epochs=num_epochs) - with self.test_session(graph=g) as sess: - sess.run(init_op) - # Note: There is no checkpoint saved currently so a NotFoundError is - # raised. - with self.assertRaises(errors.NotFoundError): - sess.run(restore_op) - for epoch in range(num_epochs): - for f in range(self._num_files): - for r in range(self._num_records): - if (epoch == epoch_break and f == file_break and - r == record_break): - sess.run(save_op) - break - self.assertEqual(self._record(f, r), sess.run(get_next_op)) - else: - continue - break - else: - continue - break - else: - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) - - with ops.Graph().as_default() as g: - init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( - num_epochs=num_epochs_1) - with self.test_session(graph=g) as sess: - sess.run(restore_op) - for epoch in range(num_epochs): - for f in range(self._num_files): - for r in range(self._num_records): - if (epoch < epoch_break or - (epoch == epoch_break and f < file_break) or - (epoch == epoch_break and f == file_break and - r < record_break)): - continue - self.assertEqual(self._record(f, r), sess.run(get_next_op)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) - - def testRestoreWithoutBuildingDatasetGraph(self): - num_epochs = 10 - epoch_break = 5 - file_break = self._num_files // 2 - record_break = self._num_records // 2 - - with ops.Graph().as_default() as g: - init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( - num_epochs=num_epochs) - with self.test_session(graph=g) as sess: - sess.run(init_op) - # Note: There is no checkpoint saved currently so a NotFoundError is - # raised. - with self.assertRaises(errors.NotFoundError): - sess.run(restore_op) - for epoch in range(num_epochs): - for f in range(self._num_files): - for r in range(self._num_records): - if (epoch == epoch_break and f == file_break and - r == record_break): - sess.run(save_op) - break - self.assertEqual(self._record(f, r), sess.run(get_next_op)) - else: - continue - break - else: - continue - break - else: - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) - - with ops.Graph().as_default() as g: - restore_op, get_next_op = self._restore_iterator() - with self.test_session(graph=g) as sess: - sess.run(restore_op) - for epoch in range(num_epochs): - for f in range(self._num_files): - for r in range(self._num_records): - if (epoch < epoch_break or - (epoch == epoch_break and f < file_break) or - (epoch == epoch_break and f == file_break and - r < record_break)): - continue - self.assertEqual(self._record(f, r), sess.run(get_next_op)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) - - def testRestoreUnusedIterator(self): - num_epochs = 10 - with ops.Graph().as_default() as g: - init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( - num_epochs=num_epochs) - with self.test_session(graph=g) as sess: - sess.run(init_op) - # Note: There is no checkpoint saved currently so a NotFoundError is - # raised. - with self.assertRaises(errors.NotFoundError): - sess.run(restore_op) - # Save unused iterator. - sess.run(save_op) - with ops.Graph().as_default() as g: - init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( - num_epochs=num_epochs) - with self.test_session(graph=g) as sess: - sess.run(restore_op) - for _ in range(num_epochs * self._num_files * self._num_records): - sess.run(get_next_op) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) - - def testRestoreExhaustedIterator(self): - num_epochs = 10 - - with ops.Graph().as_default() as g: - init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( - num_epochs=num_epochs) - with self.test_session(graph=g) as sess: - sess.run(init_op) - # Note: There is no checkpoint saved currently so a NotFoundError is - # raised. - with self.assertRaises(errors.NotFoundError): - sess.run(restore_op) - for _ in range(num_epochs): - for f in range(self._num_files): - for r in range(self._num_records): - self.assertEqual(self._record(f, r), sess.run(get_next_op)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) - sess.run(save_op) - - with ops.Graph().as_default() as g: - init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( - num_epochs=num_epochs) - with self.test_session(graph=g) as sess: - sess.run(restore_op) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) - - -class TFRecordDatasetTest(test.TestCase): + return readers.FixedLengthRecordDataset( + filenames, self._record_bytes, self._header_bytes, + self._footer_bytes).repeat(num_epochs) + + def testFixedLengthRecordCore(self): + num_epochs = 5 + num_outputs = num_epochs * self._num_files * self._num_records + self.run_core_tests(lambda: self._build_iterator_graph(num_epochs), + lambda: self._build_iterator_graph(num_epochs * 2), + num_outputs) + + +class TFRecordDatasetTestBase(test.TestCase): def setUp(self): - super(TFRecordDatasetTest, self).setUp() + super(TFRecordDatasetTestBase, self).setUp() self._num_files = 2 self._num_records = 7 @@ -880,6 +362,9 @@ class TFRecordDatasetTest(test.TestCase): writer.close() return filenames + +class TFRecordDatasetTest(TFRecordDatasetTestBase): + def testReadOneEpoch(self): with self.test_session() as sess: # Basic test: read from file 0. @@ -1001,6 +486,74 @@ class TFRecordDatasetTest(test.TestCase): sess.run(iterator.get_next()) +class TFRecordDatasetSerializationTest( + TFRecordDatasetTestBase, + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_iterator_graph(self, + num_epochs, + batch_size=1, + compression_type=None, + buffer_size=None): + filenames = self._createFiles() + if compression_type is "ZLIB": + zlib_files = [] + for i, fn in enumerate(filenames): + with open(fn, "rb") as f: + cdata = zlib.compress(f.read()) + zfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.z" % i) + with open(zfn, "wb") as f: + f.write(cdata) + zlib_files.append(zfn) + filenames = zlib_files + + elif compression_type is "GZIP": + gzip_files = [] + for i, fn in enumerate(self.test_filenames): + with open(fn, "rb") as f: + gzfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.gz" % i) + with gzip.GzipFile(gzfn, "wb") as gzf: + gzf.write(f.read()) + gzip_files.append(gzfn) + filenames = gzip_files + + return readers.TFRecordDataset( + filenames, compression_type, + buffer_size=buffer_size).repeat(num_epochs).batch(batch_size) + + def testTFRecordWithoutBufferCore(self): + num_epochs = 5 + batch_size = num_epochs + num_outputs = num_epochs * self._num_files * self._num_records // batch_size + # pylint: disable=g-long-lambda + self.run_core_tests( + lambda: self._build_iterator_graph(num_epochs, batch_size, + buffer_size=0), + lambda: self._build_iterator_graph(num_epochs * 2, batch_size), + num_outputs) + self.run_core_tests( + lambda: self._build_iterator_graph(num_epochs, buffer_size=0), None, + num_outputs * batch_size) + # pylint: enable=g-long-lambda + + def testTFRecordWithBufferCore(self): + num_epochs = 5 + num_outputs = num_epochs * self._num_files * self._num_records + self.run_core_tests(lambda: self._build_iterator_graph(num_epochs), + lambda: self._build_iterator_graph(num_epochs * 2), + num_outputs) + + def testTFRecordWithCompressionCore(self): + num_epochs = 5 + num_outputs = num_epochs * self._num_files * self._num_records + self.run_core_tests( + lambda: self._build_iterator_graph(num_epochs, compression_type="ZLIB"), + lambda: self._build_iterator_graph(num_epochs * 2), num_outputs) + self.run_core_tests( + lambda: self._build_iterator_graph(num_epochs, compression_type="GZIP"), + lambda: self._build_iterator_graph(num_epochs * 2), num_outputs) + + class ReadBatchFeaturesTest(test.TestCase): def setUp(self): diff --git a/tensorflow/contrib/data/python/kernel_tests/sequence_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/sequence_dataset_op_test.py index 91615e9f6205cc95ff531b98683ff485964f714e..1a26da82e533ec01106ea10525c1cd96627c34fb 100644 --- a/tensorflow/contrib/data/python/kernel_tests/sequence_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/sequence_dataset_op_test.py @@ -19,6 +19,7 @@ from __future__ import print_function import numpy as np +from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -207,5 +208,82 @@ class SequenceDatasetTest(test.TestCase): sess.run(get_next) +class SequenceDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_skip_dataset(self, count): + components = (np.arange(10),) + return dataset_ops.Dataset.from_tensor_slices(components).skip(count) + + def testSkipFewerThanInputs(self): + count = 4 + num_outputs = 10 - count + self.run_core_tests(lambda: self._build_skip_dataset(count), + lambda: self._build_skip_dataset(count + 2), + num_outputs) + + def testSkipVarious(self): + # Skip more than inputs + self.run_core_tests(lambda: self._build_skip_dataset(20), None, 0) + # Skip exactly the input size + self.run_core_tests(lambda: self._build_skip_dataset(10), None, 0) + self.run_core_tests(lambda: self._build_skip_dataset(-1), None, 0) + # Skip nothing + self.run_core_tests(lambda: self._build_skip_dataset(0), None, 10) + + def _build_take_dataset(self, count): + components = (np.arange(10),) + return dataset_ops.Dataset.from_tensor_slices(components).take(count) + + def testTakeFewerThanInputs(self): + count = 4 + self.run_core_tests( + lambda: self._build_take_dataset(count), + lambda: self._build_take_dataset(count + 2), + count, + ) + + def testTakeVarious(self): + # Take more than inputs + self.run_core_tests(lambda: self._build_take_dataset(20), None, 10) + # Take exactly the input size + self.run_core_tests(lambda: self._build_take_dataset(10), None, 10) + # Take all + self.run_core_tests(lambda: self._build_take_dataset(-1), None, 10) + # Take nothing + self.run_core_tests(lambda: self._build_take_dataset(0), None, 0) + + def _build_repeat_dataset(self, count, take_count=3): + components = (np.arange(10),) + return dataset_ops.Dataset.from_tensor_slices(components).take( + take_count).repeat(count) + + def testFiniteRepeat(self): + count = 10 + self.run_core_tests(lambda: self._build_repeat_dataset(count), + lambda: self._build_repeat_dataset(count + 2), + 3 * count) + + def testEmptyRepeat(self): + self.run_core_tests(lambda: self._build_repeat_dataset(0), None, 0) + + def testInfiniteRepeat(self): + self.verify_unused_iterator( + lambda: self._build_repeat_dataset(-1), 10, verify_exhausted=False) + self.verify_init_before_restore( + lambda: self._build_repeat_dataset(-1), 10, verify_exhausted=False) + self.verify_multiple_breaks( + lambda: self._build_repeat_dataset(-1), 20, verify_exhausted=False) + self.verify_reset_restored_iterator( + lambda: self._build_repeat_dataset(-1), 20, verify_exhausted=False) + self.verify_restore_in_modified_graph( + lambda: self._build_repeat_dataset(-1), + lambda: self._build_repeat_dataset(2), + 20, + verify_exhausted=False) + # Test repeat empty dataset + self.run_core_tests(lambda: self._build_repeat_dataset(-1, 0), None, 0) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py index e9ebaf4f21534fb43218d9579127b4aeb1dbd85e..ba1be0690ff3d72df9fe40980c0f5d53b33e41c5 100644 --- a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py @@ -18,16 +18,24 @@ from __future__ import division from __future__ import print_function import collections +import os import numpy as np -from tensorflow.contrib.data.python.ops import dataset_ops +from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import dataset_ops as contrib_dataset_ops +from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops +from tensorflow.contrib.data.python.ops import shuffle_ops +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import iterator_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.platform import gfile from tensorflow.python.platform import test +from tensorflow.python.training import saver as saver_lib class ShuffleDatasetTest(test.TestCase): @@ -42,8 +50,9 @@ class ShuffleDatasetTest(test.TestCase): buffer_size_placeholder = array_ops.placeholder(dtypes.int64, shape=[]) seed_placeholder = array_ops.placeholder(dtypes.int64, shape=[]) - repeat_dataset = (dataset_ops.Dataset.from_tensor_slices(components) - .repeat(count_placeholder)) + repeat_dataset = ( + contrib_dataset_ops.Dataset.from_tensor_slices(components) + .repeat(count_placeholder)) shuffle_dataset = repeat_dataset.shuffle(buffer_size_placeholder, seed_placeholder) @@ -134,8 +143,9 @@ class ShuffleDatasetTest(test.TestCase): def testDefaultArguments(self): components = [0, 1, 2, 3, 4] - iterator = (dataset_ops.Dataset.from_tensor_slices(components).shuffle(5) - .repeat().make_one_shot_iterator()) + iterator = ( + contrib_dataset_ops.Dataset.from_tensor_slices(components).shuffle(5) + .repeat().make_one_shot_iterator()) get_next = iterator.get_next() @@ -148,6 +158,401 @@ class ShuffleDatasetTest(test.TestCase): for i in range(5): self.assertEqual(10, counts[i]) + def testSeedNoneSeed2NonNone(self): + with self.assertRaises(ValueError): + dataset_ops.ShuffleDataset(dataset_ops.Dataset.range(5), + buffer_size=1, + seed=None, + seed2=10) + + +class ShuffleDatasetSerializationTest(test.TestCase): + + def tearDown(self): + # Remove all checkpoint files. + prefix = self._ckpt_path() + pattern = prefix + "*" + files = gfile.Glob(pattern) + map(gfile.Remove, files) + + def _build_graph(self, + range_limit=10, + num_repeats=5, + buffer_size=5, + seed=None, + reshuffle_each_iteration=None, + build_saveable=True): + iterator = dataset_ops.Dataset.range(range_limit).shuffle( + buffer_size, + seed=seed, + reshuffle_each_iteration=reshuffle_each_iteration).repeat( + num_repeats).make_initializable_iterator() + if build_saveable: + saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator) + ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) + init_op = iterator.initializer + get_next = iterator.get_next() + ops.add_to_collection("iterator_ops", init_op) + ops.add_to_collection("iterator_ops", get_next) + saver = saver_lib.Saver(allow_empty=True) + return init_op, get_next, saver + + def _ckpt_path(self): + return os.path.join(self.get_temp_dir(), "iterator") + + def _latest_ckpt(self): + return saver_lib.latest_checkpoint(self.get_temp_dir()) + + def _save(self, sess, saver): + saver.save(sess, self._ckpt_path()) + + def _restore(self, saver, sess): + saver.restore(sess, self._latest_ckpt()) + + def _import_meta_graph(self): + meta_file_path = self._ckpt_path() + ".meta" + return saver_lib.import_meta_graph(meta_file_path) + + def _testReadWithBreaks(self, break_points, init_before_restore=False): + seed = 55 + range_limit = 10 + num_repeats = 5 + num_outputs = range_limit * num_repeats + buffer_sizes = [1, 3, 8, 10, 25, 50] + reshuffle_each_iteration = False + for buffer_size in buffer_sizes: + expected = [] + actual = [] + # Generate the ground truth. + with ops.Graph().as_default() as g: + g.seed = 10 + init_op, get_next_op, _ = self._build_graph( + range_limit=range_limit, + num_repeats=num_repeats, + buffer_size=buffer_size, + seed=seed, + reshuffle_each_iteration=reshuffle_each_iteration) + with self.test_session(graph=g) as sess: + sess.run(init_op) + for _ in range(num_outputs): + expected.append(sess.run(get_next_op)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + + # Run and checkpoint after first break_point. + with ops.Graph().as_default() as g: + g.seed = 10 + init_op, get_next_op, saver = self._build_graph( + range_limit=range_limit, + num_repeats=num_repeats, + buffer_size=buffer_size, + seed=seed, + reshuffle_each_iteration=reshuffle_each_iteration) + with self.test_session(graph=g) as sess: + sess.run(init_op) + for _ in range(break_points[0]): + actual.append(sess.run(get_next_op)) + self._save(sess, saver) + + # Load from checkpoint and continue running while stopping at each + # subsequent checkpoint. + for i in range(len(break_points)): + with ops.Graph().as_default() as g: + saver = self._import_meta_graph() + init_op, get_next_op = ops.get_collection("iterator_ops") + with self.test_session(graph=g) as sess: + if init_before_restore: + sess.run(init_op) + self._restore(saver, sess) + start = break_points[i] + end = break_points[ + i + 1] if i < len(break_points) - 1 else num_outputs + for _ in range(end - start): + actual.append(sess.run(get_next_op)) + self._save(sess, saver) + if end == num_outputs: + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + self.assertEqual(expected, actual) + + def testSaveRestore(self): + self._testReadWithBreaks([8]) # rng buffer_size: 0 + self._testReadWithBreaks([13]) # rng buffer_size: 1 + self._testReadWithBreaks([18]) # rng buffer_size: 2 + self._testReadWithBreaks([23]) # rng buffer_size: 3 + + def testSaveUnusedIterator(self): + self._testReadWithBreaks([0]) + + def testSaveFullyUsedIterator(self): + self._testReadWithBreaks([50]) + + def testMultipleBreaks(self): + self._testReadWithBreaks([0, 5, 9, 15, 25, 32]) + + def testIdempotence(self): + # Attempt to save iterator immediately after restoring. + self._testReadWithBreaks([1, 1, 5, 5, 5, 25, 32]) + + def testInitThenRestore(self): + self._testReadWithBreaks([0, 5, 9, 15, 25, 32], init_before_restore=True) + + def testRestoreExhaustedIterator(self): + seed = 55 + range_limit = 10 + num_repeats = 5 + num_outputs = range_limit * num_repeats + buffer_sizes = [1, 3, 8, 10, 25, 50] + reshuffle_each_iteration = False + for buffer_size in buffer_sizes: + with ops.Graph().as_default() as g: + g.seed = 10 + init_op, get_next_op, saver = self._build_graph( + range_limit=range_limit, + num_repeats=num_repeats, + buffer_size=buffer_size, + seed=seed, + reshuffle_each_iteration=reshuffle_each_iteration) + with self.test_session(graph=g) as sess: + sess.run(init_op) + for _ in range(num_outputs): + sess.run(get_next_op) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + self._save(sess, saver) + + with ops.Graph().as_default() as g: + saver = self._import_meta_graph() + init_op, get_next_op = ops.get_collection("iterator_ops") + with self.test_session(graph=g) as sess: + self._restore(saver, sess) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + + def testResetRestoredIterator(self): + seed = 55 + range_limit = 10 + num_repeats = 5 + num_outputs = range_limit * num_repeats + buffer_sizes = [1, 3, 8, 10, 25, 50] + reshuffle_each_iteration = False + for buffer_size in buffer_sizes: + with ops.Graph().as_default() as g: + g.seed = 10 + init_op, get_next_op, saver = self._build_graph( + range_limit=range_limit, + num_repeats=num_repeats, + buffer_size=buffer_size, + seed=seed, + reshuffle_each_iteration=reshuffle_each_iteration) + with self.test_session(graph=g) as sess: + sess.run(init_op) + for _ in range(num_outputs // 2): + sess.run(get_next_op) + self._save(sess, saver) + + outputs = [] + with ops.Graph().as_default() as g: + saver = self._import_meta_graph() + init_op, get_next_op = ops.get_collection("iterator_ops") + with self.test_session(graph=g) as sess: + self._restore(saver, sess) + sess.run(init_op) + for _ in range(num_outputs): + outputs.append(sess.run(get_next_op)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + expected_outputs_sorted = sorted( + np.array([range(range_limit) + for _ in range(num_repeats)]).flatten()) + self.assertEqual(expected_outputs_sorted, sorted(outputs)) + + def testRestoreInModifiedGraph(self): + seed = 55 + break_point = 25 + range_limit = 10 + num_repeats = 5 + num_outputs = range_limit * num_repeats + buffer_sizes = [3, 8, 10, 25, 50] + reshuffle_each_iteration = False + for buffer_size in buffer_sizes: + expected = [] + actual_without_restore = [] + actual = [] + with ops.Graph().as_default() as g: + g.seed = 10 + init_op, get_next_op, saver = self._build_graph( + range_limit=range_limit, + num_repeats=num_repeats, + buffer_size=buffer_size, + seed=seed, + reshuffle_each_iteration=reshuffle_each_iteration) + with self.test_session(graph=g) as sess: + sess.run(init_op) + for _ in range(break_point): + expected.append(sess.run(get_next_op)) + actual.extend(expected) + self._save(sess, saver) + for _ in range(num_outputs - break_point): + expected.append(sess.run(get_next_op)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + + with ops.Graph().as_default() as g: + g.seed = 20 # Different seed than previous graph for shuffle rngs. + init_op, get_next_op, saver = self._build_graph( + range_limit=range_limit, + num_repeats=num_repeats, + buffer_size=buffer_size, + seed=seed, + reshuffle_each_iteration=reshuffle_each_iteration) + with self.test_session(graph=g) as sess: + sess.run(init_op) + for _ in range(num_outputs): + actual_without_restore.append(sess.run(get_next_op)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + + with ops.Graph().as_default() as g: + g.seed = 20 # Different seed than previous graph for shuffle rngs. + init_op, get_next_op, saver = self._build_graph( + range_limit=range_limit, + num_repeats=num_repeats, + buffer_size=buffer_size, + seed=seed, + reshuffle_each_iteration=reshuffle_each_iteration) + with self.test_session(graph=g) as sess: + self._restore(saver, sess) + for _ in range(num_outputs - break_point): + actual.append(sess.run(get_next_op)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + + # Since the modified graph has a different random seed it produces a + # different order of examples. + self.assertNotEqual(expected, actual_without_restore) + self.assertEqual(sorted(expected), sorted(actual_without_restore)) + self.assertEqual(expected, actual) + + def testDoNotBuildSaveable(self): + seed = 55 + break_point = 25 + range_limit = 10 + num_repeats = 5 + num_outputs = range_limit * num_repeats + buffer_sizes = [3, 8, 10, 25, 50] + reshuffle_each_iteration = False + for buffer_size in buffer_sizes: + actual = [] + with ops.Graph().as_default() as g: + g.seed = 10 + init_op, get_next_op, saver = self._build_graph( + range_limit=range_limit, + num_repeats=num_repeats, + buffer_size=buffer_size, + seed=seed, + reshuffle_each_iteration=reshuffle_each_iteration) + with self.test_session(graph=g) as sess: + sess.run(init_op) + for _ in range(break_point): + sess.run(get_next_op) + self._save(sess, saver) + + with ops.Graph().as_default() as g: + g.seed = 20 # Different seed than previous graph for shuffle rngs. + init_op, get_next_op, saver = self._build_graph( + range_limit=range_limit, + num_repeats=num_repeats, + buffer_size=buffer_size, + seed=seed, + reshuffle_each_iteration=reshuffle_each_iteration, + build_saveable=False) + with self.test_session(graph=g) as sess: + # Since the SaveableObject was not added to Saver's list + # of saveables, iterator state is not restored by saver.restore(). + self._restore(saver, sess) + with self.assertRaises(errors.FailedPreconditionError): + sess.run(get_next_op) + sess.run(init_op) + for _ in range(num_outputs): + actual.append(sess.run(get_next_op)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + expected_outputs_sorted = sorted( + np.array([range(range_limit) for _ in range(num_repeats)]).flatten()) + self.assertEqual(expected_outputs_sorted, sorted(actual)) + + +class ShuffleAndRepeatTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_ds(self, seed, count=5): + return dataset_ops.Dataset.range(20).apply( + shuffle_ops.shuffle_and_repeat(buffer_size=5, count=count, seed=seed)) + + def testCorrectOutput(self): + output = self.gen_outputs(lambda: self._build_ds(10), [], 100) + self.assertSequenceEqual( + sorted(output), sorted( + np.array([range(20) for _ in range(5)]).flatten())) + for i in range(5): + self.assertSequenceEqual(sorted(output[i * 20:(i + 1) * 20]), range(20)) + + def testReshuffling(self): + # Check that the output orders of different epochs are indeed different. + output = self.gen_outputs(lambda: self._build_ds(10), [], 100) + for i in range(4): + epoch1 = output[i * 20:(i + 1) * 20] + epoch2 = output[(i + 1) * 20:(i + 2) * 20] + self.assertNotEqual(epoch1, epoch2) + + def testSameOrderForSameSeeds(self): + output1 = self.gen_outputs(lambda: self._build_ds(10), [], 100) + output2 = self.gen_outputs(lambda: self._build_ds(10), [], 100) + self.assertEqual(output1, output2) + + def testDifferentOrderForDifferentSeeds(self): + output1 = self.gen_outputs(lambda: self._build_ds(10), [], 100) + output2 = self.gen_outputs(lambda: self._build_ds(20), [], 100) + self.assertNotEqual(output1, output2) + self.assertEqual(sorted(output1), sorted(output2)) + + def testCountNone(self): + output1 = self.gen_outputs( + lambda: self._build_ds(10, count=None), [], 100, verify_exhausted=False) + output2 = self.gen_outputs( + lambda: self._build_ds(20, count=None), [], 100, verify_exhausted=False) + self.assertNotEqual(output1, output2) + self.assertEqual(sorted(output1), sorted(output2)) + + def testCountMinusOne(self): + output1 = self.gen_outputs( + lambda: self._build_ds(10, count=-1), [], 100, verify_exhausted=False) + output2 = self.gen_outputs( + lambda: self._build_ds(20, count=-1), [], 100, verify_exhausted=False) + self.assertNotEqual(output1, output2) + self.assertEqual(sorted(output1), sorted(output2)) + + def testInfiniteOutputs(self): + # Asserting that the iterator is exhausted after producing 100 items should + # fail. + with self.assertRaises(AssertionError): + self.gen_outputs(lambda: self._build_ds(10, count=None), [], 100) + with self.assertRaises(AssertionError): + self.gen_outputs(lambda: self._build_ds(10, count=-1), [], 100) + + +class ShuffleAndRepeatSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_ds(self, seed): + return dataset_ops.Dataset.range(20).apply( + shuffle_ops.shuffle_and_repeat(buffer_size=5, count=5, seed=seed)) + + def testCore(self): + self.run_core_tests(lambda: self._build_ds(10), lambda: self._build_ds(20), + 100) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..8f24d6b2f612cff662aa8a36085bc69a9ea1a290 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py @@ -0,0 +1,213 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline statistics gathering ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.data.python.ops import stats_ops +from tensorflow.core.framework import summary_pb2 +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class StatsDatasetTest(test.TestCase): + + def _assertSummaryHasCount(self, summary_str, tag, expected_value): + summary_proto = summary_pb2.Summary() + summary_proto.ParseFromString(summary_str) + for value in summary_proto.value: + if tag == value.tag: + self.assertEqual(expected_value, value.histo.num) + return + self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto)) + + def _assertSummaryHasSum(self, summary_str, tag, expected_value): + summary_proto = summary_pb2.Summary() + summary_proto.ParseFromString(summary_str) + for value in summary_proto.value: + if tag == value.tag: + self.assertEqual(expected_value, value.histo.sum) + return + self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto)) + + def testBytesProduced(self): + dataset = dataset_ops.Dataset.range(100).map( + lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).apply( + stats_ops.bytes_produced_stats("bytes_produced")) + iterator = dataset.make_initializable_iterator() + stats_aggregator = stats_ops.StatsAggregator() + stats_aggregator_subscriber = stats_aggregator.subscribe(iterator) + next_element = iterator.get_next() + summary_t = stats_aggregator.get_summary() + + with self.test_session() as sess: + sess.run([iterator.initializer, stats_aggregator_subscriber]) + expected_sum = 0.0 + for i in range(100): + self.assertAllEqual( + np.array([i] * i, dtype=np.int64), sess.run(next_element)) + summary_str = sess.run(summary_t) + self._assertSummaryHasCount(summary_str, "bytes_produced", float(i + 1)) + expected_sum += i * 8.0 + self._assertSummaryHasSum(summary_str, "bytes_produced", expected_sum) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + summary_str = sess.run(summary_t) + self._assertSummaryHasCount(summary_str, "bytes_produced", 100.0) + self._assertSummaryHasSum(summary_str, "bytes_produced", expected_sum) + + def testLatencyStats(self): + dataset = dataset_ops.Dataset.range(100).apply( + stats_ops.latency_stats("record_latency")) + iterator = dataset.make_initializable_iterator() + stats_aggregator = stats_ops.StatsAggregator() + stats_aggregator_subscriber = stats_aggregator.subscribe(iterator) + next_element = iterator.get_next() + summary_t = stats_aggregator.get_summary() + + with self.test_session() as sess: + sess.run([iterator.initializer, stats_aggregator_subscriber]) + for i in range(100): + self.assertEqual(i, sess.run(next_element)) + self._assertSummaryHasCount( + sess.run(summary_t), "record_latency", float(i + 1)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 100.0) + + def testReinitialize(self): + dataset = dataset_ops.Dataset.range(100).apply( + stats_ops.latency_stats("record_latency")) + iterator = dataset.make_initializable_iterator() + stats_aggregator = stats_ops.StatsAggregator() + stats_aggregator_subscriber = stats_aggregator.subscribe(iterator) + next_element = iterator.get_next() + summary_t = stats_aggregator.get_summary() + + with self.test_session() as sess: + sess.run(stats_aggregator_subscriber) + for j in range(5): + sess.run(iterator.initializer) + for i in range(100): + self.assertEqual(i, sess.run(next_element)) + self._assertSummaryHasCount( + sess.run(summary_t), "record_latency", float((j * 100) + i + 1)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + self._assertSummaryHasCount( + sess.run(summary_t), "record_latency", (j + 1) * 100.0) + + def testNoAggregatorRegistered(self): + dataset = dataset_ops.Dataset.range(100).apply( + stats_ops.latency_stats("record_latency")) + iterator = dataset.make_initializable_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + sess.run(iterator.initializer) + for i in range(100): + self.assertEqual(i, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testMultipleTags(self): + dataset = dataset_ops.Dataset.range(100).apply( + stats_ops.latency_stats("record_latency")).apply( + stats_ops.latency_stats("record_latency_2")) + iterator = dataset.make_initializable_iterator() + stats_aggregator = stats_ops.StatsAggregator() + stats_aggregator_subscriber = stats_aggregator.subscribe(iterator) + next_element = iterator.get_next() + summary_t = stats_aggregator.get_summary() + + with self.test_session() as sess: + sess.run([iterator.initializer, stats_aggregator_subscriber]) + for i in range(100): + self.assertEqual(i, sess.run(next_element)) + self._assertSummaryHasCount( + sess.run(summary_t), "record_latency", float(i + 1)) + self._assertSummaryHasCount( + sess.run(summary_t), "record_latency_2", float(i + 1)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 100.0) + self._assertSummaryHasCount( + sess.run(summary_t), "record_latency_2", 100.0) + + def testRepeatedTags(self): + dataset = dataset_ops.Dataset.range(100).apply( + stats_ops.latency_stats("record_latency")).apply( + stats_ops.latency_stats("record_latency")) + iterator = dataset.make_initializable_iterator() + stats_aggregator = stats_ops.StatsAggregator() + stats_aggregator_subscriber = stats_aggregator.subscribe(iterator) + next_element = iterator.get_next() + summary_t = stats_aggregator.get_summary() + + with self.test_session() as sess: + sess.run([iterator.initializer, stats_aggregator_subscriber]) + for i in range(100): + self.assertEqual(i, sess.run(next_element)) + self._assertSummaryHasCount( + sess.run(summary_t), "record_latency", float(2 * (i + 1))) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 200.0) + + def testMultipleIteratorsSameAggregator(self): + dataset = dataset_ops.Dataset.range(100).apply( + stats_ops.latency_stats("record_latency")) + iterator_0 = dataset.make_initializable_iterator() + iterator_1 = dataset.make_initializable_iterator() + stats_aggregator = stats_ops.StatsAggregator() + stats_aggregator_subscribers = [stats_aggregator.subscribe(iterator_0), + stats_aggregator.subscribe(iterator_1)] + next_element = iterator_0.get_next() + iterator_1.get_next() + summary_t = stats_aggregator.get_summary() + + with self.test_session() as sess: + sess.run([iterator_0.initializer, iterator_1.initializer, + stats_aggregator_subscribers]) + for i in range(100): + self.assertEqual(i * 2, sess.run(next_element)) + self._assertSummaryHasCount( + sess.run(summary_t), "record_latency", float(2 * (i + 1))) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 200.0) + + def testMultipleStatsAggregatorsSameIteratorFail(self): + dataset = dataset_ops.Dataset.range(100).apply( + stats_ops.latency_stats("record_latency")) + iterator = dataset.make_initializable_iterator() + stats_aggregator_0 = stats_ops.StatsAggregator() + stats_aggregator_1 = stats_ops.StatsAggregator() + + with self.test_session() as sess: + sess.run(stats_aggregator_0.subscribe(iterator)) + # TODO(mrry): Consider making this allowable (and also allowing + # aggregators to unsubscribe). + with self.assertRaises(errors.FailedPreconditionError): + sess.run(stats_aggregator_1.subscribe(iterator)) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/zip_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/zip_dataset_op_test.py index b0e72183019e4d53756542e2a2ef071111120dcd..5d34b0024c472d0393544ff3dad8acea7964345f 100644 --- a/tensorflow/contrib/data/python/kernel_tests/zip_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/zip_dataset_op_test.py @@ -19,6 +19,7 @@ from __future__ import print_function import numpy as np +from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -110,5 +111,31 @@ class ZipDatasetTest(test.TestCase): sess.run(get_next) +class ZipDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_dataset(self, arr): + components = [ + np.tile(np.array([[1], [2], [3], [4]]), 20), + np.tile(np.array([[12], [13], [14], [15]]), 22), + np.array(arr) + ] + datasets = [ + dataset_ops.Dataset.from_tensor_slices(component) + for component in components + ] + return dataset_ops.Dataset.zip((datasets[0], (datasets[1], datasets[2]))) + + def testCore(self): + # Equal length components + arr = [37.0, 38.0, 39.0, 40.0] + num_outputs = len(arr) + self.run_core_tests(lambda: self._build_dataset(arr), None, num_outputs) + # Variable length components + diff_size_arr = [1.0, 2.0] + self.run_core_tests(lambda: self._build_dataset(diff_size_arr), + lambda: self._build_dataset(arr), 2) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD index 1b81cf5be9190ffab646192fb9a72fd3da7deee1..1f35ee056b7f897ce5e7488b205ecf5a05ef0268 100644 --- a/tensorflow/contrib/data/python/ops/BUILD +++ b/tensorflow/contrib/data/python/ops/BUILD @@ -14,11 +14,13 @@ load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") py_library( name = "dataset_ops", srcs = [ + "counter.py", "dataset_ops.py", ], srcs_version = "PY2AND3", deps = [ ":transformation_ops", + "//tensorflow/python:dataset_ops_gen", "//tensorflow/python:util", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/util:nest", @@ -38,6 +40,25 @@ py_library( ], ) +py_library( + name = "random_ops", + srcs = [ + "random_ops.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:constant_op", + "//tensorflow/python:dataset_ops_gen", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:random_seed", + "//tensorflow/python:tensor_shape", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/util:nest", + "//tensorflow/python/data/util:sparse", + ], +) + py_library( name = "readers", srcs = [ @@ -60,6 +81,19 @@ py_library( ], ) +py_library( + name = "shuffle_ops", + srcs = [ + "shuffle_ops.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":random_ops", + ":transformation_ops", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + py_library( name = "transformation_ops", srcs = [ @@ -70,6 +104,7 @@ py_library( "interleave_ops.py", "resampling.py", "scan_ops.py", + "stats_ops.py", ], srcs_version = "PY2AND3", deps = [ @@ -84,8 +119,10 @@ py_library( "//tensorflow/python:random_ops", "//tensorflow/python:tensor_shape", "//tensorflow/python:tensor_util", + "//tensorflow/python:util", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/util:nest", + "//tensorflow/python/data/util:sparse", "//third_party/py/numpy", ], ) @@ -117,14 +154,7 @@ tf_custom_op_py_library( deps = [ ":prefetching_ops", "//tensorflow/contrib/util:util_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:errors", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:math_ops", "//tensorflow/python:platform", - "//tensorflow/python:state_ops", - "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", ], ) diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py index abc9212a87550745490b974d25a929a66287f785..63782d229e1535892686f202ca1f0833dee6ed80 100644 --- a/tensorflow/contrib/data/python/ops/batching.py +++ b/tensorflow/contrib/data/python/ops/batching.py @@ -19,6 +19,7 @@ from __future__ import print_function from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest +from tensorflow.python.data.util import sparse from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape @@ -103,6 +104,48 @@ def unbatch(): return _apply_fn +def filter_irregular_batches(batch_size): + """Transformation that filters out batches that are not of size batch_size.""" + + def _apply_fn(dataset): + """Function from `Dataset` to `Dataset` that applies the transformation.""" + tensor_batch_size = ops.convert_to_tensor( + batch_size, dtype=dtypes.int64, name="batch_size") + + flattened = _RestructuredDataset( + dataset, + tuple(nest.flatten(dataset.output_types)), + output_classes=tuple(nest.flatten(dataset.output_classes))) + + def _predicate(*xs): + """Return `True` if this element is a full batch.""" + # Extract the dynamic batch size from the first component of the flattened + # batched element. + first_component = xs[0] + first_component_batch_size = array_ops.shape( + first_component, out_type=dtypes.int64)[0] + + return math_ops.equal(first_component_batch_size, tensor_batch_size) + + filtered = flattened.filter(_predicate) + + maybe_constant_batch_size = tensor_util.constant_value(tensor_batch_size) + + def _set_first_dimension(shape): + return shape.merge_with( + tensor_shape.vector(maybe_constant_batch_size).concatenate(shape[1:])) + + known_shapes = nest.map_structure(_set_first_dimension, + dataset.output_shapes) + return _RestructuredDataset( + filtered, + dataset.output_types, + known_shapes, + output_classes=dataset.output_classes) + + return _apply_fn + + def batch_and_drop_remainder(batch_size): """A batching transformation that omits the final small batch (if present). @@ -135,34 +178,43 @@ def batch_and_drop_remainder(batch_size): def _apply_fn(dataset): """Function from `Dataset` to `Dataset` that applies the transformation.""" - tensor_batch_size = ops.convert_to_tensor( - batch_size, dtype=dtypes.int64, name="batch_size") + batched = dataset.batch(batch_size) + return filter_irregular_batches(batch_size)(batched) + + return _apply_fn - batched = dataset.batch(tensor_batch_size) - flattened = _RestructuredDataset(batched, - tuple(nest.flatten(batched.output_types))) - def _predicate(*xs): - """Return `True` if this element is a full batch.""" - # Extract the dynamic batch size from the first component of the flattened - # batched element. - first_component = xs[0] - first_component_batch_size = array_ops.shape( - first_component, out_type=dtypes.int64)[0] +def padded_batch_and_drop_remainder(batch_size, + padded_shapes, + padding_values=None): + """A batching and padding transformation that omits the final small batch. - return math_ops.equal(first_component_batch_size, tensor_batch_size) + Like @{tf.data.Dataset.padded_batch}, this transformation combines + consecutive elements of this dataset into batches. However, if the batch + size does not evenly divide the input dataset size, this transformation will + drop the final smaller element. - filtered = flattened.filter(_predicate) + See `@{tf.contrib.data.batch_and_drop_remainder}` for more details. - maybe_constant_batch_size = tensor_util.constant_value(tensor_batch_size) + Args: + batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of + consecutive elements of this dataset to combine in a single batch. + padded_shapes: A nested structure of `tf.TensorShape` or + `tf.int64` vector tensor-like objects. See + @{tf.data.Dataset.padded_batch} for details. + padding_values: (Optional.) A nested structure of scalar-shaped + `tf.Tensor`. See @{tf.data.Dataset.padded_batch} for details. - def _set_first_dimension(shape): - return shape.merge_with( - tensor_shape.vector(maybe_constant_batch_size).concatenate(shape[1:])) + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.data.Dataset.apply} + """ - known_shapes = nest.map_structure(_set_first_dimension, - batched.output_shapes) - return _RestructuredDataset(filtered, batched.output_types, known_shapes) + def _apply_fn(dataset): + """Function from `Dataset` to `Dataset` that applies the transformation.""" + batched = dataset.padded_batch( + batch_size, padded_shapes=padded_shapes, padding_values=padding_values) + return filter_irregular_batches(batch_size)(batched) return _apply_fn @@ -191,6 +243,10 @@ class DenseToSparseBatchDataset(dataset_ops.Dataset): output_shapes=self.output_shapes, output_types=self.output_types) + @property + def output_classes(self): + return (ops.Tensor, ops.Tensor, ops.Tensor) + @property def output_shapes(self): num_elements = tensor_shape.Dimension(None) @@ -206,7 +262,11 @@ class DenseToSparseBatchDataset(dataset_ops.Dataset): class _RestructuredDataset(dataset_ops.Dataset): """An internal helper for changing the structure and shape of a dataset.""" - def __init__(self, dataset, output_types, output_shapes=None): + def __init__(self, + dataset, + output_types, + output_shapes=None, + output_classes=None): """Creates a new dataset with the given output types and shapes. The given `dataset` must have a structure that is convertible: @@ -222,6 +282,8 @@ class _RestructuredDataset(dataset_ops.Dataset): output_types: A nested structure of `tf.DType` objects. output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects. If omitted, the shapes will be inherited from `dataset`. + output_classes: (Optional.) A nested structure of class types. + If omitted, the class types will be inherited from `dataset`. Raises: ValueError: If either `output_types` or `output_shapes` is not compatible @@ -261,10 +323,21 @@ class _RestructuredDataset(dataset_ops.Dataset): output_shapes)) self._output_shapes = nest.map_structure_up_to( output_types, tensor_shape.as_shape, output_shapes) + if output_classes is None: + # Inherit class types from the original `dataset`. + self._output_classes = nest.pack_sequence_as(output_types, + nest.flatten( + dataset.output_classes)) + else: + self._output_classes = output_classes def _as_variant_tensor(self): return self._dataset._as_variant_tensor() # pylint: disable=protected-access + @property + def output_classes(self): + return self._output_classes + @property def output_types(self): return self._output_types @@ -280,7 +353,6 @@ class _MapAndBatchDataset(dataset_ops.MapDataset): def __init__(self, input_dataset, map_func, batch_size, num_parallel_batches): """See `Dataset.map()` for details.""" super(_MapAndBatchDataset, self).__init__(input_dataset, map_func) - self._batch_size = ops.convert_to_tensor( batch_size, dtype=dtypes.int64, name="batch_size") self._num_parallel_batches = ops.convert_to_tensor( @@ -295,8 +367,10 @@ class _MapAndBatchDataset(dataset_ops.MapDataset): f=self._map_func, batch_size=self._batch_size, num_parallel_batches=self._num_parallel_batches, - output_types=nest.flatten(self.output_types), - output_shapes=nest.flatten(self.output_shapes)) + output_types=nest.flatten( + sparse.as_dense_types(self.output_types, self.output_classes)), + output_shapes=nest.flatten( + sparse.as_dense_shapes(self.output_shapes, self.output_classes))) # pylint: enable=protected-access @property diff --git a/tensorflow/contrib/data/python/ops/counter.py b/tensorflow/contrib/data/python/ops/counter.py new file mode 100644 index 0000000000000000000000000000000000000000..63226fe78163c59025623a362d17c400fbe57c67 --- /dev/null +++ b/tensorflow/contrib/data/python/ops/counter.py @@ -0,0 +1,52 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The Counter Dataset.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.ops import scan_ops + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops + + +def Counter(start=0, step=1, dtype=dtypes.int64): + """Creates a `Dataset` of a `step`-separated count startin from `start`. + + For example: + + ```python + Dataset.count() == [0, 1, 2, ...) + Dataset.count(2) == [2, 3, ...) + Dataset.count(2, 5) == [2, 7, 12, ...) + Dataset.count(0, -1) == [0, -1, -2, ...) + Dataset.count(10, -1) == [10, 9, ...) + ``` + + Args: + start: starting value for count. + step: step size. + dtype: counter data type. + + Returns: + A `Dataset` of scalar elements. + """ + with ops.name_scope("counter"): + start = ops.convert_to_tensor(start, dtype=dtype, name="start") + step = ops.convert_to_tensor(step, dtype=dtype, name="step") + return dataset_ops.Dataset.from_tensors(0).repeat(None).apply( + scan_ops.scan(start, lambda state, _: (state + step, state))) diff --git a/tensorflow/contrib/data/python/ops/dataset_ops.py b/tensorflow/contrib/data/python/ops/dataset_ops.py index 45d6dbe7438957029b4d6b71e181cb1fc3596ecb..626a9e0edcea5928b1636c1a2a86e83657c966a5 100644 --- a/tensorflow/contrib/data/python/ops/dataset_ops.py +++ b/tensorflow/contrib/data/python/ops/dataset_ops.py @@ -21,7 +21,6 @@ from tensorflow.contrib.data.python.ops import batching from tensorflow.contrib.data.python.ops import enumerate_ops from tensorflow.contrib.data.python.ops import error_ops from tensorflow.contrib.data.python.ops import grouping - from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest from tensorflow.python.ops import gen_dataset_ops @@ -48,6 +47,10 @@ class Dataset(dataset_ops.Dataset): def _as_variant_tensor(self): return self._dataset._as_variant_tensor() # pylint: disable=protected-access + @property + def output_classes(self): + return self._dataset.output_classes + @property def output_shapes(self): return self._dataset.output_shapes diff --git a/tensorflow/contrib/data/python/ops/error_ops.py b/tensorflow/contrib/data/python/ops/error_ops.py index 238bb52b0205f9ab66f479f1b92e72ab6e38725b..aa629cba479102ee4244884e7c546615b28cf4e5 100644 --- a/tensorflow/contrib/data/python/ops/error_ops.py +++ b/tensorflow/contrib/data/python/ops/error_ops.py @@ -19,6 +19,7 @@ from __future__ import print_function from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest +from tensorflow.python.data.util import sparse from tensorflow.python.ops import gen_dataset_ops @@ -62,8 +63,14 @@ class IgnoreErrorsDataset(dataset_ops.Dataset): def _as_variant_tensor(self): return gen_dataset_ops.ignore_errors_dataset( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access - output_shapes=nest.flatten(self.output_shapes), - output_types=nest.flatten(self.output_types)) + output_shapes=nest.flatten( + sparse.as_dense_shapes(self.output_shapes, self.output_classes)), + output_types=nest.flatten( + sparse.as_dense_types(self.output_types, self.output_classes))) + + @property + def output_classes(self): + return self._input_dataset.output_classes @property def output_shapes(self): diff --git a/tensorflow/contrib/data/python/ops/grouping.py b/tensorflow/contrib/data/python/ops/grouping.py index 6df7b22fb69bb14c41a26bd630a825442f67ee23..ef91c56726e969053fdad667dda3e89430045652 100644 --- a/tensorflow/contrib/data/python/ops/grouping.py +++ b/tensorflow/contrib/data/python/ops/grouping.py @@ -19,6 +19,7 @@ from __future__ import print_function from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest +from tensorflow.python.data.util import sparse from tensorflow.python.framework import dtypes from tensorflow.python.framework import function from tensorflow.python.framework import ops @@ -87,15 +88,21 @@ def group_by_window(key_func, class _VariantDataset(dataset_ops.Dataset): """A Dataset wrapper for a tf.variant-typed function argument.""" - def __init__(self, dataset_variant, output_types, output_shapes): + def __init__(self, dataset_variant, output_types, output_shapes, + output_classes): super(_VariantDataset, self).__init__() self._dataset_variant = dataset_variant self._output_types = output_types self._output_shapes = output_shapes + self._output_classes = output_classes def _as_variant_tensor(self): return self._dataset_variant + @property + def output_classes(self): + return self._output_classes + @property def output_shapes(self): return self._output_shapes @@ -137,13 +144,21 @@ class GroupByWindowDataset(dataset_ops.Dataset): def _make_key_func(self, key_func, input_dataset): """Make wrapping Defun for key_func.""" - @function.Defun(*nest.flatten(input_dataset.output_types)) + @function.Defun(*nest.flatten( + sparse.as_dense_types(input_dataset.output_types, + input_dataset.output_classes))) def tf_key_func(*args): """A wrapper for Defun that facilitates shape inference.""" # Pass in shape information from the input_dataset. - for arg, shape in zip(args, nest.flatten(input_dataset.output_shapes)): + dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes, + input_dataset.output_classes) + for arg, shape in zip(args, nest.flatten(dense_shapes)): arg.set_shape(shape) + nested_args = nest.pack_sequence_as(input_dataset.output_types, args) + nested_args = sparse.deserialize_sparse_tensors( + nested_args, input_dataset.output_types, input_dataset.output_shapes, + input_dataset.output_classes) # pylint: disable=protected-access if dataset_ops._should_unpack_args(nested_args): ret = key_func(*nested_args) @@ -165,14 +180,15 @@ class GroupByWindowDataset(dataset_ops.Dataset): def tf_reduce_func(key, window_dataset_variant): """A wrapper for Defun that facilitates shape inference.""" key.set_shape([]) - window_dataset = _VariantDataset(window_dataset_variant, - input_dataset.output_types, - input_dataset.output_shapes) + window_dataset = _VariantDataset( + window_dataset_variant, input_dataset.output_types, + input_dataset.output_shapes, input_dataset.output_classes) if not isinstance(window_dataset, dataset_ops.Dataset): raise TypeError("`window_dataset` must return a `Dataset` object.") output_dataset = reduce_func(key, window_dataset) if not isinstance(output_dataset, dataset_ops.Dataset): raise TypeError("`reduce_func` must return a `Dataset` object.") + self._output_classes = output_dataset.output_classes self._output_types = output_dataset.output_types self._output_shapes = output_dataset.output_shapes return output_dataset._as_variant_tensor() # pylint: disable=protected-access @@ -180,6 +196,10 @@ class GroupByWindowDataset(dataset_ops.Dataset): self._reduce_func = tf_reduce_func self._reduce_func.add_to_graph(ops.get_default_graph()) + @property + def output_classes(self): + return self._output_classes + @property def output_shapes(self): return self._output_shapes @@ -197,5 +217,7 @@ class GroupByWindowDataset(dataset_ops.Dataset): key_func=self._key_func, reduce_func=self._reduce_func, window_size_func=self._window_size_func, - output_types=nest.flatten(self.output_types), - output_shapes=nest.flatten(self.output_shapes)) + output_types=nest.flatten( + sparse.as_dense_types(self.output_types, self.output_classes)), + output_shapes=nest.flatten( + sparse.as_dense_shapes(self.output_shapes, self.output_classes))) diff --git a/tensorflow/contrib/data/python/ops/interleave_ops.py b/tensorflow/contrib/data/python/ops/interleave_ops.py index 74a919c1fff62cfa79b0877a3d081077ca6776f0..53324e06e7f1dc249388410f0e14e42336630cd1 100644 --- a/tensorflow/contrib/data/python/ops/interleave_ops.py +++ b/tensorflow/contrib/data/python/ops/interleave_ops.py @@ -19,6 +19,7 @@ from __future__ import print_function from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest +from tensorflow.python.data.util import sparse from tensorflow.python.framework import dtypes from tensorflow.python.framework import function from tensorflow.python.framework import ops @@ -35,16 +36,22 @@ class ParallelInterleaveDataset(dataset_ops.Dataset): super(ParallelInterleaveDataset, self).__init__() self._input_dataset = input_dataset - @function.Defun(*nest.flatten(input_dataset.output_types)) + @function.Defun(*nest.flatten( + sparse.as_dense_types(input_dataset.output_types, + input_dataset.output_classes))) def tf_map_func(*args): """A wrapper for Defun that facilitates shape inference.""" # Pass in shape information from the input_dataset. - for arg, shape in zip(args, nest.flatten(input_dataset.output_shapes)): + dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes, + input_dataset.output_classes) + for arg, shape in zip(args, nest.flatten(dense_shapes)): arg.set_shape(shape) nested_args = nest.pack_sequence_as(input_dataset.output_types, args) - - if nest.is_sequence(nested_args): + nested_args = sparse.deserialize_sparse_tensors( + nested_args, input_dataset.output_types, input_dataset.output_shapes, + input_dataset.output_classes) + if dataset_ops._should_unpack_args(nested_args): # pylint: disable=protected-access dataset = map_func(*nested_args) else: dataset = map_func(nested_args) @@ -52,6 +59,7 @@ class ParallelInterleaveDataset(dataset_ops.Dataset): if not isinstance(dataset, dataset_ops.Dataset): raise TypeError("`map_func` must return a `Dataset` object.") + self._output_classes = dataset.output_classes self._output_types = dataset.output_types self._output_shapes = dataset.output_shapes @@ -75,8 +83,14 @@ class ParallelInterleaveDataset(dataset_ops.Dataset): self._block_length, self._sloppy, f=self._map_func, - output_types=nest.flatten(self.output_types), - output_shapes=nest.flatten(self.output_shapes)) + output_types=nest.flatten( + sparse.as_dense_types(self.output_types, self.output_classes)), + output_shapes=nest.flatten( + sparse.as_dense_shapes(self.output_shapes, self.output_classes))) + + @property + def output_classes(self): + return self._output_classes @property def output_shapes(self): diff --git a/tensorflow/contrib/data/python/ops/random_ops.py b/tensorflow/contrib/data/python/ops/random_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..7d727165feabb101549567f28a2dfa07083de244 --- /dev/null +++ b/tensorflow/contrib/data/python/ops/random_ops.py @@ -0,0 +1,67 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Datasets for random number generators.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import nest +from tensorflow.python.data.util import sparse +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import random_seed +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import gen_dataset_ops + + +class RandomDataset(dataset_ops.Dataset): + """A `Dataset` of pseudorandom values.""" + + def __init__(self, seed=None): + """A `Dataset` of pseudorandom values.""" + super(RandomDataset, self).__init__() + seed, seed2 = random_seed.get_seed(seed) + if seed is None: + self._seed = constant_op.constant(0, dtype=dtypes.int64, name="seed") + else: + self._seed = ops.convert_to_tensor(seed, dtype=dtypes.int64, name="seed") + if seed2 is None: + self._seed2 = constant_op.constant(0, dtype=dtypes.int64, name="seed2") + else: + self._seed2 = ops.convert_to_tensor( + seed2, dtype=dtypes.int64, name="seed2") + + def _as_variant_tensor(self): + return gen_dataset_ops.random_dataset( + seed=self._seed, + seed2=self._seed2, + output_shapes=nest.flatten( + sparse.as_dense_shapes(self.output_shapes, self.output_classes)), + output_types=nest.flatten( + sparse.as_dense_types(self.output_types, self.output_classes))) + + @property + def output_classes(self): + return ops.Tensor + + @property + def output_shapes(self): + return tensor_shape.scalar() + + @property + def output_types(self): + return dtypes.int64 diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py index 2e1c3153ca78e20e2628e8754b9827b817f8c732..347e5edc7b0d479dfa260e8cec500ffaaba375be 100644 --- a/tensorflow/contrib/data/python/ops/readers.py +++ b/tensorflow/contrib/data/python/ops/readers.py @@ -23,7 +23,6 @@ from tensorflow.python.data.ops import readers from tensorflow.python.data.util import nest from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import parsing_ops @@ -156,8 +155,7 @@ def read_batch_features(file_pattern, features: A `dict` mapping feature keys to `FixedLenFeature` or `VarLenFeature` values. See `tf.parse_example`. reader: A function or class that can be called with a `filenames` tensor - and (optional) `reader_args` and returns a `Dataset` of serialized - Examples. + and (optional) `reader_args` and returns a `Dataset` of Examples. reader_args: Additional arguments to pass to the reader class. randomize_input: Whether the input should be randomized. num_epochs: Integer specifying the number of times to read through the @@ -166,7 +164,7 @@ def read_batch_features(file_pattern, shuffling but would increase memory usage and startup time. Returns: - A dict from keys in features to Tensor or SparseTensor objects. + A dict from keys in features to `Tensor` or `SparseTensor` objects. """ filenames = _get_file_names(file_pattern, randomize_input) if reader_args: @@ -174,32 +172,17 @@ def read_batch_features(file_pattern, else: dataset = reader(filenames) if dataset.output_types == (dtypes.string, dtypes.string): - dataset = dataset.map(lambda unused_k, v: v) - elif dataset.output_types != dtypes.string: - raise TypeError("`reader` must be a dataset of `tf.string` values, " - "or `(tf.string, tf.string)` key-value pairs.") + dataset = dataset.map(lambda _, v: v) if num_epochs != 1: dataset = dataset.repeat(num_epochs) if randomize_input: dataset = dataset.shuffle(capacity) dataset = dataset.batch(batch_size) - dataset = dataset.map(lambda x: _parse_example(x, features)) + dataset = dataset.map(lambda x: parsing_ops.parse_example(x, features)) + dataset = dataset.prefetch(1) iterator = dataset.make_one_shot_iterator() outputs = iterator.get_next() - index = 0 - result = {} - for key in sorted(features.keys()): - feature = features[key] - if isinstance(feature, parsing_ops.FixedLenFeature): - result[key] = outputs[index] - index += 1 - else: - result[key] = sparse_tensor_lib.SparseTensor( - indices=outputs[index], - values=outputs[index + 1], - dense_shape=outputs[index + 2]) - index += 3 - return result + return outputs def _get_file_names(file_pattern, randomize_input): @@ -233,18 +216,6 @@ def _get_file_names(file_pattern, randomize_input): return file_names -def _parse_example(serialized, features): - parsed = parsing_ops.parse_example(serialized, features) - result = [] - for key in sorted(features.keys()): - val = parsed[key] - if isinstance(val, sparse_tensor_lib.SparseTensor): - result.extend([val.indices, val.values, val.dense_shape]) - else: - result.append(val) - return tuple(result) - - class SqlDataset(contrib_dataset_ops.Dataset): def __init__(self, driver_name, data_source_name, query, output_types): @@ -299,6 +270,10 @@ class _SqlDataset(dataset_ops.Dataset): nest.flatten(self.output_types), nest.flatten(self.output_shapes)) + @property + def output_classes(self): + return nest.map_structure(lambda _: ops.Tensor, self._output_types) + @property def output_shapes(self): return nest.map_structure(lambda _: tensor_shape.TensorShape([]), diff --git a/tensorflow/contrib/data/python/ops/scan_ops.py b/tensorflow/contrib/data/python/ops/scan_ops.py index 5acaed48a3d73e93706bdd0b5b2d614b0c565ab7..2744786e9eec4c9268ba854df6ea761339bb0b4e 100644 --- a/tensorflow/contrib/data/python/ops/scan_ops.py +++ b/tensorflow/contrib/data/python/ops/scan_ops.py @@ -21,6 +21,7 @@ import collections from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest +from tensorflow.python.data.util import sparse from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.ops import gen_dataset_ops @@ -43,6 +44,7 @@ class _ScanDataset(dataset_ops.Dataset): # Compute initial values for the state shapes and types based on # the initial state. These will be refined by running # `tf_scan_func` one or more times below. + # TODO(b/68937811): Allow the initial state to be a tf.SparseTensor. self._state_shapes = nest.pack_sequence_as( self._initial_state, [t.shape for t in nest.flatten(self._initial_state)]) @@ -51,6 +53,7 @@ class _ScanDataset(dataset_ops.Dataset): [t.dtype for t in nest.flatten(self._initial_state)]) # Will be populated by calling `tf_scan_func`. + self._output_classes = None self._output_shapes = None self._output_types = None @@ -65,14 +68,17 @@ class _ScanDataset(dataset_ops.Dataset): # Create a list in which `tf_scan_func` will store the s flat_new_state_shapes = [] - @function.Defun( - *(flat_state_types + nest.flatten(input_dataset.output_types))) + @function.Defun(*(flat_state_types + nest.flatten( + sparse.as_dense_types(input_dataset.output_types, + input_dataset.output_classes)))) def tf_scan_func(*args): """A wrapper for Defun that facilitates shape inference.""" # Pass in shape information from the state and input_dataset. - for arg, shape in zip( - args, - flat_state_shapes + nest.flatten(input_dataset.output_shapes)): + # TODO(b/69424092): Check that neither inputs nor outputs are sparse. + dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes, + input_dataset.output_classes) + for arg, shape in zip(args, + flat_state_shapes + nest.flatten(dense_shapes)): arg.set_shape(shape) pivot = len(flat_state_shapes) @@ -106,6 +112,8 @@ class _ScanDataset(dataset_ops.Dataset): "state. Expected %s; got %s." % (self._state_types, nest.pack_sequence_as( self._state_types, [t.dtype for t in flat_new_state]))) + self._output_classes = nest.pack_sequence_as( + output_value, [ops.Tensor for _ in flat_output_value]) self._output_types = nest.pack_sequence_as( output_value, [t.dtype for t in flat_output_value]) @@ -144,8 +152,14 @@ class _ScanDataset(dataset_ops.Dataset): nest.flatten(self._initial_state), self._scan_func.captured_inputs, f=self._scan_func, - output_types=nest.flatten(self.output_types), - output_shapes=nest.flatten(self.output_shapes)) + output_types=nest.flatten( + sparse.as_dense_types(self.output_types, self.output_classes)), + output_shapes=nest.flatten( + sparse.as_dense_shapes(self.output_shapes, self.output_classes))) + + @property + def output_classes(self): + return self._output_classes @property def output_shapes(self): diff --git a/tensorflow/contrib/data/python/ops/shuffle_ops.py b/tensorflow/contrib/data/python/ops/shuffle_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..460732d65e4e652058ad821fbed45d365b4f41c1 --- /dev/null +++ b/tensorflow/contrib/data/python/ops/shuffle_ops.py @@ -0,0 +1,69 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Experimental shuffle ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.ops import batching +from tensorflow.contrib.data.python.ops import random_ops +from tensorflow.python.data.ops import dataset_ops + + +def shuffle_and_repeat(buffer_size, count=None, seed=None): + """Shuffles and repeats a Dataset returning a new permutation for each epoch. + + `dataset.apply(tf.contrib.data.shuffle_and_repeat(buffer_size, count))` + + is equivalent to + + `dataset.shuffle(buffer_size, reshuffle_each_iteration=True).repeat(count)` + + The difference is that the latter dataset is not serializable. So, + if you need to checkpoint an input pipeline with reshuffling you must use + this implementation. + + Args: + buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the + maximum number elements that will be buffered when prefetching. + count: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the + number of times the dataset should be repeated. The default behavior + (if `count` is `None` or `-1`) is for the dataset be repeated + indefinitely. + seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the + random seed that will be used to create the distribution. See + @{tf.set_random_seed} for behavior. + + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.contrib.data.Dataset.apply}. + """ + def _apply_fn(dataset): # pylint: disable=missing-docstring + random_ds = random_ops.RandomDataset(seed).apply( + batching.batch_and_drop_remainder(2)) + if count is not None and count is not -1: + random_ds = random_ds.take(count) + + def map_fn(seeds): + return dataset_ops.ShuffleDataset( + input_dataset=dataset, + buffer_size=buffer_size, + seed=seeds[0], + reshuffle_each_iteration=False, + seed2=seeds[1]) + + return random_ds.flat_map(map_fn) + + return _apply_fn diff --git a/tensorflow/contrib/data/python/ops/stats_ops.py b/tensorflow/contrib/data/python/ops/stats_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..b8875bd533ddc9e2c195646619dccf3aab5225e4 --- /dev/null +++ b/tensorflow/contrib/data/python/ops/stats_ops.py @@ -0,0 +1,177 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Experimental API for gathering statistics from `tf.data` pipelines.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.ops import iterator_ops +from tensorflow.python.data.util import nest +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import gen_dataset_ops + + +class StatsAggregator(object): + """A stateful resource that aggregates statistics from one or more iterators. + + To record statistics, use one of the custom transformation functions defined + in this module when defining your @{tf.data.Dataset}. All statistics will be + aggregated by the `StatsAggregator` that is associated with a particular + iterator (see below). For example, to record the total number of bytes + produced by iterating over a dataset: + + ```python + dataset = ... + dataset = dataset.apply(stats_ops.bytes_produced_stats("total_bytes")) + ``` + + To associate a `StatsAggregator` with a @{tf.data.Iterator} object, use + the following pattern: + + ```python + dataset = ... + iterator = dataset.make_one_shot_iterator() + stats_aggregator = stats_ops.StatsAggregator() + set_op = stats_op.set_stats_aggregator_op(iterator, stats_aggregator) + + with tf.Session() as sess: + # Running `set_op` will associate `iterator` with `stats_aggregator`. + sess.run(set_op) + ``` + + To get a protocol buffer summary of the currently aggregated statistics, + use the `StatsAggregator.get_summary()` tensor. The easiest way to do this + is to add the returned tensor to the @{tf.GraphKeys.SUMMARIES} collection, + so that the summaries will be included with any existing summaries. + + ```python + stats_aggregator = stats_ops.StatsAggregator() + stats_summary = stats_aggregator.get_summary() + tf.add_to_collection(tf.GraphKeys.SUMMARIES, stats_summary) + ``` + + Note: This interface is experimental and expected to change. In particular, + we expect to add other implementations of `StatsAggregator` that provide + different ways of exporting statistics, and add more types of statistics. + """ + + def __init__(self): + """Creates a `StatsAggregator`.""" + self._resource = gen_dataset_ops.stats_aggregator_handle() + + def get_summary(self): + """Returns a string @{tf.Tensor} that summarizes the aggregated statistics. + + The returned tensor will contain a serialized @{tf.summary.Summary} protocol + buffer, which can be used with the standard TensorBoard logging facilities. + + Returns: + A scalar string @{tf.Tensor} that summarizes the aggregated statistics. + """ + return gen_dataset_ops.stats_aggregator_summary(self._resource) + + def subscribe(self, iterator): + """Returns a @{tf.Operation} to associate this aggregator with `iterator`. + + Note: Each @{tf.data.Iterator} can be associated with at most one + `StatsAggregator`. After running the operation that this function + returns, all statistics recorded in the iteration of `iterator` + will be stored in `stats_aggregator`. + + Args: + iterator: A @{tf.data.Iterator} object. + + Returns: + A @{tf.Operation} that, when run, associates this aggregator with + `iterator`. + """ + if not isinstance(iterator, iterator_ops.Iterator): + raise TypeError("`iterator` must be a `tf.data.Iterator` object.") + return gen_dataset_ops.iterator_set_stats_aggregator( + iterator._iterator_resource, self._resource) # pylint: disable=protected-access + + +def bytes_produced_stats(tag): + """Records the number of bytes produced by each element of the input dataset. + + To consume the statistics, associate a `StatsAggregator` with an iterator + over the output dataset. + + Args: + tag: String. All statistics recorded by the returned transformation will + be associated with the given `tag`. + + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.contrib.data.Dataset.apply}. + """ + + def _apply_fn(dataset): + return _StatsDataset(dataset, gen_dataset_ops.bytes_produced_stats_dataset, + tag) + + return _apply_fn + + +def latency_stats(tag): + """Records the latency of producing each element of the input dataset. + + To consume the statistics, associate a `StatsAggregator` with an iterator + over the output dataset. + + Args: + tag: String. All statistics recorded by the returned transformation will + be associated with the given `tag`. + + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.contrib.data.Dataset.apply}. + """ + + def _apply_fn(dataset): + return _StatsDataset(dataset, gen_dataset_ops.latency_stats_dataset, tag) + + return _apply_fn + + +class _StatsDataset(dataset_ops.Dataset): + """A `Dataset` that acts as an identity, and also records statistics.""" + + def __init__(self, input_dataset, op_function, tag): + super(_StatsDataset, self).__init__() + self._input_dataset = input_dataset + self._op_function = op_function + self._tag = ops.convert_to_tensor(tag, dtype=dtypes.string) + + def _as_variant_tensor(self): + return self._op_function( + self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + self._tag, + output_shapes=nest.flatten(self.output_shapes), + output_types=nest.flatten(self.output_types)) + + @property + def output_shapes(self): + return self._input_dataset.output_shapes + + @property + def output_types(self): + return self._input_dataset.output_types + + @property + def output_classes(self): + return self._input_dataset.output_classes diff --git a/tensorflow/contrib/decision_trees/proto/generic_tree_model_proto.swig b/tensorflow/contrib/decision_trees/proto/generic_tree_model_proto.swig index d3d201afd5761e7c5c136301c779222bedc68492..cafb9314caee1c4907786b8101e7c71bd7095306 100644 --- a/tensorflow/contrib/decision_trees/proto/generic_tree_model_proto.swig +++ b/tensorflow/contrib/decision_trees/proto/generic_tree_model_proto.swig @@ -2,7 +2,7 @@ %include "net/proto/swig/protofunc.swig" -#ifndef MUST_USE_RESULT +#ifndef ABSL_MUST_USE_RESULT #error Use this file only as a %include or %import after google.swig. #endif diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD index 4a4f3789016bed5db475da81b2448b682f158353..b2c641f8ab3ea23c5135042e4b1223d487ae8cbc 100644 --- a/tensorflow/contrib/distributions/BUILD +++ b/tensorflow/contrib/distributions/BUILD @@ -2,12 +2,15 @@ # Contains ops for statistical distributions (with pdf, cdf, sample, etc...). # APIs here are meant to evolve over time. +package(default_visibility = [ + "//learning/brain/contrib/bayesflow:__subpackages__", + "//tensorflow:__subpackages__", +]) + licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) -package(default_visibility = ["//tensorflow:__subpackages__"]) - load("//tensorflow:tensorflow.bzl", "cuda_py_test") py_library( @@ -137,6 +140,23 @@ cuda_py_test( ], ) +cuda_py_test( + name = "cauchy_test", + size = "medium", + srcs = ["python/kernel_tests/cauchy_test.py"], + additional_deps = [ + ":distributions_py", + "//third_party/py/numpy", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:gradients", + "//tensorflow/python:platform_test", + "//tensorflow/python:variables", + ], +) + cuda_py_test( name = "chi2_test", srcs = ["python/kernel_tests/chi2_test.py"], @@ -184,6 +204,24 @@ cuda_py_test( ], ) +cuda_py_test( + name = "half_normal_test", + size = "medium", + srcs = ["python/kernel_tests/half_normal_test.py"], + additional_deps = [ + ":distributions_py", + "//third_party/py/numpy", + "//tensorflow/python:client", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:gradients", + "//tensorflow/python:nn_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:variables", + ], +) + cuda_py_test( name = "inverse_gamma_test", srcs = ["python/kernel_tests/inverse_gamma_test.py"], diff --git a/tensorflow/contrib/distributions/__init__.py b/tensorflow/contrib/distributions/__init__.py index 16f6533e57347a5fe41b017c9855d216fba9da82..66827179e9fa1bea852f55246c263c4696cf3bdc 100644 --- a/tensorflow/contrib/distributions/__init__.py +++ b/tensorflow/contrib/distributions/__init__.py @@ -24,6 +24,7 @@ from __future__ import print_function from tensorflow.contrib.distributions.python.ops import bijectors from tensorflow.contrib.distributions.python.ops.binomial import * +from tensorflow.contrib.distributions.python.ops.cauchy import * from tensorflow.contrib.distributions.python.ops.chi2 import * from tensorflow.contrib.distributions.python.ops.conditional_distribution import * from tensorflow.contrib.distributions.python.ops.conditional_transformed_distribution import * @@ -35,6 +36,7 @@ from tensorflow.contrib.distributions.python.ops.distribution_util import softpl from tensorflow.contrib.distributions.python.ops.distribution_util import tridiag from tensorflow.contrib.distributions.python.ops.estimator import * from tensorflow.contrib.distributions.python.ops.geometric import * +from tensorflow.contrib.distributions.python.ops.half_normal import * from tensorflow.contrib.distributions.python.ops.independent import * from tensorflow.contrib.distributions.python.ops.inverse_gamma import * from tensorflow.contrib.distributions.python.ops.logistic import * @@ -83,6 +85,7 @@ from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ 'bijectors', + 'Cauchy', 'ConditionalDistribution', 'ConditionalTransformedDistribution', 'FULLY_REPARAMETERIZED', @@ -105,6 +108,7 @@ _allowed_symbols = [ 'Gamma', 'GammaWithSoftplusConcentrationRate', 'Geometric', + 'HalfNormal', 'Independent', 'InverseGamma', 'InverseGammaWithSoftplusConcentrationRate', diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/masked_autoregressive_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/masked_autoregressive_test.py index 25a9b6f5fe2ed6d218d6b44650fce17fa89c0664..288d9d8dd6f17cd6348d3d72aea4408e26913ebd 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/masked_autoregressive_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/masked_autoregressive_test.py @@ -22,9 +22,9 @@ import numpy as np from tensorflow.contrib.distributions.python.ops import test_util from tensorflow.contrib.distributions.python.ops.bijectors.invert import Invert +from tensorflow.contrib.distributions.python.ops.bijectors.masked_autoregressive import _gen_mask from tensorflow.contrib.distributions.python.ops.bijectors.masked_autoregressive import masked_autoregressive_default_template from tensorflow.contrib.distributions.python.ops.bijectors.masked_autoregressive import MaskedAutoregressiveFlow -from tensorflow.contrib.distributions.python.ops.bijectors.masked_autoregressive_impl import _gen_mask from tensorflow.python.framework import constant_op from tensorflow.python.ops import array_ops from tensorflow.python.ops import variables diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py index 38b3a23c2d684a6f89b7c4be4a763c649bf4de15..49451446b56d290f130c5db90c13b94974d92dc9 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py @@ -28,8 +28,19 @@ from tensorflow.python.ops.distributions.bijector_test_util import assert_biject from tensorflow.python.platform import test -class ReshapeBijectorTest(test.TestCase): - """Tests correctness of the reshape transformation.""" +class _ReshapeBijectorTest(object): + """Base class for testing the reshape transformation. + + Methods defined in this class call a method self.build_shapes() that + is implemented by subclasses defined below, returning respectively + ReshapeBijectorTestStatic: static shapes, + ReshapeBijectorTestDynamic: shape placeholders of known ndims, and + ReshapeBijectorTestDynamicNdims: shape placeholders of unspecified ndims, + so that each test in this base class is automatically run over all + three cases. The subclasses also implement assertRaisesError to test + for either Python exceptions (in the case of static shapes) or + TensorFlow op errors (dynamic shapes). + """ def setUp(self): self._rng = np.random.RandomState(42) @@ -40,9 +51,10 @@ class ReshapeBijectorTest(test.TestCase): expected_y = np.reshape(expected_x, [4, 6]) with self.test_session() as sess: + shape_in, shape_out, feed_dict = self.build_shapes([3, 2], [6,]) bijector = Reshape( - event_shape_out=[6,], - event_shape_in=[3, 2], + event_shape_out=shape_out, + event_shape_in=shape_in, validate_args=True) (x_, y_, @@ -52,66 +64,23 @@ class ReshapeBijectorTest(test.TestCase): bijector.forward(expected_x), bijector.forward_log_det_jacobian(expected_x), bijector.inverse_log_det_jacobian(expected_y), - )) + ), feed_dict=feed_dict) self.assertEqual("reshape", bijector.name) self.assertAllClose(expected_y, y_, rtol=1e-6, atol=0) self.assertAllClose(expected_x, x_, rtol=1e-6, atol=0) self.assertAllClose(0., fldj_, rtol=1e-6, atol=0) self.assertAllClose(0., ildj_, rtol=1e-6, atol=0) - def testEventShapeDynamicNdims(self): - """Check forward/inverse shape methods with dynamic ndims.""" - - shape_in = tensor_shape.TensorShape([6,]) - shape_in_ph = array_ops.placeholder(dtype=dtypes.int32) - - shape_out = tensor_shape.TensorShape([2, 3]) - shape_out_ph = array_ops.placeholder(dtype=dtypes.int32) - - bijector = Reshape( - event_shape_out=shape_out_ph, - event_shape_in=shape_in_ph, validate_args=True) - - # using the _tensor methods, we should always get a fully-specified - # result since these are evaluated at graph runtime. - with self.test_session() as sess: - (shape_out_, - shape_in_) = sess.run(( - bijector.forward_event_shape_tensor(shape_in), - bijector.inverse_event_shape_tensor(shape_out), - ), feed_dict={ - shape_in_ph: shape_in, - shape_out_ph: shape_out, - }) - self.assertAllEqual(shape_out, shape_out_) - self.assertAllEqual(shape_in, shape_in_) - - def testEventShapeDynamic(self): - """Check shape methods with static ndims but dynamic shape.""" - - shape_in = tensor_shape.TensorShape([6,]) - shape_in_partial = tensor_shape.TensorShape([None,]) - shape_in_ph = array_ops.placeholder( - shape=[1,], dtype=dtypes.int32) - - shape_out = tensor_shape.TensorShape([2, 3]) - shape_out_partial = tensor_shape.TensorShape([None, None]) - shape_out_ph = array_ops.placeholder( - shape=[2,], dtype=dtypes.int32) + def testEventShapeTensor(self): + """Test event_shape_tensor methods when even ndims may be dynamic.""" + shape_in_static = [2, 3] + shape_out_static = [6,] + shape_in, shape_out, feed_dict = self.build_shapes(shape_in_static, + shape_out_static) bijector = Reshape( - event_shape_out=shape_out_ph, - event_shape_in=shape_in_ph, - validate_args=True) - - # if event shapes are not statically available, should - # return partially-specified TensorShapes. - self.assertAllEqual( - bijector.forward_event_shape(shape_in).as_list(), - shape_out_partial.as_list()) - self.assertAllEqual( - bijector.inverse_event_shape(shape_out).as_list(), - shape_in_partial.as_list()) + event_shape_out=shape_out, + event_shape_in=shape_in, validate_args=True) # using the _tensor methods, we should always get a fully-specified # result since these are evaluated at graph runtime. @@ -120,42 +89,9 @@ class ReshapeBijectorTest(test.TestCase): shape_in_) = sess.run(( bijector.forward_event_shape_tensor(shape_in), bijector.inverse_event_shape_tensor(shape_out), - ), feed_dict={ - shape_in_ph: shape_in, - shape_out_ph: shape_out, - }) - self.assertAllEqual(shape_out, shape_out_) - self.assertAllEqual(shape_in, shape_in_) - - def testEventShapeStatic(self): - """Check shape methods when shape is statically known.""" - - shape_in = tensor_shape.TensorShape([6,]) - shape_out = tensor_shape.TensorShape([2, 3]) - - bijector_static = Reshape( - event_shape_out=shape_out, - event_shape_in=shape_in, - validate_args=True) - - # test that forward_ and inverse_event_shape do sensible things - # when shapes are statically known. - self.assertEqual( - bijector_static.forward_event_shape(shape_in), - shape_out) - self.assertEqual( - bijector_static.inverse_event_shape(shape_out), - shape_in) - - with self.test_session() as sess: - (shape_out_static_, - shape_in_static_, - ) = sess.run(( - bijector_static.forward_event_shape_tensor(shape_in), - bijector_static.inverse_event_shape_tensor(shape_out), - )) - self.assertAllEqual(shape_out, shape_out_static_) - self.assertAllEqual(shape_in, shape_in_static_) + ), feed_dict=feed_dict) + self.assertAllEqual(shape_out_static, shape_out_) + self.assertAllEqual(shape_in_static, shape_in_) def testScalarReshape(self): """Test reshaping to and from a scalar shape ().""" @@ -166,11 +102,11 @@ class ReshapeBijectorTest(test.TestCase): expected_x_scalar = np.random.randn(1,) expected_y_scalar = expected_x_scalar[0] + shape_in, shape_out, feed_dict = self.build_shapes([], [1,]) with self.test_session() as sess: bijector = Reshape( - event_shape_out=[], - event_shape_in=[1,], validate_args=True) - + event_shape_out=shape_in, + event_shape_in=shape_out, validate_args=True) (x_, y_, x_scalar_, @@ -180,53 +116,178 @@ class ReshapeBijectorTest(test.TestCase): bijector.forward(expected_x), bijector.inverse(expected_y_scalar), bijector.forward(expected_x_scalar), - )) + ), feed_dict=feed_dict) self.assertAllClose(expected_y, y_, rtol=1e-6, atol=0) self.assertAllClose(expected_x, x_, rtol=1e-6, atol=0) self.assertAllClose(expected_y_scalar, y_scalar_, rtol=1e-6, atol=0) self.assertAllClose(expected_x_scalar, x_scalar_, rtol=1e-6, atol=0) - def testRaisesOpError(self): - x1 = np.random.randn(4, 2, 3) - x2 = np.random.randn(4, 3, 2) - x3 = np.random.randn(4, 5, 1, 1) + def testMultipleUnspecifiedDimensionsOpError(self): with self.test_session() as sess: - shape_in_ph = array_ops.placeholder(shape=[2,], dtype=dtypes.int32) - shape_out_ph = array_ops.placeholder(shape=[3,], dtype=dtypes.int32) + shape_in, shape_out, feed_dict = self.build_shapes([2, 3], [4, -1, -1,]) bijector = Reshape( - event_shape_out=shape_out_ph, - event_shape_in=shape_in_ph, + event_shape_out=shape_out, + event_shape_in=shape_in, validate_args=True) - with self.assertRaisesOpError( + with self.assertRaisesError( + "elements must have at most one `-1`."): + sess.run(bijector.forward_event_shape_tensor(shape_in), + feed_dict=feed_dict) + + def testInvalidDimensionsOpError(self): + + with self.test_session() as sess: + + shape_in, shape_out, feed_dict = self.build_shapes([2, 3], [1, 2, -2,]) + bijector = Reshape( + event_shape_out=shape_out, + event_shape_in=shape_in, + validate_args=True) + + with self.assertRaisesError( + "elements must be either positive integers or `-1`."): + sess.run(bijector.forward_event_shape_tensor(shape_in), + feed_dict=feed_dict) + + def testValidButNonMatchingInputOpError(self): + x = np.random.randn(4, 3, 2) + + with self.test_session() as sess: + shape_in, shape_out, feed_dict = self.build_shapes([2, 3], [1, 6, 1,]) + bijector = Reshape( + event_shape_out=shape_out, + event_shape_in=shape_in, + validate_args=True) + + # Here we pass in a tensor (x) whose shape is compatible with + # the output shape, so tf.reshape will throw no error, but + # doesn't match the expected input shape. + with self.assertRaisesError( "Input `event_shape` does not match `event_shape_in`."): - sess.run(bijector.forward(x2), - feed_dict={shape_out_ph: [1, 6, 1], - shape_in_ph: [2, 3]}) + sess.run(bijector.forward(x), + feed_dict=feed_dict) - with self.assertRaisesOpError( - "event_shape_out entries must be positive."): - sess.run(bijector.forward(x1), - feed_dict={shape_out_ph: [-1, -1, 6], - shape_in_ph: [2, 3]}) + def testValidButNonMatchingInputPartiallySpecifiedOpError(self): + x = np.random.randn(4, 3, 2) + + with self.test_session() as sess: + shape_in, shape_out, feed_dict = self.build_shapes([2, -1], [1, 6, 1,]) + bijector = Reshape( + event_shape_out=shape_out, + event_shape_in=shape_in, + validate_args=True) + + with self.assertRaisesError( + "Input `event_shape` does not match `event_shape_in`."): + sess.run(bijector.forward(x), + feed_dict=feed_dict) + + def testInputOutputMismatchOpError(self): + x1 = np.random.randn(4, 2, 3) + x2 = np.random.randn(4, 1, 1, 5) + + with self.test_session() as sess: + shape_in, shape_out, fd_mismatched = self.build_shapes([2, 3], + [1, 1, 5]) + bijector = Reshape( + event_shape_out=shape_out, + event_shape_in=shape_in, + validate_args=True) # test that *all* methods check basic assertions - fd_mismatched = {shape_out_ph: [1, 1, 5], shape_in_ph: [2, 3]} - with self.assertRaisesOpError( - "Input/output `event_size`s do not match."): + with self.assertRaisesError( + "Input to reshape is a tensor with"): sess.run(bijector.forward(x1), feed_dict=fd_mismatched) - with self.assertRaisesOpError( - "Input/output `event_size`s do not match."): - sess.run(bijector.inverse(x3), feed_dict=fd_mismatched) - with self.assertRaisesOpError( - "Input/output `event_size`s do not match."): - sess.run(bijector.inverse_log_det_jacobian(x3), - feed_dict=fd_mismatched) - with self.assertRaisesOpError( - "Input/output `event_size`s do not match."): - sess.run(bijector.forward_log_det_jacobian(x1), - feed_dict=fd_mismatched) + with self.assertRaisesError( + "Input to reshape is a tensor with"): + sess.run(bijector.inverse(x2), feed_dict=fd_mismatched) + + def testOneShapePartiallySpecified(self): + expected_x = np.random.randn(4, 6) + expected_y = np.reshape(expected_x, [4, 2, 3]) + + with self.test_session() as sess: + # one of input/output shapes is partially specified + shape_in, shape_out, feed_dict = self.build_shapes([-1,], [2, 3]) + bijector = Reshape( + event_shape_out=shape_out, + event_shape_in=shape_in, + validate_args=True) + (x_, + y_, + ) = sess.run(( + bijector.inverse(expected_y), + bijector.forward(expected_x), + ), feed_dict=feed_dict) + self.assertAllClose(expected_y, y_, rtol=1e-6, atol=0) + self.assertAllClose(expected_x, x_, rtol=1e-6, atol=0) + + def testBothShapesPartiallySpecified(self): + expected_x = np.random.randn(4, 2, 3) + expected_y = np.reshape(expected_x, [4, 3, 2]) + with self.test_session() as sess: + shape_in, shape_out, feed_dict = self.build_shapes([-1, 3], [-1, 2]) + bijector = Reshape( + event_shape_out=shape_out, + event_shape_in=shape_in, + validate_args=True) + (x_, + y_, + ) = sess.run(( + bijector.inverse(expected_y), + bijector.forward(expected_x), + ), feed_dict=feed_dict) + self.assertAllClose(expected_y, y_, rtol=1e-6, atol=0) + self.assertAllClose(expected_x, x_, rtol=1e-6, atol=0) + + def testDefaultVectorShape(self): + expected_x = np.random.randn(4, 4) + expected_y = np.reshape(expected_x, [4, 2, 2]) + with self.test_session() as sess: + _, shape_out, feed_dict = self.build_shapes([-1,], [-1, 2]) + bijector = Reshape(shape_out, + validate_args=True) + (x_, + y_, + ) = sess.run(( + bijector.inverse(expected_y), + bijector.forward(expected_x), + ), feed_dict=feed_dict) + self.assertAllClose(expected_y, y_, rtol=1e-6, atol=0) + self.assertAllClose(expected_x, x_, rtol=1e-6, atol=0) + + def build_shapes(self, *args, **kwargs): + raise NotImplementedError("Subclass failed to implement `build_shapes`.") + + +class ReshapeBijectorTestStatic(test.TestCase, _ReshapeBijectorTest): + + def build_shapes(self, shape_in, shape_out): + shape_in_static = shape_in + shape_out_static = shape_out + feed_dict = {} + return shape_in_static, shape_out_static, feed_dict + + def assertRaisesError(self, msg): + return self.assertRaisesRegexp(Exception, msg) + + def testEventShape(self): + shape_in_static = tensor_shape.TensorShape([2, 3]) + shape_out_static = tensor_shape.TensorShape([6,]) + bijector = Reshape( + event_shape_out=shape_out_static, + event_shape_in=shape_in_static, validate_args=True) + + # test that forward_ and inverse_event_shape do sensible things + # when shapes are statically known. + self.assertEqual( + bijector.forward_event_shape(shape_in_static), + shape_out_static) + self.assertEqual( + bijector.inverse_event_shape(shape_out_static), + shape_in_static) def testBijectiveAndFinite(self): x = np.random.randn(4, 2, 3) @@ -238,5 +299,32 @@ class ReshapeBijectorTest(test.TestCase): validate_args=True) assert_bijective_and_finite(bijector, x, y, rtol=1e-6, atol=0) + +class ReshapeBijectorTestDynamic(test.TestCase, _ReshapeBijectorTest): + + def build_shapes(self, shape_in, shape_out): + shape_in_ph = array_ops.placeholder(shape=(len(shape_in),), + dtype=dtypes.int32) + shape_out_ph = array_ops.placeholder(shape=(len(shape_out),), + dtype=dtypes.int32) + feed_dict = {shape_in_ph: shape_in, shape_out_ph: shape_out} + return shape_in_ph, shape_out_ph, feed_dict + + def assertRaisesError(self, msg): + return self.assertRaisesOpError(msg) + + +class ReshapeBijectorTestDynamicNdims(test.TestCase, _ReshapeBijectorTest): + + def build_shapes(self, shape_in, shape_out): + shape_in_ph = array_ops.placeholder(shape=None, dtype=dtypes.int32) + shape_out_ph = array_ops.placeholder(shape=None, dtype=dtypes.int32) + feed_dict = {shape_in_ph: shape_in, shape_out_ph: shape_out} + return shape_in_ph, shape_out_ph, feed_dict + + def assertRaisesError(self, msg): + return self.assertRaisesOpError(msg) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/cauchy_test.py b/tensorflow/contrib/distributions/python/kernel_tests/cauchy_test.py new file mode 100644 index 0000000000000000000000000000000000000000..73747db31c86b67eaad5aeab7d5e80191e12b333 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/cauchy_test.py @@ -0,0 +1,438 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Cauchy.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import importlib +import numpy as np + +from tensorflow.contrib.distributions.python.ops import cauchy as cauchy_lib +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.platform import tf_logging + + +def try_import(name): # pylint: disable=invalid-name + module = None + try: + module = importlib.import_module(name) + except ImportError as e: + tf_logging.warning("Could not import %s: %s" % (name, str(e))) + return module + + +stats = try_import("scipy.stats") + + +class CauchyTest(test.TestCase): + + def setUp(self): + self._rng = np.random.RandomState(123) + + def assertAllFinite(self, tensor): + is_finite = np.isfinite(tensor.eval()) + all_true = np.ones_like(is_finite, dtype=np.bool) + self.assertAllEqual(all_true, is_finite) + + def _testParamShapes(self, sample_shape, expected): + with self.test_session(): + param_shapes = cauchy_lib.Cauchy.param_shapes(sample_shape) + loc_shape, scale_shape = param_shapes["loc"], param_shapes["scale"] + self.assertAllEqual(expected, loc_shape.eval()) + self.assertAllEqual(expected, scale_shape.eval()) + loc = array_ops.zeros(loc_shape) + scale = array_ops.ones(scale_shape) + self.assertAllEqual(expected, + array_ops.shape( + cauchy_lib.Cauchy(loc, scale).sample()).eval()) + + def _testParamStaticShapes(self, sample_shape, expected): + param_shapes = cauchy_lib.Cauchy.param_static_shapes(sample_shape) + loc_shape, scale_shape = param_shapes["loc"], param_shapes["scale"] + self.assertEqual(expected, loc_shape) + self.assertEqual(expected, scale_shape) + + def testParamShapes(self): + sample_shape = [10, 3, 4] + self._testParamShapes(sample_shape, sample_shape) + self._testParamShapes(constant_op.constant(sample_shape), sample_shape) + + def testParamStaticShapes(self): + sample_shape = [10, 3, 4] + self._testParamStaticShapes(sample_shape, sample_shape) + self._testParamStaticShapes( + tensor_shape.TensorShape(sample_shape), sample_shape) + + def testCauchyLogPDF(self): + with self.test_session(): + batch_size = 6 + loc = constant_op.constant([3.0] * batch_size) + scale = constant_op.constant([np.sqrt(10.0)] * batch_size) + x = np.array([-2.5, 2.5, 4.0, 0.0, -1.0, 2.0], dtype=np.float32) + cauchy = cauchy_lib.Cauchy(loc=loc, scale=scale) + + log_pdf = cauchy.log_prob(x) + self.assertAllEqual(cauchy.batch_shape_tensor().eval(), log_pdf.shape) + self.assertAllEqual(cauchy.batch_shape_tensor().eval(), + log_pdf.eval().shape) + self.assertAllEqual(cauchy.batch_shape, log_pdf.shape) + self.assertAllEqual(cauchy.batch_shape, log_pdf.eval().shape) + + pdf = cauchy.prob(x) + self.assertAllEqual(cauchy.batch_shape_tensor().eval(), pdf.shape) + self.assertAllEqual(cauchy.batch_shape_tensor().eval(), pdf.eval().shape) + self.assertAllEqual(cauchy.batch_shape, pdf.shape) + self.assertAllEqual(cauchy.batch_shape, pdf.eval().shape) + + if not stats: + return + expected_log_pdf = stats.cauchy(loc.eval(), scale.eval()).logpdf(x) + self.assertAllClose(expected_log_pdf, log_pdf.eval()) + self.assertAllClose(np.exp(expected_log_pdf), pdf.eval()) + + def testCauchyLogPDFMultidimensional(self): + with self.test_session(): + batch_size = 6 + loc = constant_op.constant([[3.0, -3.0]] * batch_size) + scale = constant_op.constant( + [[np.sqrt(10.0), np.sqrt(15.0)]] * batch_size) + x = np.array([[-2.5, 2.5, 4.0, 0.0, -1.0, 2.0]], dtype=np.float32).T + cauchy = cauchy_lib.Cauchy(loc=loc, scale=scale) + + log_pdf = cauchy.log_prob(x) + log_pdf_values = log_pdf.eval() + self.assertEqual(log_pdf.shape, (6, 2)) + self.assertAllEqual(cauchy.batch_shape_tensor().eval(), log_pdf.shape) + self.assertAllEqual(cauchy.batch_shape_tensor().eval(), + log_pdf.eval().shape) + self.assertAllEqual(cauchy.batch_shape, log_pdf.shape) + self.assertAllEqual(cauchy.batch_shape, log_pdf.eval().shape) + + pdf = cauchy.prob(x) + pdf_values = pdf.eval() + self.assertEqual(pdf.shape, (6, 2)) + self.assertAllEqual(cauchy.batch_shape_tensor().eval(), pdf.shape) + self.assertAllEqual(cauchy.batch_shape_tensor().eval(), pdf_values.shape) + self.assertAllEqual(cauchy.batch_shape, pdf.shape) + self.assertAllEqual(cauchy.batch_shape, pdf_values.shape) + + if not stats: + return + expected_log_pdf = stats.cauchy(loc.eval(), scale.eval()).logpdf(x) + self.assertAllClose(expected_log_pdf, log_pdf_values) + self.assertAllClose(np.exp(expected_log_pdf), pdf_values) + + def testCauchyCDF(self): + with self.test_session(): + batch_size = 50 + loc = self._rng.randn(batch_size) + scale = self._rng.rand(batch_size) + 1.0 + x = np.linspace(-8.0, 8.0, batch_size).astype(np.float64) + + cauchy = cauchy_lib.Cauchy(loc=loc, scale=scale) + cdf = cauchy.cdf(x) + self.assertAllEqual(cauchy.batch_shape_tensor().eval(), cdf.shape) + self.assertAllEqual(cauchy.batch_shape_tensor().eval(), cdf.eval().shape) + self.assertAllEqual(cauchy.batch_shape, cdf.shape) + self.assertAllEqual(cauchy.batch_shape, cdf.eval().shape) + if not stats: + return + expected_cdf = stats.cauchy(loc, scale).cdf(x) + self.assertAllClose(expected_cdf, cdf.eval(), atol=0) + + def testCauchySurvivalFunction(self): + with self.test_session(): + batch_size = 50 + loc = self._rng.randn(batch_size) + scale = self._rng.rand(batch_size) + 1.0 + x = np.linspace(-8.0, 8.0, batch_size).astype(np.float64) + + cauchy = cauchy_lib.Cauchy(loc=loc, scale=scale) + + sf = cauchy.survival_function(x) + self.assertAllEqual(cauchy.batch_shape_tensor().eval(), sf.shape) + self.assertAllEqual(cauchy.batch_shape_tensor().eval(), sf.eval().shape) + self.assertAllEqual(cauchy.batch_shape, sf.shape) + self.assertAllEqual(cauchy.batch_shape, sf.eval().shape) + if not stats: + return + expected_sf = stats.cauchy(loc, scale).sf(x) + self.assertAllClose(expected_sf, sf.eval(), atol=0) + + def testCauchyLogCDF(self): + with self.test_session(): + batch_size = 50 + loc = self._rng.randn(batch_size) + scale = self._rng.rand(batch_size) + 1.0 + x = np.linspace(-100.0, 10.0, batch_size).astype(np.float64) + + cauchy = cauchy_lib.Cauchy(loc=loc, scale=scale) + + cdf = cauchy.log_cdf(x) + self.assertAllEqual(cauchy.batch_shape_tensor().eval(), cdf.shape) + self.assertAllEqual(cauchy.batch_shape_tensor().eval(), cdf.eval().shape) + self.assertAllEqual(cauchy.batch_shape, cdf.shape) + self.assertAllEqual(cauchy.batch_shape, cdf.eval().shape) + + if not stats: + return + expected_cdf = stats.cauchy(loc, scale).logcdf(x) + self.assertAllClose(expected_cdf, cdf.eval(), atol=0, rtol=1e-5) + + def testFiniteGradientAtDifficultPoints(self): + for dtype in [np.float32, np.float64]: + g = ops.Graph() + with g.as_default(): + loc = variables.Variable(dtype(0.0)) + scale = variables.Variable(dtype(1.0)) + dist = cauchy_lib.Cauchy(loc=loc, scale=scale) + x = np.array([-100., -20., -5., 0., 5., 20., 100.]).astype(dtype) + for func in [ + dist.cdf, dist.log_cdf, dist.survival_function, + dist.log_survival_function, dist.log_prob, dist.prob + ]: + value = func(x) + grads = gradients_impl.gradients(value, [loc, scale]) + with self.test_session(graph=g): + variables.global_variables_initializer().run() + self.assertAllFinite(value) + self.assertAllFinite(grads[0]) + self.assertAllFinite(grads[1]) + + def testCauchyLogSurvivalFunction(self): + with self.test_session(): + batch_size = 50 + loc = self._rng.randn(batch_size) + scale = self._rng.rand(batch_size) + 1.0 + x = np.linspace(-10.0, 100.0, batch_size).astype(np.float64) + + cauchy = cauchy_lib.Cauchy(loc=loc, scale=scale) + + sf = cauchy.log_survival_function(x) + self.assertAllEqual(cauchy.batch_shape_tensor().eval(), sf.shape) + self.assertAllEqual(cauchy.batch_shape_tensor().eval(), sf.eval().shape) + self.assertAllEqual(cauchy.batch_shape, sf.shape) + self.assertAllEqual(cauchy.batch_shape, sf.eval().shape) + + if not stats: + return + expected_sf = stats.cauchy(loc, scale).logsf(x) + self.assertAllClose(expected_sf, sf.eval(), atol=0, rtol=1e-5) + + def testCauchyEntropy(self): + with self.test_session(): + loc = np.array([1.0, 1.0, 1.0]) + scale = np.array([[1.0, 2.0, 3.0]]) + cauchy = cauchy_lib.Cauchy(loc=loc, scale=scale) + + entropy = cauchy.entropy() + self.assertAllEqual(cauchy.batch_shape_tensor().eval(), entropy.shape) + self.assertAllEqual(cauchy.batch_shape_tensor().eval(), + entropy.eval().shape) + self.assertAllEqual(cauchy.batch_shape, entropy.shape) + self.assertAllEqual(cauchy.batch_shape, entropy.eval().shape) + + if not stats: + return + expected_entropy = stats.cauchy(loc, scale[0]).entropy().reshape((1, 3)) + self.assertAllClose(expected_entropy, entropy.eval()) + + def testCauchyMode(self): + with self.test_session(): + # Mu will be broadcast to [7, 7, 7]. + loc = [7.] + scale = [11., 12., 13.] + + cauchy = cauchy_lib.Cauchy(loc=loc, scale=scale) + + self.assertAllEqual((3,), cauchy.mode().shape) + self.assertAllEqual([7., 7, 7], cauchy.mode().eval()) + + def testCauchyMean(self): + with self.test_session(): + loc = [1., 2., 3.] + scale = [7.] + cauchy = cauchy_lib.Cauchy(loc=loc, scale=scale) + + self.assertAllEqual((3,), cauchy.mean().shape) + self.assertAllEqual([np.nan] * 3, cauchy.mean().eval()) + + def testCauchyNanMean(self): + with self.test_session(): + loc = [1., 2., 3.] + scale = [7.] + cauchy = cauchy_lib.Cauchy(loc=loc, scale=scale, allow_nan_stats=False) + + with self.assertRaises(ValueError): + cauchy.mean().eval() + + def testCauchyQuantile(self): + with self.test_session(): + batch_size = 50 + loc = self._rng.randn(batch_size) + scale = self._rng.rand(batch_size) + 1.0 + p = np.linspace(0.000001, 0.999999, batch_size).astype(np.float64) + + cauchy = cauchy_lib.Cauchy(loc=loc, scale=scale) + x = cauchy.quantile(p) + + self.assertAllEqual(cauchy.batch_shape_tensor().eval(), x.shape) + self.assertAllEqual(cauchy.batch_shape_tensor().eval(), x.eval().shape) + self.assertAllEqual(cauchy.batch_shape, x.shape) + self.assertAllEqual(cauchy.batch_shape, x.eval().shape) + + if not stats: + return + expected_x = stats.cauchy(loc, scale).ppf(p) + self.assertAllClose(expected_x, x.eval(), atol=0.) + + def testCauchyVariance(self): + with self.test_session(): + # scale will be broadcast to [7, 7, 7] + loc = [1., 2., 3.] + scale = [7.] + cauchy = cauchy_lib.Cauchy(loc=loc, scale=scale) + + self.assertAllEqual((3,), cauchy.variance().shape) + self.assertAllEqual([np.nan] * 3, cauchy.variance().eval()) + + def testCauchyNanVariance(self): + with self.test_session(): + # scale will be broadcast to [7, 7, 7] + loc = [1., 2., 3.] + scale = [7.] + cauchy = cauchy_lib.Cauchy(loc=loc, scale=scale, allow_nan_stats=False) + + with self.assertRaises(ValueError): + cauchy.variance().eval() + + def testCauchyStandardDeviation(self): + with self.test_session(): + # scale will be broadcast to [7, 7, 7] + loc = [1., 2., 3.] + scale = [7.] + cauchy = cauchy_lib.Cauchy(loc=loc, scale=scale) + + self.assertAllEqual((3,), cauchy.stddev().shape) + self.assertAllEqual([np.nan] * 3, cauchy.stddev().eval()) + + def testCauchyNanStandardDeviation(self): + with self.test_session(): + # scale will be broadcast to [7, 7, 7] + loc = [1., 2., 3.] + scale = [7.] + cauchy = cauchy_lib.Cauchy(loc=loc, scale=scale, allow_nan_stats=False) + + with self.assertRaises(ValueError): + cauchy.stddev().eval() + + def testCauchySample(self): + with self.test_session(): + loc = constant_op.constant(3.0) + scale = constant_op.constant(1.0) + loc_v = 3.0 + n = constant_op.constant(100000) + cauchy = cauchy_lib.Cauchy(loc=loc, scale=scale) + samples = cauchy.sample(n) + sample_values = samples.eval() + + self.assertEqual(sample_values.shape, (100000,)) + self.assertAllClose(np.median(sample_values), loc_v, atol=1e-1) + + expected_shape = tensor_shape.TensorShape([n.eval()]).concatenate( + tensor_shape.TensorShape(cauchy.batch_shape_tensor().eval())) + + self.assertAllEqual(expected_shape, samples.shape) + self.assertAllEqual(expected_shape, sample_values.shape) + + expected_shape = ( + tensor_shape.TensorShape([n.eval()]).concatenate(cauchy.batch_shape)) + + self.assertAllEqual(expected_shape, samples.shape) + self.assertAllEqual(expected_shape, sample_values.shape) + + def testCauchySampleMultiDimensional(self): + with self.test_session(): + batch_size = 2 + loc = constant_op.constant([[3.0, -3.0]] * batch_size) + scale = constant_op.constant([[0.5, 1.0]] * batch_size) + loc_v = [3.0, -3.0] + n = constant_op.constant(100000) + cauchy = cauchy_lib.Cauchy(loc=loc, scale=scale) + samples = cauchy.sample(n) + sample_values = samples.eval() + self.assertEqual(samples.shape, (100000, batch_size, 2)) + self.assertAllClose( + np.median(sample_values[:, 0, 0]), loc_v[0], atol=1e-1) + self.assertAllClose( + np.median(sample_values[:, 0, 1]), loc_v[1], atol=1e-1) + + expected_shape = tensor_shape.TensorShape([n.eval()]).concatenate( + tensor_shape.TensorShape(cauchy.batch_shape_tensor().eval())) + self.assertAllEqual(expected_shape, samples.shape) + self.assertAllEqual(expected_shape, sample_values.shape) + + expected_shape = ( + tensor_shape.TensorShape([n.eval()]).concatenate(cauchy.batch_shape)) + self.assertAllEqual(expected_shape, samples.shape) + self.assertAllEqual(expected_shape, sample_values.shape) + + def testCauchyNegativeLocFails(self): + with self.test_session(): + cauchy = cauchy_lib.Cauchy(loc=[1.], scale=[-5.], validate_args=True) + with self.assertRaisesOpError("Condition x > 0 did not hold"): + cauchy.mode().eval() + + def testCauchyShape(self): + with self.test_session(): + loc = constant_op.constant([-3.0] * 5) + scale = constant_op.constant(11.0) + cauchy = cauchy_lib.Cauchy(loc=loc, scale=scale) + + self.assertEqual(cauchy.batch_shape_tensor().eval(), [5]) + self.assertEqual(cauchy.batch_shape, tensor_shape.TensorShape([5])) + self.assertAllEqual(cauchy.event_shape_tensor().eval(), []) + self.assertEqual(cauchy.event_shape, tensor_shape.TensorShape([])) + + def testCauchyShapeWithPlaceholders(self): + loc = array_ops.placeholder(dtype=dtypes.float32) + scale = array_ops.placeholder(dtype=dtypes.float32) + cauchy = cauchy_lib.Cauchy(loc=loc, scale=scale) + + with self.test_session() as sess: + # get_batch_shape should return an "" tensor. + self.assertEqual(cauchy.batch_shape, tensor_shape.TensorShape(None)) + self.assertEqual(cauchy.event_shape, ()) + self.assertAllEqual(cauchy.event_shape_tensor().eval(), []) + self.assertAllEqual( + sess.run( + cauchy.batch_shape_tensor(), + feed_dict={ + loc: 5.0, + scale: [1.0, 2.0] + }), [2]) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/half_normal_test.py b/tensorflow/contrib/distributions/python/kernel_tests/half_normal_test.py new file mode 100644 index 0000000000000000000000000000000000000000..a4e75660083dc2edd1759a3a54e221d9e8a268c3 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/half_normal_test.py @@ -0,0 +1,320 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for initializers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import importlib +import numpy as np + +from tensorflow.contrib.distributions.python.ops import half_normal as hn_lib +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.platform import tf_logging + + +def try_import(name): # pylint: disable=invalid-name + module = None + try: + module = importlib.import_module(name) + except ImportError as e: + tf_logging.warning("Could not import %s: %s" % (name, str(e))) + return module + +stats = try_import("scipy.stats") + + +class HalfNormalTest(test.TestCase): + + def setUp(self): + self._rng = np.random.RandomState(123) + + def assertAllFinite(self, tensor): + is_finite = np.isfinite(tensor.eval()) + all_true = np.ones_like(is_finite, dtype=np.bool) + self.assertAllEqual(all_true, is_finite) + + def _testParamShapes(self, sample_shape, expected): + with self.test_session(): + param_shapes = hn_lib.HalfNormal.param_shapes(sample_shape) + scale_shape = param_shapes["scale"] + self.assertAllEqual(expected, scale_shape.eval()) + scale = array_ops.ones(scale_shape) + self.assertAllEqual( + expected, + array_ops.shape(hn_lib.HalfNormal(scale).sample()).eval()) + + def _testParamStaticShapes(self, sample_shape, expected): + param_shapes = hn_lib.HalfNormal.param_static_shapes(sample_shape) + scale_shape = param_shapes["scale"] + self.assertEqual(expected, scale_shape) + + def _testBatchShapes(self, dist, tensor): + self.assertAllEqual(dist.batch_shape_tensor().eval(), tensor.shape) + self.assertAllEqual(dist.batch_shape_tensor().eval(), tensor.eval().shape) + self.assertAllEqual(dist.batch_shape, tensor.shape) + self.assertAllEqual(dist.batch_shape, tensor.eval().shape) + + def testParamShapes(self): + sample_shape = [10, 3, 4] + self._testParamShapes(sample_shape, sample_shape) + self._testParamShapes(constant_op.constant(sample_shape), sample_shape) + + def testParamStaticShapes(self): + sample_shape = [10, 3, 4] + self._testParamStaticShapes(sample_shape, sample_shape) + self._testParamStaticShapes( + tensor_shape.TensorShape(sample_shape), sample_shape) + + def testHalfNormalLogPDF(self): + with self.test_session(): + batch_size = 6 + scale = constant_op.constant([3.0] * batch_size) + x = np.array([-2.5, 2.5, 4.0, 0.0, -1.0, 2.0], dtype=np.float32) + halfnorm = hn_lib.HalfNormal(scale=scale) + + log_pdf = halfnorm.log_prob(x) + self._testBatchShapes(halfnorm, log_pdf) + + pdf = halfnorm.prob(x) + self._testBatchShapes(halfnorm, pdf) + + if not stats: + return + expected_log_pdf = stats.halfnorm(scale=scale.eval()).logpdf(x) + self.assertAllClose(expected_log_pdf, log_pdf.eval()) + self.assertAllClose(np.exp(expected_log_pdf), pdf.eval()) + + def testHalfNormalLogPDFMultidimensional(self): + with self.test_session(): + batch_size = 6 + scale = constant_op.constant([[3.0, 1.0]] * batch_size) + x = np.array([[-2.5, 2.5, 4.0, 0.0, -1.0, 2.0]], dtype=np.float32).T + halfnorm = hn_lib.HalfNormal(scale=scale) + + log_pdf = halfnorm.log_prob(x) + self._testBatchShapes(halfnorm, log_pdf) + + pdf = halfnorm.prob(x) + self._testBatchShapes(halfnorm, pdf) + + if not stats: + return + expected_log_pdf = stats.halfnorm(scale=scale.eval()).logpdf(x) + self.assertAllClose(expected_log_pdf, log_pdf.eval()) + self.assertAllClose(np.exp(expected_log_pdf), pdf.eval()) + + def testHalfNormalCDF(self): + with self.test_session(): + batch_size = 50 + scale = self._rng.rand(batch_size) + 1.0 + x = np.linspace(-8.0, 8.0, batch_size).astype(np.float64) + halfnorm = hn_lib.HalfNormal(scale=scale) + + cdf = halfnorm.cdf(x) + self._testBatchShapes(halfnorm, cdf) + + log_cdf = halfnorm.log_cdf(x) + self._testBatchShapes(halfnorm, log_cdf) + + if not stats: + return + expected_logcdf = stats.halfnorm(scale=scale).logcdf(x) + self.assertAllClose(expected_logcdf, log_cdf.eval(), atol=0) + self.assertAllClose(np.exp(expected_logcdf), cdf.eval(), atol=0) + + def testHalfNormalSurvivalFunction(self): + with self.test_session(): + batch_size = 50 + scale = self._rng.rand(batch_size) + 1.0 + x = np.linspace(-8.0, 8.0, batch_size).astype(np.float64) + halfnorm = hn_lib.HalfNormal(scale=scale) + + sf = halfnorm.survival_function(x) + self._testBatchShapes(halfnorm, sf) + + log_sf = halfnorm.log_survival_function(x) + self._testBatchShapes(halfnorm, log_sf) + + if not stats: + return + expected_logsf = stats.halfnorm(scale=scale).logsf(x) + self.assertAllClose(expected_logsf, log_sf.eval(), atol=0) + self.assertAllClose(np.exp(expected_logsf), sf.eval(), atol=0) + + def testHalfNormalQuantile(self): + with self.test_session(): + batch_size = 50 + scale = self._rng.rand(batch_size) + 1.0 + p = np.linspace(0., 1.0, batch_size).astype(np.float64) + + halfnorm = hn_lib.HalfNormal(scale=scale) + x = halfnorm.quantile(p) + self._testBatchShapes(halfnorm, x) + + if not stats: + return + expected_x = stats.halfnorm(scale=scale).ppf(p) + self.assertAllClose(expected_x, x.eval(), atol=0) + + def testFiniteGradients(self): + for dtype in [np.float32, np.float64]: + g = ops.Graph() + with g.as_default(): + scale = variables.Variable(dtype(3.0)) + dist = hn_lib.HalfNormal(scale=scale) + x = np.array([0.01, 0.1, 1., 5., 10.]).astype(dtype) + for func in [ + dist.cdf, dist.log_cdf, dist.survival_function, + dist.log_prob, dist.prob, dist.log_survival_function, + ]: + print(func.__name__) + value = func(x) + grads = gradients_impl.gradients(value, [scale]) + with self.test_session(graph=g): + variables.global_variables_initializer().run() + self.assertAllFinite(value) + self.assertAllFinite(grads[0]) + + def testHalfNormalEntropy(self): + with self.test_session(): + scale = np.array([[1.0, 2.0, 3.0]]) + halfnorm = hn_lib.HalfNormal(scale=scale) + + # See https://en.wikipedia.org/wiki/Half-normal_distribution for the + # entropy formula used here. + expected_entropy = 0.5 * np.log(np.pi * scale ** 2.0 / 2.0) + 0.5 + + entropy = halfnorm.entropy() + self._testBatchShapes(halfnorm, entropy) + self.assertAllClose(expected_entropy, entropy.eval()) + + def testHalfNormalMeanAndMode(self): + with self.test_session(): + scale = np.array([11., 12., 13.]) + + halfnorm = hn_lib.HalfNormal(scale=scale) + expected_mean = scale * np.sqrt(2.0) / np.sqrt(np.pi) + + self.assertAllEqual((3,), halfnorm.mean().eval().shape) + self.assertAllEqual(expected_mean, halfnorm.mean().eval()) + + self.assertAllEqual((3,), halfnorm.mode().eval().shape) + self.assertAllEqual([0., 0., 0.], halfnorm.mode().eval()) + + def testHalfNormalVariance(self): + with self.test_session(): + scale = np.array([7., 7., 7.]) + halfnorm = hn_lib.HalfNormal(scale=scale) + expected_variance = scale ** 2.0 * (1.0 - 2.0 / np.pi) + + self.assertAllEqual((3,), halfnorm.variance().eval().shape) + self.assertAllEqual(expected_variance, halfnorm.variance().eval()) + + def testHalfNormalStandardDeviation(self): + with self.test_session(): + scale = np.array([7., 7., 7.]) + halfnorm = hn_lib.HalfNormal(scale=scale) + expected_variance = scale ** 2.0 * (1.0 - 2.0 / np.pi) + + self.assertAllEqual((3,), halfnorm.stddev().shape) + self.assertAllEqual(np.sqrt(expected_variance), halfnorm.stddev().eval()) + + def testHalfNormalSample(self): + with self.test_session(): + scale = constant_op.constant(3.0) + n = constant_op.constant(100000) + halfnorm = hn_lib.HalfNormal(scale=scale) + + sample = halfnorm.sample(n) + + self.assertEqual(sample.eval().shape, (100000,)) + self.assertAllClose(sample.eval().mean(), + 3.0 * np.sqrt(2.0) / np.sqrt(np.pi), atol=1e-1) + + expected_shape = tensor_shape.TensorShape([n.eval()]).concatenate( + tensor_shape.TensorShape(halfnorm.batch_shape_tensor().eval())) + self.assertAllEqual(expected_shape, sample.shape) + self.assertAllEqual(expected_shape, sample.eval().shape) + + expected_shape_static = (tensor_shape.TensorShape( + [n.eval()]).concatenate(halfnorm.batch_shape)) + self.assertAllEqual(expected_shape_static, sample.shape) + self.assertAllEqual(expected_shape_static, sample.eval().shape) + + def testHalfNormalSampleMultiDimensional(self): + with self.test_session(): + batch_size = 2 + scale = constant_op.constant([[2.0, 3.0]] * batch_size) + n = constant_op.constant(100000) + halfnorm = hn_lib.HalfNormal(scale=scale) + + sample = halfnorm.sample(n) + self.assertEqual(sample.shape, (100000, batch_size, 2)) + self.assertAllClose(sample.eval()[:, 0, 0].mean(), + 2.0 * np.sqrt(2.0) / np.sqrt(np.pi), atol=1e-1) + self.assertAllClose(sample.eval()[:, 0, 1].mean(), + 3.0 * np.sqrt(2.0) / np.sqrt(np.pi), atol=1e-1) + + expected_shape = tensor_shape.TensorShape([n.eval()]).concatenate( + tensor_shape.TensorShape(halfnorm.batch_shape_tensor().eval())) + self.assertAllEqual(expected_shape, sample.shape) + self.assertAllEqual(expected_shape, sample.eval().shape) + + expected_shape_static = (tensor_shape.TensorShape( + [n.eval()]).concatenate(halfnorm.batch_shape)) + self.assertAllEqual(expected_shape_static, sample.shape) + self.assertAllEqual(expected_shape_static, sample.eval().shape) + + def testNegativeSigmaFails(self): + with self.test_session(): + halfnorm = hn_lib.HalfNormal(scale=[-5.], validate_args=True, name="G") + with self.assertRaisesOpError("Condition x > 0 did not hold"): + halfnorm.mean().eval() + + def testHalfNormalShape(self): + with self.test_session(): + scale = constant_op.constant([6.0] * 5) + halfnorm = hn_lib.HalfNormal(scale=scale) + + self.assertEqual(halfnorm.batch_shape_tensor().eval(), [5]) + self.assertEqual(halfnorm.batch_shape, tensor_shape.TensorShape([5])) + self.assertAllEqual(halfnorm.event_shape_tensor().eval(), []) + self.assertEqual(halfnorm.event_shape, tensor_shape.TensorShape([])) + + def testHalfNormalShapeWithPlaceholders(self): + scale = array_ops.placeholder(dtype=dtypes.float32) + halfnorm = hn_lib.HalfNormal(scale=scale) + + with self.test_session() as sess: + # get_batch_shape should return an "" tensor. + self.assertEqual(halfnorm.batch_shape, tensor_shape.TensorShape(None)) + self.assertEqual(halfnorm.event_shape, ()) + self.assertAllEqual(halfnorm.event_shape_tensor().eval(), []) + self.assertAllEqual( + sess.run(halfnorm.batch_shape_tensor(), + feed_dict={scale: [1.0, 2.0]}), [2]) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/mixture_same_family_test.py b/tensorflow/contrib/distributions/python/kernel_tests/mixture_same_family_test.py index ece6bc077d9e21502fdfd01300a9d3e9f2c9c380..ff6092fc260660b512e8123823c63e98a023af6d 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/mixture_same_family_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/mixture_same_family_test.py @@ -45,6 +45,17 @@ class MixtureSameFamilyTest(test_util.VectorDistributionTestHelpers, self.assertEqual([4, 5], x.shape) self.assertEqual([4, 5], log_prob_x.shape) + def testSampleAndLogProbBatch(self): + with self.test_session(): + gm = mixture_same_family_lib.MixtureSameFamily( + mixture_distribution=categorical_lib.Categorical(probs=[[0.3, 0.7]]), + components_distribution=normal_lib.Normal( + loc=[[-1., 1]], scale=[[0.1, 0.5]])) + x = gm.sample([4, 5], seed=42) + log_prob_x = gm.log_prob(x) + self.assertEqual([4, 5, 1], x.shape) + self.assertEqual([4, 5, 1], log_prob_x.shape) + def testSampleAndLogProbShapesBroadcastMix(self): mix_probs = np.float32([.3, .7]) bern_probs = np.float32([[.4, .6], [.25, .75]]) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py b/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py index 103d8e186221e879d1734a097114708429f725bd..cbaf74d3f66253ae5727e1ba579e2d49235b748e 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py @@ -200,6 +200,27 @@ class TransformedDistributionTest(test.TestCase): self.assertAllEqual([2], multi_logit_normal.event_shape) self.assertAllEqual([2], multi_logit_normal.event_shape_tensor().eval()) + def testCastLogDetJacobian(self): + """Test log_prob when Jacobian and log_prob dtypes do not match.""" + + with self.test_session(): + # Create an identity bijector whose jacobians have dtype int32 + int_identity = bs.Inline( + forward_fn=array_ops.identity, + inverse_fn=array_ops.identity, + inverse_log_det_jacobian_fn=lambda x: math_ops.cast(0, dtypes.int32), + forward_log_det_jacobian_fn=lambda x: math_ops.cast(0, dtypes.int32), + is_constant_jacobian=True) + normal = self._cls()( + distribution=ds.Normal(loc=0., scale=1.), + bijector=int_identity, + validate_args=True) + + y = normal.sample() + normal.log_prob(y).eval() + normal.prob(y).eval() + normal.entropy().eval() + def testEntropy(self): with self.test_session(): shift = np.array([[-1, 0, 1], [-1, -2, -3]], dtype=np.float32) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value.py b/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value.py index 6049419818e18c54209f0be95d41fcecf6627b7e..0fe9f6aa78fbe845b99d0668f075b0162ec2a9f7 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value.py @@ -18,12 +18,117 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.absolute_value_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bijector -_allowed_symbols = ["AbsoluteValue"] +__all__ = [ + "AbsoluteValue", +] -remove_undocumented(__name__, _allowed_symbols) + +class AbsoluteValue(bijector.Bijector): + """Computes `Y = g(X) = Abs(X)`, element-wise. + + This non-injective bijector allows for transformations of scalar distributions + with the absolute value function, which maps `(-inf, inf)` to `[0, inf)`. + + * For `y in (0, inf)`, `AbsoluteValue.inverse(y)` returns the set inverse + `{x in (-inf, inf) : |x| = y}` as a tuple, `-y, y`. + * `AbsoluteValue.inverse(0)` returns `0, 0`, which is not the set inverse + (the set inverse is the singleton `{0}`), but "works" in conjunction with + `TransformedDistribution` to produce a left semi-continuous pdf. + * For `y < 0`, `AbsoluteValue.inverse(y)` happily returns the + wrong thing, `-y, y`. This is done for efficiency. If + `validate_args == True`, `y < 0` will raise an exception. + + + ```python + tfd = tf.contrib.distributions + + abs = tfd.bijectors.AbsoluteValue() + + abs.forward([-1., 0., 1.]) + ==> [1., 0., 1.] + + abs.inverse(1.) + ==> [-1., 1.] + + # The |dX/dY| is constant, == 1. So Log|dX/dY| == 0. + abs.inverse_log_det_jacobian(1.) + ==> [0., 0.] + + # Special case handling of 0. + abs.inverse(0.) + ==> [0., 0.] + + abs.inverse_log_det_jacobian(0.) + ==> [0., 0.] + ``` + + """ + + def __init__(self, event_ndims=0, validate_args=False, name="absolute_value"): + """Instantiates the `AbsoluteValue` bijector. + + Args: + event_ndims: Python scalar indicating the number of dimensions associated + with a particular draw from the distribution. Currently only zero is + supported. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness, in particular whether inputs to `inverse` and + `inverse_log_det_jacobian` are non-negative. + name: Python `str` name given to ops managed by this object. + + Raises: + ValueError: If `event_ndims` is not zero. + """ + self._graph_parents = [] + self._name = name + + event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims") + event_ndims_const = tensor_util.constant_value(event_ndims) + if event_ndims_const is not None and event_ndims_const not in (0,): + raise ValueError("event_ndims(%s) was not 0" % event_ndims_const) + else: + if validate_args: + event_ndims = control_flow_ops.with_dependencies( + [check_ops.assert_equal( + event_ndims, 0, message="event_ndims was not 0")], + event_ndims) + + with self._name_scope("init"): + super(AbsoluteValue, self).__init__( + event_ndims=event_ndims, + validate_args=validate_args, + name=name) + + def _forward(self, x): + return math_ops.abs(x) + + def _inverse(self, y): + if self.validate_args: + y = control_flow_ops.with_dependencies( + [check_ops.assert_non_negative(y, message="Argument y was negative")], + y) + return -y, y + + def _inverse_log_det_jacobian(self, y): + # If event_ndims = 2, + # F^{-1}(y) = (-y, y), so DF^{-1}(y) = (-1, 1), + # so Log|DF^{-1}(y)| = Log[1, 1] = [0, 0]. + batch_shape = array_ops.shape(y)[:array_ops.rank(y) - self.event_ndims] + zeros = array_ops.zeros(batch_shape, dtype=y.dtype) + if self.validate_args: + zeros = control_flow_ops.with_dependencies( + [check_ops.assert_non_negative(y, message="Argument y was negative")], + zeros) + return zeros, zeros + + @property + def _is_injective(self): + return False diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value_impl.py deleted file mode 100644 index b84502003ab6c0c4ffdda21eea162f441509e1fa..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value_impl.py +++ /dev/null @@ -1,132 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""AbsoluteValue bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops.distributions import bijector - -__all__ = [ - "AbsoluteValue", -] - - -class AbsoluteValue(bijector.Bijector): - """Computes `Y = g(X) = Abs(X)`, element-wise. - - This non-injective bijector allows for transformations of scalar distributions - with the absolute value function, which maps `(-inf, inf)` to `[0, inf)`. - - * For `y in (0, inf)`, `AbsoluteValue.inverse(y)` returns the set inverse - `{x in (-inf, inf) : |x| = y}` as a tuple, `-y, y`. - * `AbsoluteValue.inverse(0)` returns `0, 0`, which is not the set inverse - (the set inverse is the singleton `{0}`), but "works" in conjunction with - `TransformedDistribution` to produce a left semi-continuous pdf. - * For `y < 0`, `AbsoluteValue.inverse(y)` happily returns the - wrong thing, `-y, y`. This is done for efficiency. If - `validate_args == True`, `y < 0` will raise an exception. - - - ```python - abs = ds.bijectors.AbsoluteValue() - - abs.forward([-1., 0., 1.]) - ==> [1., 0., 1.] - - abs.inverse(1.) - ==> [-1., 1.] - - # The |dX/dY| is constant, == 1. So Log|dX/dY| == 0. - abs.inverse_log_det_jacobian(1.) - ==> [0., 0.] - - # Special case handling of 0. - abs.inverse(0.) - ==> [0., 0.] - - abs.inverse_log_det_jacobian(0.) - ==> [0., 0.] - ``` - - """ - - def __init__(self, event_ndims=0, validate_args=False, name="absolute_value"): - """Instantiates the `AbsoluteValue` bijector. - - Args: - event_ndims: Python scalar indicating the number of dimensions associated - with a particular draw from the distribution. Currently only zero is - supported. - validate_args: Python `bool` indicating whether arguments should be - checked for correctness, in particular whether inputs to `inverse` and - `inverse_log_det_jacobian` are non-negative. - name: Python `str` name given to ops managed by this object. - - Raises: - ValueError: If `event_ndims` is not zero. - """ - self._graph_parents = [] - self._name = name - - event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims") - event_ndims_const = tensor_util.constant_value(event_ndims) - if event_ndims_const is not None and event_ndims_const not in (0,): - raise ValueError("event_ndims(%s) was not 0" % event_ndims_const) - else: - if validate_args: - event_ndims = control_flow_ops.with_dependencies( - [check_ops.assert_equal( - event_ndims, 0, message="event_ndims was not 0")], - event_ndims) - - with self._name_scope("init"): - super(AbsoluteValue, self).__init__( - event_ndims=event_ndims, - validate_args=validate_args, - name=name) - - def _forward(self, x): - return math_ops.abs(x) - - def _inverse(self, y): - if self.validate_args: - y = control_flow_ops.with_dependencies( - [check_ops.assert_non_negative(y, message="Argument y was negative")], - y) - return -y, y - - def _inverse_log_det_jacobian(self, y): - # If event_ndims = 2, - # F^{-1}(y) = (-y, y), so DF^{-1}(y) = (-1, 1), - # so Log|DF^{-1}(y)| = Log[1, 1] = [0, 0]. - batch_shape = array_ops.shape(y)[:array_ops.rank(y) - self.event_ndims] - zeros = array_ops.zeros(batch_shape, dtype=y.dtype) - if self.validate_args: - zeros = control_flow_ops.with_dependencies( - [check_ops.assert_non_negative(y, message="Argument y was negative")], - zeros) - return zeros, zeros - - @property - def _is_injective(self): - return False diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/affine.py b/tensorflow/contrib/distributions/python/ops/bijectors/affine.py index 940cceff04e77cfc2f7caae5a798d135f7601b95..05bb9c2f9bdf35e222c94db3491157893da64ebd 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/affine.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/affine.py @@ -18,12 +18,386 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.affine_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.contrib import linalg +from tensorflow.contrib.distributions.python.ops import distribution_util +from tensorflow.contrib.distributions.python.ops.shape import _DistributionShape +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bijector -_allowed_symbols = ["Affine"] -remove_undocumented(__name__, _allowed_symbols) +__all__ = [ + "Affine", +] + + +def _as_tensor(x, name): + """Convenience to convert to `Tensor` or leave as `None`.""" + return None if x is None else ops.convert_to_tensor(x, name=name) + + +class Affine(bijector.Bijector): + """Compute `Y = g(X; shift, scale) = scale @ X + shift`. + + Here `scale = c * I + diag(D1) + tril(L) + V @ diag(D2) @ V.T`. + + In TF parlance, the `scale` term is logically equivalent to: + + ```python + scale = ( + scale_identity_multiplier * tf.diag(tf.ones(d)) + + tf.diag(scale_diag) + + scale_tril + + scale_perturb_factor @ diag(scale_perturb_diag) @ + tf.transpose([scale_perturb_factor]) + ) + ``` + + The `scale` term is applied without necessarily materializing constituent + matrices, i.e., the matmul is [matrix-free]( + https://en.wikipedia.org/wiki/Matrix-free_methods) when possible. + + Examples: + + ```python + # Y = X + b = Affine() + + # Y = X + shift + b = Affine(shift=[1., 2, 3]) + + # Y = 2 * I @ X.T + shift + b = Affine(shift=[1., 2, 3], + scale_identity_multiplier=2.) + + # Y = tf.diag(d1) @ X.T + shift + b = Affine(shift=[1., 2, 3], + scale_diag=[-1., 2, 1]) # Implicitly 3x3. + + # Y = (I + v * v.T) @ X.T + shift + b = Affine(shift=[1., 2, 3], + scale_perturb_factor=[[1., 0], + [0, 1], + [1, 1]]) + + # Y = (diag(d1) + v * diag(d2) * v.T) @ X.T + shift + b = Affine(shift=[1., 2, 3], + scale_diag=[1., 3, 3], # Implicitly 3x3. + scale_perturb_diag=[2., 1], # Implicitly 2x2. + scale_perturb_factor=[[1., 0], + [0, 1], + [1, 1]]) + + ``` + + """ + + def __init__(self, + shift=None, + scale_identity_multiplier=None, + scale_diag=None, + scale_tril=None, + scale_perturb_factor=None, + scale_perturb_diag=None, + event_ndims=1, + validate_args=False, + name="affine"): + """Instantiates the `Affine` bijector. + + This `Bijector` is initialized with `shift` `Tensor` and `scale` arguments, + giving the forward operation: + + ```none + Y = g(X) = scale @ X + shift + ``` + + where the `scale` term is logically equivalent to: + + ```python + scale = ( + scale_identity_multiplier * tf.diag(tf.ones(d)) + + tf.diag(scale_diag) + + scale_tril + + scale_perturb_factor @ diag(scale_perturb_diag) @ + tf.transpose([scale_perturb_factor]) + ) + ``` + + If none of `scale_identity_multiplier`, `scale_diag`, or `scale_tril` are + specified then `scale += IdentityMatrix`. Otherwise specifying a + `scale` argument has the semantics of `scale += Expand(arg)`, i.e., + `scale_diag != None` means `scale += tf.diag(scale_diag)`. + + Args: + shift: Floating-point `Tensor`. If this is set to `None`, no shift is + applied. + scale_identity_multiplier: floating point rank 0 `Tensor` representing a + scaling done to the identity matrix. + When `scale_identity_multiplier = scale_diag = scale_tril = None` then + `scale += IdentityMatrix`. Otherwise no scaled-identity-matrix is added + to `scale`. + scale_diag: Floating-point `Tensor` representing the diagonal matrix. + `scale_diag` has shape [N1, N2, ... k], which represents a k x k + diagonal matrix. + When `None` no diagonal term is added to `scale`. + scale_tril: Floating-point `Tensor` representing the diagonal matrix. + `scale_diag` has shape [N1, N2, ... k, k], which represents a k x k + lower triangular matrix. + When `None` no `scale_tril` term is added to `scale`. + The upper triangular elements above the diagonal are ignored. + scale_perturb_factor: Floating-point `Tensor` representing factor matrix + with last two dimensions of shape `(k, r)`. When `None`, no rank-r + update is added to `scale`. + scale_perturb_diag: Floating-point `Tensor` representing the diagonal + matrix. `scale_perturb_diag` has shape [N1, N2, ... r], which + represents an `r x r` diagonal matrix. When `None` low rank updates will + take the form `scale_perturb_factor * scale_perturb_factor.T`. + event_ndims: Scalar `int` `Tensor` indicating the number of dimensions + associated with a particular draw from the distribution. Must be 0 or 1. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + + Raises: + ValueError: if `perturb_diag` is specified but not `perturb_factor`. + TypeError: if `shift` has different `dtype` from `scale` arguments. + """ + self._graph_parents = [] + self._name = name + self._validate_args = validate_args + + # Ambiguous definition of low rank update. + if scale_perturb_diag is not None and scale_perturb_factor is None: + raise ValueError("When scale_perturb_diag is specified, " + "scale_perturb_factor must be specified.") + + # Special case, only handling a scaled identity matrix. We don't know its + # dimensions, so this is special cased. + # We don't check identity_multiplier, since below we set it to 1. if all + # other scale args are None. + self._is_only_identity_multiplier = (scale_tril is None and + scale_diag is None and + scale_perturb_factor is None) + + with self._name_scope("init", values=[ + shift, scale_identity_multiplier, scale_diag, scale_tril, + scale_perturb_diag, scale_perturb_factor]): + event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims") + event_ndims_const = tensor_util.constant_value(event_ndims) + if event_ndims_const is not None and event_ndims_const not in (0, 1): + raise ValueError("event_ndims(%s) was not 0 or 1" % event_ndims_const) + else: + if validate_args: + # Shape tool will catch if event_ndims is negative. + event_ndims = control_flow_ops.with_dependencies( + [check_ops.assert_less( + event_ndims, 2, message="event_ndims must be 0 or 1")], + event_ndims) + + if event_ndims_const == 0 and not self._is_only_identity_multiplier: + raise ValueError( + "If event_ndims == 0, the only scale argument you can pass is " + "scale_identity_multiplier. All others operate on vectors.") + + # In the absence of `loc` and `scale`, we'll assume `dtype` is `float32`. + dtype = dtypes.float32 + + if shift is not None: + shift = ops.convert_to_tensor(shift, name="shift") + dtype = shift.dtype.base_dtype + self._shift = shift + + # When no args are specified, pretend the scale matrix is the identity + # matrix. + if (self._is_only_identity_multiplier and + scale_identity_multiplier is None): + scale_identity_multiplier = ops.convert_to_tensor(1., dtype=dtype) + + # self._create_scale_operator returns a LinearOperator in all cases + # except if self._is_only_identity_multiplier; in which case it + # returns a scalar Tensor. + scale = self._create_scale_operator( + identity_multiplier=scale_identity_multiplier, + diag=scale_diag, + tril=scale_tril, + perturb_diag=scale_perturb_diag, + perturb_factor=scale_perturb_factor, + shift=shift, + validate_args=validate_args) + + if scale.dtype is not None: + dtype = scale.dtype.base_dtype + + if scale is not None and not self._is_only_identity_multiplier: + if (shift is not None and + shift.dtype.base_dtype != scale.dtype.base_dtype): + raise TypeError( + "shift.dtype({}) is incompatible with scale.dtype({}).".format( + shift.dtype, scale.dtype)) + + if scale.tensor_rank is not None: + batch_ndims = scale.tensor_rank - 2 + else: + batch_ndims = scale.tensor_rank_tensor() - 2 + else: + # We won't need shape inference when scale is None or when scale is a + # scalar. + batch_ndims = 0 + self._scale = scale + self._shaper = _DistributionShape( + batch_ndims=batch_ndims, + event_ndims=event_ndims, + validate_args=validate_args) + super(Affine, self).__init__( + event_ndims=event_ndims, + graph_parents=( + [event_ndims] + + [self._scale] if tensor_util.is_tensor(self._scale) + else self._scale.graph_parents + + [self._shift] if self._shift is not None else []), + is_constant_jacobian=True, + dtype=dtype, + validate_args=validate_args, + name=name) + + def _create_scale_operator(self, identity_multiplier, diag, tril, + perturb_diag, perturb_factor, shift, + validate_args): + """Construct `scale` from various components. + + Args: + identity_multiplier: floating point rank 0 `Tensor` representing a scaling + done to the identity matrix. + diag: Floating-point `Tensor` representing the diagonal matrix. + `scale_diag` has shape [N1, N2, ... k], which represents a k x k + diagonal matrix. + tril: Floating-point `Tensor` representing the diagonal matrix. + `scale_tril` has shape [N1, N2, ... k], which represents a k x k lower + triangular matrix. + perturb_diag: Floating-point `Tensor` representing the diagonal matrix of + the low rank update. + perturb_factor: Floating-point `Tensor` representing factor matrix. + shift: Floating-point `Tensor` representing `shift in `scale @ X + shift`. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + + Returns: + scale. In the case of scaling by a constant, scale is a + floating point `Tensor`. Otherwise, scale is a `LinearOperator`. + + Raises: + ValueError: if all of `tril`, `diag` and `identity_multiplier` are `None`. + """ + identity_multiplier = _as_tensor(identity_multiplier, "identity_multiplier") + diag = _as_tensor(diag, "diag") + tril = _as_tensor(tril, "tril") + perturb_diag = _as_tensor(perturb_diag, "perturb_diag") + perturb_factor = _as_tensor(perturb_factor, "perturb_factor") + + # If possible, use the low rank update to infer the shape of + # the identity matrix, when scale represents a scaled identity matrix + # with a low rank update. + shape_hint = None + if perturb_factor is not None: + shape_hint = distribution_util.dimension_size(perturb_factor, axis=-2) + + if self._is_only_identity_multiplier: + if validate_args: + return control_flow_ops.with_dependencies( + [check_ops.assert_none_equal( + identity_multiplier, + array_ops.zeros([], identity_multiplier.dtype), + ["identity_multiplier should be non-zero."])], + identity_multiplier) + return identity_multiplier + + scale = distribution_util.make_tril_scale( + loc=shift, + scale_tril=tril, + scale_diag=diag, + scale_identity_multiplier=identity_multiplier, + validate_args=validate_args, + assert_positive=False, + shape_hint=shape_hint) + + if perturb_factor is not None: + return linalg.LinearOperatorLowRankUpdate( + scale, + u=perturb_factor, + diag_update=perturb_diag, + is_diag_update_positive=perturb_diag is None, + is_non_singular=True, # Implied by is_positive_definite=True. + is_self_adjoint=True, + is_positive_definite=True, + is_square=True) + + return scale + + @property + def shift(self): + """The `shift` `Tensor` in `Y = scale @ X + shift`.""" + return self._shift + + @property + def scale(self): + """The `scale` `LinearOperator` in `Y = scale @ X + shift`.""" + return self._scale + + def _forward(self, x): + y = x + if self._is_only_identity_multiplier: + y *= self._scale + if self.shift is not None: + return y + self.shift + return y + y, sample_shape = self._shaper.make_batch_of_event_sample_matrices( + y, expand_batch_dim=False) + with ops.control_dependencies(self._maybe_check_scale() if + self.validate_args else []): + y = self.scale.matmul(y) + y = self._shaper.undo_make_batch_of_event_sample_matrices( + y, sample_shape, expand_batch_dim=False) + if self.shift is not None: + y += self.shift + return y + + def _inverse(self, y): + x = y + if self.shift is not None: + x -= self.shift + if self._is_only_identity_multiplier: + return x / self._scale + + x, sample_shape = self._shaper.make_batch_of_event_sample_matrices( + x, expand_batch_dim=False) + # Solve fails if the op is singular so we may safely skip this assertion. + x = self.scale.solve(x) + x = self._shaper.undo_make_batch_of_event_sample_matrices( + x, sample_shape, expand_batch_dim=False) + return x + + def _inverse_log_det_jacobian(self, y): + return -self._forward_log_det_jacobian(y) + + def _forward_log_det_jacobian(self, x): + if self._is_only_identity_multiplier: + # We don't pad in this case and instead let the fldj be applied + # via broadcast. + event_size = distribution_util.pick_vector( + math_ops.equal(self._shaper.event_ndims, 0), + [1], array_ops.shape(x))[-1] + event_size = math_ops.cast(event_size, dtype=self._scale.dtype) + return math_ops.log(math_ops.abs(self._scale)) * event_size + return self.scale.log_abs_determinant() + + def _maybe_check_scale(self): + try: + return [self.scale.assert_non_singular()] + except NotImplementedError: + pass + return [] diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/affine_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/affine_impl.py deleted file mode 100644 index 05bb9c2f9bdf35e222c94db3491157893da64ebd..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/affine_impl.py +++ /dev/null @@ -1,403 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Affine bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib import linalg -from tensorflow.contrib.distributions.python.ops import distribution_util -from tensorflow.contrib.distributions.python.ops.shape import _DistributionShape -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops.distributions import bijector - - -__all__ = [ - "Affine", -] - - -def _as_tensor(x, name): - """Convenience to convert to `Tensor` or leave as `None`.""" - return None if x is None else ops.convert_to_tensor(x, name=name) - - -class Affine(bijector.Bijector): - """Compute `Y = g(X; shift, scale) = scale @ X + shift`. - - Here `scale = c * I + diag(D1) + tril(L) + V @ diag(D2) @ V.T`. - - In TF parlance, the `scale` term is logically equivalent to: - - ```python - scale = ( - scale_identity_multiplier * tf.diag(tf.ones(d)) + - tf.diag(scale_diag) + - scale_tril + - scale_perturb_factor @ diag(scale_perturb_diag) @ - tf.transpose([scale_perturb_factor]) - ) - ``` - - The `scale` term is applied without necessarily materializing constituent - matrices, i.e., the matmul is [matrix-free]( - https://en.wikipedia.org/wiki/Matrix-free_methods) when possible. - - Examples: - - ```python - # Y = X - b = Affine() - - # Y = X + shift - b = Affine(shift=[1., 2, 3]) - - # Y = 2 * I @ X.T + shift - b = Affine(shift=[1., 2, 3], - scale_identity_multiplier=2.) - - # Y = tf.diag(d1) @ X.T + shift - b = Affine(shift=[1., 2, 3], - scale_diag=[-1., 2, 1]) # Implicitly 3x3. - - # Y = (I + v * v.T) @ X.T + shift - b = Affine(shift=[1., 2, 3], - scale_perturb_factor=[[1., 0], - [0, 1], - [1, 1]]) - - # Y = (diag(d1) + v * diag(d2) * v.T) @ X.T + shift - b = Affine(shift=[1., 2, 3], - scale_diag=[1., 3, 3], # Implicitly 3x3. - scale_perturb_diag=[2., 1], # Implicitly 2x2. - scale_perturb_factor=[[1., 0], - [0, 1], - [1, 1]]) - - ``` - - """ - - def __init__(self, - shift=None, - scale_identity_multiplier=None, - scale_diag=None, - scale_tril=None, - scale_perturb_factor=None, - scale_perturb_diag=None, - event_ndims=1, - validate_args=False, - name="affine"): - """Instantiates the `Affine` bijector. - - This `Bijector` is initialized with `shift` `Tensor` and `scale` arguments, - giving the forward operation: - - ```none - Y = g(X) = scale @ X + shift - ``` - - where the `scale` term is logically equivalent to: - - ```python - scale = ( - scale_identity_multiplier * tf.diag(tf.ones(d)) + - tf.diag(scale_diag) + - scale_tril + - scale_perturb_factor @ diag(scale_perturb_diag) @ - tf.transpose([scale_perturb_factor]) - ) - ``` - - If none of `scale_identity_multiplier`, `scale_diag`, or `scale_tril` are - specified then `scale += IdentityMatrix`. Otherwise specifying a - `scale` argument has the semantics of `scale += Expand(arg)`, i.e., - `scale_diag != None` means `scale += tf.diag(scale_diag)`. - - Args: - shift: Floating-point `Tensor`. If this is set to `None`, no shift is - applied. - scale_identity_multiplier: floating point rank 0 `Tensor` representing a - scaling done to the identity matrix. - When `scale_identity_multiplier = scale_diag = scale_tril = None` then - `scale += IdentityMatrix`. Otherwise no scaled-identity-matrix is added - to `scale`. - scale_diag: Floating-point `Tensor` representing the diagonal matrix. - `scale_diag` has shape [N1, N2, ... k], which represents a k x k - diagonal matrix. - When `None` no diagonal term is added to `scale`. - scale_tril: Floating-point `Tensor` representing the diagonal matrix. - `scale_diag` has shape [N1, N2, ... k, k], which represents a k x k - lower triangular matrix. - When `None` no `scale_tril` term is added to `scale`. - The upper triangular elements above the diagonal are ignored. - scale_perturb_factor: Floating-point `Tensor` representing factor matrix - with last two dimensions of shape `(k, r)`. When `None`, no rank-r - update is added to `scale`. - scale_perturb_diag: Floating-point `Tensor` representing the diagonal - matrix. `scale_perturb_diag` has shape [N1, N2, ... r], which - represents an `r x r` diagonal matrix. When `None` low rank updates will - take the form `scale_perturb_factor * scale_perturb_factor.T`. - event_ndims: Scalar `int` `Tensor` indicating the number of dimensions - associated with a particular draw from the distribution. Must be 0 or 1. - validate_args: Python `bool` indicating whether arguments should be - checked for correctness. - name: Python `str` name given to ops managed by this object. - - Raises: - ValueError: if `perturb_diag` is specified but not `perturb_factor`. - TypeError: if `shift` has different `dtype` from `scale` arguments. - """ - self._graph_parents = [] - self._name = name - self._validate_args = validate_args - - # Ambiguous definition of low rank update. - if scale_perturb_diag is not None and scale_perturb_factor is None: - raise ValueError("When scale_perturb_diag is specified, " - "scale_perturb_factor must be specified.") - - # Special case, only handling a scaled identity matrix. We don't know its - # dimensions, so this is special cased. - # We don't check identity_multiplier, since below we set it to 1. if all - # other scale args are None. - self._is_only_identity_multiplier = (scale_tril is None and - scale_diag is None and - scale_perturb_factor is None) - - with self._name_scope("init", values=[ - shift, scale_identity_multiplier, scale_diag, scale_tril, - scale_perturb_diag, scale_perturb_factor]): - event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims") - event_ndims_const = tensor_util.constant_value(event_ndims) - if event_ndims_const is not None and event_ndims_const not in (0, 1): - raise ValueError("event_ndims(%s) was not 0 or 1" % event_ndims_const) - else: - if validate_args: - # Shape tool will catch if event_ndims is negative. - event_ndims = control_flow_ops.with_dependencies( - [check_ops.assert_less( - event_ndims, 2, message="event_ndims must be 0 or 1")], - event_ndims) - - if event_ndims_const == 0 and not self._is_only_identity_multiplier: - raise ValueError( - "If event_ndims == 0, the only scale argument you can pass is " - "scale_identity_multiplier. All others operate on vectors.") - - # In the absence of `loc` and `scale`, we'll assume `dtype` is `float32`. - dtype = dtypes.float32 - - if shift is not None: - shift = ops.convert_to_tensor(shift, name="shift") - dtype = shift.dtype.base_dtype - self._shift = shift - - # When no args are specified, pretend the scale matrix is the identity - # matrix. - if (self._is_only_identity_multiplier and - scale_identity_multiplier is None): - scale_identity_multiplier = ops.convert_to_tensor(1., dtype=dtype) - - # self._create_scale_operator returns a LinearOperator in all cases - # except if self._is_only_identity_multiplier; in which case it - # returns a scalar Tensor. - scale = self._create_scale_operator( - identity_multiplier=scale_identity_multiplier, - diag=scale_diag, - tril=scale_tril, - perturb_diag=scale_perturb_diag, - perturb_factor=scale_perturb_factor, - shift=shift, - validate_args=validate_args) - - if scale.dtype is not None: - dtype = scale.dtype.base_dtype - - if scale is not None and not self._is_only_identity_multiplier: - if (shift is not None and - shift.dtype.base_dtype != scale.dtype.base_dtype): - raise TypeError( - "shift.dtype({}) is incompatible with scale.dtype({}).".format( - shift.dtype, scale.dtype)) - - if scale.tensor_rank is not None: - batch_ndims = scale.tensor_rank - 2 - else: - batch_ndims = scale.tensor_rank_tensor() - 2 - else: - # We won't need shape inference when scale is None or when scale is a - # scalar. - batch_ndims = 0 - self._scale = scale - self._shaper = _DistributionShape( - batch_ndims=batch_ndims, - event_ndims=event_ndims, - validate_args=validate_args) - super(Affine, self).__init__( - event_ndims=event_ndims, - graph_parents=( - [event_ndims] + - [self._scale] if tensor_util.is_tensor(self._scale) - else self._scale.graph_parents + - [self._shift] if self._shift is not None else []), - is_constant_jacobian=True, - dtype=dtype, - validate_args=validate_args, - name=name) - - def _create_scale_operator(self, identity_multiplier, diag, tril, - perturb_diag, perturb_factor, shift, - validate_args): - """Construct `scale` from various components. - - Args: - identity_multiplier: floating point rank 0 `Tensor` representing a scaling - done to the identity matrix. - diag: Floating-point `Tensor` representing the diagonal matrix. - `scale_diag` has shape [N1, N2, ... k], which represents a k x k - diagonal matrix. - tril: Floating-point `Tensor` representing the diagonal matrix. - `scale_tril` has shape [N1, N2, ... k], which represents a k x k lower - triangular matrix. - perturb_diag: Floating-point `Tensor` representing the diagonal matrix of - the low rank update. - perturb_factor: Floating-point `Tensor` representing factor matrix. - shift: Floating-point `Tensor` representing `shift in `scale @ X + shift`. - validate_args: Python `bool` indicating whether arguments should be - checked for correctness. - - Returns: - scale. In the case of scaling by a constant, scale is a - floating point `Tensor`. Otherwise, scale is a `LinearOperator`. - - Raises: - ValueError: if all of `tril`, `diag` and `identity_multiplier` are `None`. - """ - identity_multiplier = _as_tensor(identity_multiplier, "identity_multiplier") - diag = _as_tensor(diag, "diag") - tril = _as_tensor(tril, "tril") - perturb_diag = _as_tensor(perturb_diag, "perturb_diag") - perturb_factor = _as_tensor(perturb_factor, "perturb_factor") - - # If possible, use the low rank update to infer the shape of - # the identity matrix, when scale represents a scaled identity matrix - # with a low rank update. - shape_hint = None - if perturb_factor is not None: - shape_hint = distribution_util.dimension_size(perturb_factor, axis=-2) - - if self._is_only_identity_multiplier: - if validate_args: - return control_flow_ops.with_dependencies( - [check_ops.assert_none_equal( - identity_multiplier, - array_ops.zeros([], identity_multiplier.dtype), - ["identity_multiplier should be non-zero."])], - identity_multiplier) - return identity_multiplier - - scale = distribution_util.make_tril_scale( - loc=shift, - scale_tril=tril, - scale_diag=diag, - scale_identity_multiplier=identity_multiplier, - validate_args=validate_args, - assert_positive=False, - shape_hint=shape_hint) - - if perturb_factor is not None: - return linalg.LinearOperatorLowRankUpdate( - scale, - u=perturb_factor, - diag_update=perturb_diag, - is_diag_update_positive=perturb_diag is None, - is_non_singular=True, # Implied by is_positive_definite=True. - is_self_adjoint=True, - is_positive_definite=True, - is_square=True) - - return scale - - @property - def shift(self): - """The `shift` `Tensor` in `Y = scale @ X + shift`.""" - return self._shift - - @property - def scale(self): - """The `scale` `LinearOperator` in `Y = scale @ X + shift`.""" - return self._scale - - def _forward(self, x): - y = x - if self._is_only_identity_multiplier: - y *= self._scale - if self.shift is not None: - return y + self.shift - return y - y, sample_shape = self._shaper.make_batch_of_event_sample_matrices( - y, expand_batch_dim=False) - with ops.control_dependencies(self._maybe_check_scale() if - self.validate_args else []): - y = self.scale.matmul(y) - y = self._shaper.undo_make_batch_of_event_sample_matrices( - y, sample_shape, expand_batch_dim=False) - if self.shift is not None: - y += self.shift - return y - - def _inverse(self, y): - x = y - if self.shift is not None: - x -= self.shift - if self._is_only_identity_multiplier: - return x / self._scale - - x, sample_shape = self._shaper.make_batch_of_event_sample_matrices( - x, expand_batch_dim=False) - # Solve fails if the op is singular so we may safely skip this assertion. - x = self.scale.solve(x) - x = self._shaper.undo_make_batch_of_event_sample_matrices( - x, sample_shape, expand_batch_dim=False) - return x - - def _inverse_log_det_jacobian(self, y): - return -self._forward_log_det_jacobian(y) - - def _forward_log_det_jacobian(self, x): - if self._is_only_identity_multiplier: - # We don't pad in this case and instead let the fldj be applied - # via broadcast. - event_size = distribution_util.pick_vector( - math_ops.equal(self._shaper.event_ndims, 0), - [1], array_ops.shape(x))[-1] - event_size = math_ops.cast(event_size, dtype=self._scale.dtype) - return math_ops.log(math_ops.abs(self._scale)) * event_size - return self.scale.log_abs_determinant() - - def _maybe_check_scale(self): - try: - return [self.scale.assert_non_singular()] - except NotImplementedError: - pass - return [] diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator.py b/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator.py index aca04a89df7c3ee09d5f7cc10f6779e33fa7aa66..89043b1410370074f11f2cfa59b6b6663fa62521 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator.py @@ -18,12 +18,214 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.affine_linear_operator_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.contrib.distributions.python.ops.shape import _DistributionShape +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops.distributions import bijector +from tensorflow.python.ops.linalg import linear_operator -_allowed_symbols = ["AffineLinearOperator"] -remove_undocumented(__name__, _allowed_symbols) +__all__ = [ + "AffineLinearOperator", +] + + +class AffineLinearOperator(bijector.Bijector): + """Compute `Y = g(X; shift, scale) = scale @ X + shift`. + + `shift` is a numeric `Tensor` and `scale` is a `LinearOperator`. + + If `X` is a scalar then the forward transformation is: `scale * X + shift` + where `*` denotes the scalar product. + + Note: we don't always simply transpose `X` (but write it this way for + brevity). Actually the input `X` undergoes the following transformation + before being premultiplied by `scale`: + + 1. If there are no sample dims, we call `X = tf.expand_dims(X, 0)`, i.e., + `new_sample_shape = [1]`. Otherwise do nothing. + 2. The sample shape is flattened to have one dimension, i.e., + `new_sample_shape = [n]` where `n = tf.reduce_prod(old_sample_shape)`. + 3. The sample dim is cyclically rotated left by 1, i.e., + `new_shape = [B1,...,Bb, k, n]` where `n` is as above, `k` is the + event_shape, and `B1,...,Bb` are the batch shapes for each of `b` batch + dimensions. + + (For more details see `shape.make_batch_of_event_sample_matrices`.) + + The result of the above transformation is that `X` can be regarded as a batch + of matrices where each column is a draw from the distribution. After + premultiplying by `scale`, we take the inverse of this procedure. The input + `Y` also undergoes the same transformation before/after premultiplying by + `inv(scale)`. + + Example Use: + + ```python + linalg = tf.linalg + + x = [1., 2, 3] + + shift = [-1., 0., 1] + diag = [1., 2, 3] + scale = linalg.LinearOperatorDiag(diag) + affine = AffineLinearOperator(shift, scale) + # In this case, `forward` is equivalent to: + # y = scale @ x + shift + y = affine.forward(x) # [0., 4, 10] + + shift = [2., 3, 1] + tril = [[1., 0, 0], + [2, 1, 0], + [3, 2, 1]] + scale = linalg.LinearOperatorLowerTriangular(tril) + affine = AffineLinearOperator(shift, scale) + # In this case, `forward` is equivalent to: + # np.squeeze(np.matmul(tril, np.expand_dims(x, -1)), -1) + shift + y = affine.forward(x) # [3., 7, 11] + ``` + + """ + + def __init__(self, + shift=None, + scale=None, + event_ndims=1, + validate_args=False, + name="affine_linear_operator"): + """Instantiates the `AffineLinearOperator` bijector. + + Args: + shift: Floating-point `Tensor`. + scale: Subclass of `LinearOperator`. Represents the (batch) positive + definite matrix `M` in `R^{k x k}`. + event_ndims: Scalar `integer` `Tensor` indicating the number of dimensions + associated with a particular draw from the distribution. Must be 0 or 1. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + + Raises: + ValueError: if `event_ndims` is not 0 or 1. + TypeError: if `scale` is not a `LinearOperator`. + TypeError: if `shift.dtype` does not match `scale.dtype`. + ValueError: if not `scale.is_non_singular`. + """ + self._graph_parents = [] + self._name = name + self._validate_args = validate_args + graph_parents = [] + with self._name_scope("init", values=[shift]): + event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims") + if tensor_util.constant_value(event_ndims) is not None: + event_ndims = tensor_util.constant_value(event_ndims) + if event_ndims not in (0, 1): + raise ValueError("event_ndims({}) was not 0 or 1".format(event_ndims)) + else: + if validate_args: + # Shape tool will catch if event_ndims is negative. + event_ndims = control_flow_ops.with_dependencies( + [check_ops.assert_less( + event_ndims, 2, message="event_ndims must be 0 or 1")], + event_ndims) + graph_parents += [event_ndims] + + # In the absence of `loc` and `scale`, we'll assume `dtype` is `float32`. + dtype = dtypes.float32 + + if shift is not None: + shift = ops.convert_to_tensor(shift, name="shift") + graph_parents += [shift] + dtype = shift.dtype.base_dtype + self._shift = shift + + if scale is not None: + if (shift is not None and + shift.dtype.base_dtype != scale.dtype.base_dtype): + raise TypeError( + "shift.dtype({}) is incompatible with scale.dtype({}).".format( + shift.dtype, scale.dtype)) + if not isinstance(scale, linear_operator.LinearOperator): + raise TypeError("scale is not an instance of tf.LinearOperator") + if validate_args and not scale.is_non_singular: + raise ValueError("Scale matrix must be non-singular.") + graph_parents += scale.graph_parents + if scale.tensor_rank is not None: + batch_ndims = scale.tensor_rank - 2 + else: + batch_ndims = scale.tensor_rank_tensor() - 2 + graph_parents += [batch_ndims] + if scale.dtype is not None: + dtype = scale.dtype.base_dtype + else: + batch_ndims = 0 # We won't need shape inference when scale is None. + self._scale = scale + self._shaper = _DistributionShape( + batch_ndims=batch_ndims, + event_ndims=event_ndims, + validate_args=validate_args) + super(AffineLinearOperator, self).__init__( + event_ndims=event_ndims, + graph_parents=graph_parents, + is_constant_jacobian=True, + dtype=dtype, + validate_args=validate_args, + name=name) + + @property + def shift(self): + """The `shift` `Tensor` in `Y = scale @ X + shift`.""" + return self._shift + + @property + def scale(self): + """The `scale` `LinearOperator` in `Y = scale @ X + shift`.""" + return self._scale + + def _forward(self, x): + y = x + if self.scale is not None: + y, sample_shape = self._shaper.make_batch_of_event_sample_matrices( + y, expand_batch_dim=False) + with ops.control_dependencies(self._maybe_collect_assertions() if + self.validate_args else []): + y = self.scale.matmul(y) + y = self._shaper.undo_make_batch_of_event_sample_matrices( + y, sample_shape, expand_batch_dim=False) + if self.shift is not None: + y += self.shift + return y + + def _inverse(self, y): + x = y + if self.shift is not None: + x -= self.shift + if self.scale is not None: + x, sample_shape = self._shaper.make_batch_of_event_sample_matrices( + x, expand_batch_dim=False) + # Solve fails if the op is singular so we may safely skip this assertion. + x = self.scale.solve(x) + x = self._shaper.undo_make_batch_of_event_sample_matrices( + x, sample_shape, expand_batch_dim=False) + return x + + def _inverse_log_det_jacobian(self, y): + return -self._forward_log_det_jacobian(y) + + def _forward_log_det_jacobian(self, x): # pylint: disable=unused-argument + if self.scale is None: + return constant_op.constant(0, dtype=x.dtype.base_dtype) + with ops.control_dependencies(self._maybe_collect_assertions() if + self.validate_args else []): + return self.scale.log_abs_determinant() + + def _maybe_collect_assertions(self): + try: + return [self.scale.assert_non_singular()] + except NotImplementedError: + pass + return [] diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator_impl.py deleted file mode 100644 index 89043b1410370074f11f2cfa59b6b6663fa62521..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator_impl.py +++ /dev/null @@ -1,231 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""AffineLinearOperator bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.distributions.python.ops.shape import _DistributionShape -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops.distributions import bijector -from tensorflow.python.ops.linalg import linear_operator - - -__all__ = [ - "AffineLinearOperator", -] - - -class AffineLinearOperator(bijector.Bijector): - """Compute `Y = g(X; shift, scale) = scale @ X + shift`. - - `shift` is a numeric `Tensor` and `scale` is a `LinearOperator`. - - If `X` is a scalar then the forward transformation is: `scale * X + shift` - where `*` denotes the scalar product. - - Note: we don't always simply transpose `X` (but write it this way for - brevity). Actually the input `X` undergoes the following transformation - before being premultiplied by `scale`: - - 1. If there are no sample dims, we call `X = tf.expand_dims(X, 0)`, i.e., - `new_sample_shape = [1]`. Otherwise do nothing. - 2. The sample shape is flattened to have one dimension, i.e., - `new_sample_shape = [n]` where `n = tf.reduce_prod(old_sample_shape)`. - 3. The sample dim is cyclically rotated left by 1, i.e., - `new_shape = [B1,...,Bb, k, n]` where `n` is as above, `k` is the - event_shape, and `B1,...,Bb` are the batch shapes for each of `b` batch - dimensions. - - (For more details see `shape.make_batch_of_event_sample_matrices`.) - - The result of the above transformation is that `X` can be regarded as a batch - of matrices where each column is a draw from the distribution. After - premultiplying by `scale`, we take the inverse of this procedure. The input - `Y` also undergoes the same transformation before/after premultiplying by - `inv(scale)`. - - Example Use: - - ```python - linalg = tf.linalg - - x = [1., 2, 3] - - shift = [-1., 0., 1] - diag = [1., 2, 3] - scale = linalg.LinearOperatorDiag(diag) - affine = AffineLinearOperator(shift, scale) - # In this case, `forward` is equivalent to: - # y = scale @ x + shift - y = affine.forward(x) # [0., 4, 10] - - shift = [2., 3, 1] - tril = [[1., 0, 0], - [2, 1, 0], - [3, 2, 1]] - scale = linalg.LinearOperatorLowerTriangular(tril) - affine = AffineLinearOperator(shift, scale) - # In this case, `forward` is equivalent to: - # np.squeeze(np.matmul(tril, np.expand_dims(x, -1)), -1) + shift - y = affine.forward(x) # [3., 7, 11] - ``` - - """ - - def __init__(self, - shift=None, - scale=None, - event_ndims=1, - validate_args=False, - name="affine_linear_operator"): - """Instantiates the `AffineLinearOperator` bijector. - - Args: - shift: Floating-point `Tensor`. - scale: Subclass of `LinearOperator`. Represents the (batch) positive - definite matrix `M` in `R^{k x k}`. - event_ndims: Scalar `integer` `Tensor` indicating the number of dimensions - associated with a particular draw from the distribution. Must be 0 or 1. - validate_args: Python `bool` indicating whether arguments should be - checked for correctness. - name: Python `str` name given to ops managed by this object. - - Raises: - ValueError: if `event_ndims` is not 0 or 1. - TypeError: if `scale` is not a `LinearOperator`. - TypeError: if `shift.dtype` does not match `scale.dtype`. - ValueError: if not `scale.is_non_singular`. - """ - self._graph_parents = [] - self._name = name - self._validate_args = validate_args - graph_parents = [] - with self._name_scope("init", values=[shift]): - event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims") - if tensor_util.constant_value(event_ndims) is not None: - event_ndims = tensor_util.constant_value(event_ndims) - if event_ndims not in (0, 1): - raise ValueError("event_ndims({}) was not 0 or 1".format(event_ndims)) - else: - if validate_args: - # Shape tool will catch if event_ndims is negative. - event_ndims = control_flow_ops.with_dependencies( - [check_ops.assert_less( - event_ndims, 2, message="event_ndims must be 0 or 1")], - event_ndims) - graph_parents += [event_ndims] - - # In the absence of `loc` and `scale`, we'll assume `dtype` is `float32`. - dtype = dtypes.float32 - - if shift is not None: - shift = ops.convert_to_tensor(shift, name="shift") - graph_parents += [shift] - dtype = shift.dtype.base_dtype - self._shift = shift - - if scale is not None: - if (shift is not None and - shift.dtype.base_dtype != scale.dtype.base_dtype): - raise TypeError( - "shift.dtype({}) is incompatible with scale.dtype({}).".format( - shift.dtype, scale.dtype)) - if not isinstance(scale, linear_operator.LinearOperator): - raise TypeError("scale is not an instance of tf.LinearOperator") - if validate_args and not scale.is_non_singular: - raise ValueError("Scale matrix must be non-singular.") - graph_parents += scale.graph_parents - if scale.tensor_rank is not None: - batch_ndims = scale.tensor_rank - 2 - else: - batch_ndims = scale.tensor_rank_tensor() - 2 - graph_parents += [batch_ndims] - if scale.dtype is not None: - dtype = scale.dtype.base_dtype - else: - batch_ndims = 0 # We won't need shape inference when scale is None. - self._scale = scale - self._shaper = _DistributionShape( - batch_ndims=batch_ndims, - event_ndims=event_ndims, - validate_args=validate_args) - super(AffineLinearOperator, self).__init__( - event_ndims=event_ndims, - graph_parents=graph_parents, - is_constant_jacobian=True, - dtype=dtype, - validate_args=validate_args, - name=name) - - @property - def shift(self): - """The `shift` `Tensor` in `Y = scale @ X + shift`.""" - return self._shift - - @property - def scale(self): - """The `scale` `LinearOperator` in `Y = scale @ X + shift`.""" - return self._scale - - def _forward(self, x): - y = x - if self.scale is not None: - y, sample_shape = self._shaper.make_batch_of_event_sample_matrices( - y, expand_batch_dim=False) - with ops.control_dependencies(self._maybe_collect_assertions() if - self.validate_args else []): - y = self.scale.matmul(y) - y = self._shaper.undo_make_batch_of_event_sample_matrices( - y, sample_shape, expand_batch_dim=False) - if self.shift is not None: - y += self.shift - return y - - def _inverse(self, y): - x = y - if self.shift is not None: - x -= self.shift - if self.scale is not None: - x, sample_shape = self._shaper.make_batch_of_event_sample_matrices( - x, expand_batch_dim=False) - # Solve fails if the op is singular so we may safely skip this assertion. - x = self.scale.solve(x) - x = self._shaper.undo_make_batch_of_event_sample_matrices( - x, sample_shape, expand_batch_dim=False) - return x - - def _inverse_log_det_jacobian(self, y): - return -self._forward_log_det_jacobian(y) - - def _forward_log_det_jacobian(self, x): # pylint: disable=unused-argument - if self.scale is None: - return constant_op.constant(0, dtype=x.dtype.base_dtype) - with ops.control_dependencies(self._maybe_collect_assertions() if - self.validate_args else []): - return self.scale.log_abs_determinant() - - def _maybe_collect_assertions(self): - try: - return [self.scale.assert_non_singular()] - except NotImplementedError: - pass - return [] diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/chain.py b/tensorflow/contrib/distributions/python/ops/bijectors/chain.py index 0db10fb75c8483a8209f39370362b05a03d047ca..3ce7c26213034c7345a20faa803c94a1bfa8d579 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/chain.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/chain.py @@ -18,12 +18,151 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.chain_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +import itertools -_allowed_symbols = ["Chain"] +from tensorflow.python.framework import constant_op +from tensorflow.python.ops.distributions import bijector -remove_undocumented(__name__, _allowed_symbols) + +__all__ = [ + "Chain", +] + + +class Chain(bijector.Bijector): + """Bijector which applies a sequence of bijectors. + + Example Use: + + ```python + chain = Chain([Exp(), Softplus()], name="one_plus_exp") + ``` + + Results in: + + * Forward: + + ```python + exp = Exp() + softplus = Softplus() + Chain([exp, softplus]).forward(x) + = exp.forward(softplus.forward(x)) + = tf.exp(tf.log(1. + tf.exp(x))) + = 1. + tf.exp(x) + ``` + + * Inverse: + + ```python + exp = Exp() + softplus = Softplus() + Chain([exp, softplus]).inverse(y) + = softplus.inverse(exp.inverse(y)) + = tf.log(tf.exp(tf.log(y)) - 1.) + = tf.log(y - 1.) + ``` + + """ + + def __init__(self, bijectors=None, validate_args=False, name=None): + """Instantiates `Chain` bijector. + + Args: + bijectors: Python `list` of bijector instances. An empty list makes this + bijector equivalent to the `Identity` bijector. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str`, name given to ops managed by this object. Default: + E.g., `Chain([Exp(), Softplus()]).name == "chain_of_exp_of_softplus"`. + + Raises: + ValueError: if bijectors have different dtypes. + """ + if bijectors is None: + bijectors = () + self._bijectors = bijectors + + for a_bijector in bijectors: + if not a_bijector._is_injective: # pylint: disable=protected-access + raise NotImplementedError( + "Invert is not implemented for non-injective bijector ({})".format( + a_bijector.name)) + + dtype = list(set([b.dtype for b in bijectors])) + if len(dtype) > 2: + raise ValueError("incompatible dtypes: %s" % dtype) + elif len(dtype) == 2: + dtype = dtype[1] if dtype[0] is None else dtype[0] + event_ndims = bijectors[0].event_ndims + elif len(dtype) == 1: + dtype = dtype[0] + event_ndims = bijectors[0].event_ndims + else: + dtype = None + event_ndims = None + + super(Chain, self).__init__( + graph_parents=list(itertools.chain.from_iterable( + b.graph_parents for b in bijectors)), + is_constant_jacobian=all(b.is_constant_jacobian for b in bijectors), + validate_args=validate_args, + dtype=dtype, + event_ndims=event_ndims, + name=name or ("identity" if not bijectors else + "_of_".join(["chain"] + [b.name for b in bijectors]))) + + @property + def bijectors(self): + return self._bijectors + + def _shape_helper(self, func_name, input_shape, reverse): + new_shape = input_shape + for b in reversed(self.bijectors) if reverse else self.bijectors: + func = getattr(b, func_name, None) + if func is None: + raise ValueError("unable to call %s on bijector %s (%s)" % + (func_name, b.name, func)) + new_shape = func(new_shape) + return new_shape + + def _forward_event_shape(self, input_shape): + return self._shape_helper("forward_event_shape", input_shape, + reverse=True) + + def _forward_event_shape_tensor(self, input_shape): + return self._shape_helper( + "forward_event_shape_tensor", input_shape, reverse=True) + + def _inverse_event_shape(self, output_shape): + return self._shape_helper("inverse_event_shape", output_shape, + reverse=False) + + def _inverse_event_shape_tensor(self, output_shape): + return self._shape_helper("inverse_event_shape_tensor", output_shape, + reverse=False) + + def _inverse(self, y, **kwargs): + for b in self.bijectors: + y = b.inverse(y, **kwargs.get(b.name, {})) + return y + + def _inverse_log_det_jacobian(self, y, **kwargs): + ildj = constant_op.constant(0., dtype=y.dtype, + name="inverse_log_det_jacobian") + for b in self.bijectors: + ildj += b.inverse_log_det_jacobian(y, **kwargs.get(b.name, {})) + y = b.inverse(y, **kwargs.get(b.name, {})) + return ildj + + def _forward(self, x, **kwargs): + for b in reversed(self.bijectors): + x = b.forward(x, **kwargs.get(b.name, {})) + return x + + def _forward_log_det_jacobian(self, x, **kwargs): + fldj = constant_op.constant(0., dtype=x.dtype, + name="forward_log_det_jacobian") + for b in reversed(self.bijectors): + fldj += b.forward_log_det_jacobian(x, **kwargs.get(b.name, {})) + x = b.forward(x, **kwargs.get(b.name, {})) + return fldj diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/chain_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/chain_impl.py deleted file mode 100644 index 3ce7c26213034c7345a20faa803c94a1bfa8d579..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/chain_impl.py +++ /dev/null @@ -1,168 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Chain bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import itertools - -from tensorflow.python.framework import constant_op -from tensorflow.python.ops.distributions import bijector - - -__all__ = [ - "Chain", -] - - -class Chain(bijector.Bijector): - """Bijector which applies a sequence of bijectors. - - Example Use: - - ```python - chain = Chain([Exp(), Softplus()], name="one_plus_exp") - ``` - - Results in: - - * Forward: - - ```python - exp = Exp() - softplus = Softplus() - Chain([exp, softplus]).forward(x) - = exp.forward(softplus.forward(x)) - = tf.exp(tf.log(1. + tf.exp(x))) - = 1. + tf.exp(x) - ``` - - * Inverse: - - ```python - exp = Exp() - softplus = Softplus() - Chain([exp, softplus]).inverse(y) - = softplus.inverse(exp.inverse(y)) - = tf.log(tf.exp(tf.log(y)) - 1.) - = tf.log(y - 1.) - ``` - - """ - - def __init__(self, bijectors=None, validate_args=False, name=None): - """Instantiates `Chain` bijector. - - Args: - bijectors: Python `list` of bijector instances. An empty list makes this - bijector equivalent to the `Identity` bijector. - validate_args: Python `bool` indicating whether arguments should be - checked for correctness. - name: Python `str`, name given to ops managed by this object. Default: - E.g., `Chain([Exp(), Softplus()]).name == "chain_of_exp_of_softplus"`. - - Raises: - ValueError: if bijectors have different dtypes. - """ - if bijectors is None: - bijectors = () - self._bijectors = bijectors - - for a_bijector in bijectors: - if not a_bijector._is_injective: # pylint: disable=protected-access - raise NotImplementedError( - "Invert is not implemented for non-injective bijector ({})".format( - a_bijector.name)) - - dtype = list(set([b.dtype for b in bijectors])) - if len(dtype) > 2: - raise ValueError("incompatible dtypes: %s" % dtype) - elif len(dtype) == 2: - dtype = dtype[1] if dtype[0] is None else dtype[0] - event_ndims = bijectors[0].event_ndims - elif len(dtype) == 1: - dtype = dtype[0] - event_ndims = bijectors[0].event_ndims - else: - dtype = None - event_ndims = None - - super(Chain, self).__init__( - graph_parents=list(itertools.chain.from_iterable( - b.graph_parents for b in bijectors)), - is_constant_jacobian=all(b.is_constant_jacobian for b in bijectors), - validate_args=validate_args, - dtype=dtype, - event_ndims=event_ndims, - name=name or ("identity" if not bijectors else - "_of_".join(["chain"] + [b.name for b in bijectors]))) - - @property - def bijectors(self): - return self._bijectors - - def _shape_helper(self, func_name, input_shape, reverse): - new_shape = input_shape - for b in reversed(self.bijectors) if reverse else self.bijectors: - func = getattr(b, func_name, None) - if func is None: - raise ValueError("unable to call %s on bijector %s (%s)" % - (func_name, b.name, func)) - new_shape = func(new_shape) - return new_shape - - def _forward_event_shape(self, input_shape): - return self._shape_helper("forward_event_shape", input_shape, - reverse=True) - - def _forward_event_shape_tensor(self, input_shape): - return self._shape_helper( - "forward_event_shape_tensor", input_shape, reverse=True) - - def _inverse_event_shape(self, output_shape): - return self._shape_helper("inverse_event_shape", output_shape, - reverse=False) - - def _inverse_event_shape_tensor(self, output_shape): - return self._shape_helper("inverse_event_shape_tensor", output_shape, - reverse=False) - - def _inverse(self, y, **kwargs): - for b in self.bijectors: - y = b.inverse(y, **kwargs.get(b.name, {})) - return y - - def _inverse_log_det_jacobian(self, y, **kwargs): - ildj = constant_op.constant(0., dtype=y.dtype, - name="inverse_log_det_jacobian") - for b in self.bijectors: - ildj += b.inverse_log_det_jacobian(y, **kwargs.get(b.name, {})) - y = b.inverse(y, **kwargs.get(b.name, {})) - return ildj - - def _forward(self, x, **kwargs): - for b in reversed(self.bijectors): - x = b.forward(x, **kwargs.get(b.name, {})) - return x - - def _forward_log_det_jacobian(self, x, **kwargs): - fldj = constant_op.constant(0., dtype=x.dtype, - name="forward_log_det_jacobian") - for b in reversed(self.bijectors): - fldj += b.forward_log_det_jacobian(x, **kwargs.get(b.name, {})) - x = b.forward(x, **kwargs.get(b.name, {})) - return fldj diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py b/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py index 4686af8bc42a3232cb3a34f2cfcce8323c5896dd..cbd60f92a60612c6cf791b2c7708a3310c6e2b6b 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py @@ -18,12 +18,219 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.cholesky_outer_product_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +import numpy as np -_allowed_symbols = ["CholeskyOuterProduct"] +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bijector +from tensorflow.python.ops.distributions import util as distribution_util -remove_undocumented(__name__, _allowed_symbols) + +__all__ = [ + "CholeskyOuterProduct", +] + + +class CholeskyOuterProduct(bijector.Bijector): + """Compute `g(X) = X @ X.T`; X is lower-triangular, positive-diagonal matrix. + + `event_ndims` must be 0 or 2, i.e., scalar or matrix. + + Note: the upper-triangular part of X is ignored (whether or not its zero). + + The surjectivity of g as a map from the set of n x n positive-diagonal + lower-triangular matrices to the set of SPD matrices follows immediately from + executing the Cholesky factorization algorithm on an SPD matrix A to produce a + positive-diagonal lower-triangular matrix L such that `A = L @ L.T`. + + To prove the injectivity of g, suppose that L_1 and L_2 are lower-triangular + with positive diagonals and satisfy `A = L_1 @ L_1.T = L_2 @ L_2.T`. Then + `inv(L_1) @ A @ inv(L_1).T = [inv(L_1) @ L_2] @ [inv(L_1) @ L_2].T = I`. + Setting `L_3 := inv(L_1) @ L_2`, that L_3 is a positive-diagonal + lower-triangular matrix follows from `inv(L_1)` being positive-diagonal + lower-triangular (which follows from the diagonal of a triangular matrix being + its spectrum), and that the product of two positive-diagonal lower-triangular + matrices is another positive-diagonal lower-triangular matrix. + + A simple inductive argument (proceding one column of L_3 at a time) shows + that, if `I = L_3 @ L_3.T`, with L_3 being lower-triangular with positive- + diagonal, then `L_3 = I`. Thus, `L_1 = L_2`, proving injectivity of g. + + Examples: + + ```python + bijector.CholeskyOuterProduct(event_ndims=2).forward(x=[[1., 0], [2, 1]]) + # Result: [[1., 2], [2, 5]], i.e., x @ x.T + + bijector.CholeskyOuterProduct(event_ndims=2).inverse(y=[[1., 2], [2, 5]]) + # Result: [[1., 0], [2, 1]], i.e., cholesky(y). + ``` + + """ + + def __init__(self, event_ndims=2, validate_args=False, + name="cholesky_outer_product"): + """Instantiates the `CholeskyOuterProduct` bijector. + + Args: + event_ndims: `constant` `int32` scalar `Tensor` indicating the number of + dimensions associated with a particular draw from the distribution. Must + be 0 or 2. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + + Raises: + ValueError: if event_ndims is neither 0 or 2. + """ + self._graph_parents = [] + self._name = name + with self._name_scope("init", values=[event_ndims]): + event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims") + event_ndims = tensor_util.constant_value(event_ndims) + if event_ndims is None or event_ndims not in [0, 2]: + raise ValueError("`event_ndims` must be a TF constant which is 0 or 2") + self._static_event_ndims = event_ndims + super(CholeskyOuterProduct, self).__init__( + event_ndims=event_ndims, + validate_args=validate_args, + name=name) + + def _forward(self, x): + if self._static_event_ndims == 0: + return math_ops.square(x) + if self.validate_args: + is_matrix = check_ops.assert_rank_at_least(x, 2) + shape = array_ops.shape(x) + is_square = check_ops.assert_equal(shape[-2], shape[-1]) + x = control_flow_ops.with_dependencies([is_matrix, is_square], x) + # For safety, explicitly zero-out the upper triangular part. + x = array_ops.matrix_band_part(x, -1, 0) + return math_ops.matmul(x, x, adjoint_b=True) + + def _inverse(self, y): + return (math_ops.sqrt(y) if self._static_event_ndims == 0 + else linalg_ops.cholesky(y)) + + def _inverse_log_det_jacobian(self, y): + return -self._forward_log_det_jacobian(x=self._inverse(y)) + + def _forward_log_det_jacobian(self, x): + # Let Y be a symmetric, positive definite matrix and write: + # Y = X X.T + # where X is lower-triangular. + # + # Observe that, + # dY[i,j]/dX[a,b] + # = d/dX[a,b] { X[i,:] X[j,:] } + # = sum_{d=1}^p { I[i=a] I[d=b] X[j,d] + I[j=a] I[d=b] X[i,d] } + # + # To compute the Jacobian dX/dY we must represent X,Y as vectors. Since Y is + # symmetric and X is lower-triangular, we need vectors of dimension: + # d = p (p + 1) / 2 + # where X, Y are p x p matrices, p > 0. We use a row-major mapping, i.e., + # k = { i (i + 1) / 2 + j i>=j + # { undef ij thus i,j!=a. + # + # Since the Jacobian is lower-triangular, we need only compute the product + # of diagonal elements: + # d vec[Y] / d vec[X] @[k(i,j), k(i,j)] + # = X[j,j] + I[i=j] X[i,j] + # = 2 X[j,j]. + # Since there is a 2 X[j,j] term for every lower-triangular element of X we + # conclude: + # |Jac(d vec[Y]/d vec[X])| = 2^p prod_{j=0}^{p-1} X[j,j]^{p-j}. + if self._static_event_ndims == 0: + if self.validate_args: + is_positive = check_ops.assert_positive( + x, message="All elements must be positive.") + x = control_flow_ops.with_dependencies([is_positive], x) + return np.log(2.) + math_ops.log(x) + + diag = array_ops.matrix_diag_part(x) + + # We now ensure diag is columnar. Eg, if `diag = [1, 2, 3]` then the output + # is `[[1], [2], [3]]` and if `diag = [[1, 2, 3], [4, 5, 6]]` then the + # output is unchanged. + diag = self._make_columnar(diag) + + if self.validate_args: + is_matrix = check_ops.assert_rank_at_least( + x, 2, message="Input must be a (batch of) matrix.") + shape = array_ops.shape(x) + is_square = check_ops.assert_equal( + shape[-2], shape[-1], + message="Input must be a (batch of) square matrix.") + # Assuming lower-triangular means we only need check diag>0. + is_positive_definite = check_ops.assert_positive( + diag, message="Input must be positive definite.") + x = control_flow_ops.with_dependencies( + [is_matrix, is_square, is_positive_definite], x) + + # Create a vector equal to: [p, p-1, ..., 2, 1]. + if x.get_shape().ndims is None or x.get_shape()[-1].value is None: + p_int = array_ops.shape(x)[-1] + p_float = math_ops.cast(p_int, dtype=x.dtype) + else: + p_int = x.get_shape()[-1].value + p_float = np.array(p_int, dtype=x.dtype.as_numpy_dtype) + exponents = math_ops.linspace(p_float, 1., p_int) + + sum_weighted_log_diag = array_ops.squeeze( + math_ops.matmul(math_ops.log(diag), + exponents[..., array_ops.newaxis]), + squeeze_dims=-1) + fldj = p_float * np.log(2.) + sum_weighted_log_diag + + return fldj + + def _make_columnar(self, x): + """Ensures non-scalar input has at least one column. + + Example: + If `x = [1, 2, 3]` then the output is `[[1], [2], [3]]`. + + If `x = [[1, 2, 3], [4, 5, 6]]` then the output is unchanged. + + If `x = 1` then the output is unchanged. + + Args: + x: `Tensor`. + + Returns: + columnar_x: `Tensor` with at least two dimensions. + """ + if x.get_shape().ndims is not None: + if x.get_shape().ndims == 1: + x = x[array_ops.newaxis, :] + return x + shape = array_ops.shape(x) + maybe_expanded_shape = array_ops.concat([ + shape[:-1], + distribution_util.pick_vector( + math_ops.equal(array_ops.rank(x), 1), + [1], np.array([], dtype=np.int32)), + shape[-1:], + ], 0) + return array_ops.reshape(x, maybe_expanded_shape) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product_impl.py deleted file mode 100644 index cbd60f92a60612c6cf791b2c7708a3310c6e2b6b..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product_impl.py +++ /dev/null @@ -1,236 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""CholeskyOuterProduct bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import linalg_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops.distributions import bijector -from tensorflow.python.ops.distributions import util as distribution_util - - -__all__ = [ - "CholeskyOuterProduct", -] - - -class CholeskyOuterProduct(bijector.Bijector): - """Compute `g(X) = X @ X.T`; X is lower-triangular, positive-diagonal matrix. - - `event_ndims` must be 0 or 2, i.e., scalar or matrix. - - Note: the upper-triangular part of X is ignored (whether or not its zero). - - The surjectivity of g as a map from the set of n x n positive-diagonal - lower-triangular matrices to the set of SPD matrices follows immediately from - executing the Cholesky factorization algorithm on an SPD matrix A to produce a - positive-diagonal lower-triangular matrix L such that `A = L @ L.T`. - - To prove the injectivity of g, suppose that L_1 and L_2 are lower-triangular - with positive diagonals and satisfy `A = L_1 @ L_1.T = L_2 @ L_2.T`. Then - `inv(L_1) @ A @ inv(L_1).T = [inv(L_1) @ L_2] @ [inv(L_1) @ L_2].T = I`. - Setting `L_3 := inv(L_1) @ L_2`, that L_3 is a positive-diagonal - lower-triangular matrix follows from `inv(L_1)` being positive-diagonal - lower-triangular (which follows from the diagonal of a triangular matrix being - its spectrum), and that the product of two positive-diagonal lower-triangular - matrices is another positive-diagonal lower-triangular matrix. - - A simple inductive argument (proceding one column of L_3 at a time) shows - that, if `I = L_3 @ L_3.T`, with L_3 being lower-triangular with positive- - diagonal, then `L_3 = I`. Thus, `L_1 = L_2`, proving injectivity of g. - - Examples: - - ```python - bijector.CholeskyOuterProduct(event_ndims=2).forward(x=[[1., 0], [2, 1]]) - # Result: [[1., 2], [2, 5]], i.e., x @ x.T - - bijector.CholeskyOuterProduct(event_ndims=2).inverse(y=[[1., 2], [2, 5]]) - # Result: [[1., 0], [2, 1]], i.e., cholesky(y). - ``` - - """ - - def __init__(self, event_ndims=2, validate_args=False, - name="cholesky_outer_product"): - """Instantiates the `CholeskyOuterProduct` bijector. - - Args: - event_ndims: `constant` `int32` scalar `Tensor` indicating the number of - dimensions associated with a particular draw from the distribution. Must - be 0 or 2. - validate_args: Python `bool` indicating whether arguments should be - checked for correctness. - name: Python `str` name given to ops managed by this object. - - Raises: - ValueError: if event_ndims is neither 0 or 2. - """ - self._graph_parents = [] - self._name = name - with self._name_scope("init", values=[event_ndims]): - event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims") - event_ndims = tensor_util.constant_value(event_ndims) - if event_ndims is None or event_ndims not in [0, 2]: - raise ValueError("`event_ndims` must be a TF constant which is 0 or 2") - self._static_event_ndims = event_ndims - super(CholeskyOuterProduct, self).__init__( - event_ndims=event_ndims, - validate_args=validate_args, - name=name) - - def _forward(self, x): - if self._static_event_ndims == 0: - return math_ops.square(x) - if self.validate_args: - is_matrix = check_ops.assert_rank_at_least(x, 2) - shape = array_ops.shape(x) - is_square = check_ops.assert_equal(shape[-2], shape[-1]) - x = control_flow_ops.with_dependencies([is_matrix, is_square], x) - # For safety, explicitly zero-out the upper triangular part. - x = array_ops.matrix_band_part(x, -1, 0) - return math_ops.matmul(x, x, adjoint_b=True) - - def _inverse(self, y): - return (math_ops.sqrt(y) if self._static_event_ndims == 0 - else linalg_ops.cholesky(y)) - - def _inverse_log_det_jacobian(self, y): - return -self._forward_log_det_jacobian(x=self._inverse(y)) - - def _forward_log_det_jacobian(self, x): - # Let Y be a symmetric, positive definite matrix and write: - # Y = X X.T - # where X is lower-triangular. - # - # Observe that, - # dY[i,j]/dX[a,b] - # = d/dX[a,b] { X[i,:] X[j,:] } - # = sum_{d=1}^p { I[i=a] I[d=b] X[j,d] + I[j=a] I[d=b] X[i,d] } - # - # To compute the Jacobian dX/dY we must represent X,Y as vectors. Since Y is - # symmetric and X is lower-triangular, we need vectors of dimension: - # d = p (p + 1) / 2 - # where X, Y are p x p matrices, p > 0. We use a row-major mapping, i.e., - # k = { i (i + 1) / 2 + j i>=j - # { undef ij thus i,j!=a. - # - # Since the Jacobian is lower-triangular, we need only compute the product - # of diagonal elements: - # d vec[Y] / d vec[X] @[k(i,j), k(i,j)] - # = X[j,j] + I[i=j] X[i,j] - # = 2 X[j,j]. - # Since there is a 2 X[j,j] term for every lower-triangular element of X we - # conclude: - # |Jac(d vec[Y]/d vec[X])| = 2^p prod_{j=0}^{p-1} X[j,j]^{p-j}. - if self._static_event_ndims == 0: - if self.validate_args: - is_positive = check_ops.assert_positive( - x, message="All elements must be positive.") - x = control_flow_ops.with_dependencies([is_positive], x) - return np.log(2.) + math_ops.log(x) - - diag = array_ops.matrix_diag_part(x) - - # We now ensure diag is columnar. Eg, if `diag = [1, 2, 3]` then the output - # is `[[1], [2], [3]]` and if `diag = [[1, 2, 3], [4, 5, 6]]` then the - # output is unchanged. - diag = self._make_columnar(diag) - - if self.validate_args: - is_matrix = check_ops.assert_rank_at_least( - x, 2, message="Input must be a (batch of) matrix.") - shape = array_ops.shape(x) - is_square = check_ops.assert_equal( - shape[-2], shape[-1], - message="Input must be a (batch of) square matrix.") - # Assuming lower-triangular means we only need check diag>0. - is_positive_definite = check_ops.assert_positive( - diag, message="Input must be positive definite.") - x = control_flow_ops.with_dependencies( - [is_matrix, is_square, is_positive_definite], x) - - # Create a vector equal to: [p, p-1, ..., 2, 1]. - if x.get_shape().ndims is None or x.get_shape()[-1].value is None: - p_int = array_ops.shape(x)[-1] - p_float = math_ops.cast(p_int, dtype=x.dtype) - else: - p_int = x.get_shape()[-1].value - p_float = np.array(p_int, dtype=x.dtype.as_numpy_dtype) - exponents = math_ops.linspace(p_float, 1., p_int) - - sum_weighted_log_diag = array_ops.squeeze( - math_ops.matmul(math_ops.log(diag), - exponents[..., array_ops.newaxis]), - squeeze_dims=-1) - fldj = p_float * np.log(2.) + sum_weighted_log_diag - - return fldj - - def _make_columnar(self, x): - """Ensures non-scalar input has at least one column. - - Example: - If `x = [1, 2, 3]` then the output is `[[1], [2], [3]]`. - - If `x = [[1, 2, 3], [4, 5, 6]]` then the output is unchanged. - - If `x = 1` then the output is unchanged. - - Args: - x: `Tensor`. - - Returns: - columnar_x: `Tensor` with at least two dimensions. - """ - if x.get_shape().ndims is not None: - if x.get_shape().ndims == 1: - x = x[array_ops.newaxis, :] - return x - shape = array_ops.shape(x) - maybe_expanded_shape = array_ops.concat([ - shape[:-1], - distribution_util.pick_vector( - math_ops.equal(array_ops.rank(x), 1), - [1], np.array([], dtype=np.int32)), - shape[-1:], - ], 0) - return array_ops.reshape(x, maybe_expanded_shape) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/conditional_bijector.py b/tensorflow/contrib/distributions/python/ops/bijectors/conditional_bijector.py index d254b635d28099a09a2054536f04ffee3a355b2f..ccb1f029277bc07011df7be047a075274f2b3a27 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/conditional_bijector.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/conditional_bijector.py @@ -18,12 +18,38 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.conditional_bijector_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.python.ops.distributions import bijector +from tensorflow.python.ops.distributions import util as distribution_util -_allowed_symbols = ["ConditionalBijector"] -remove_undocumented(__name__, _allowed_symbols) +__all__ = ["ConditionalBijector"] + + +class ConditionalBijector(bijector.Bijector): + """Conditional Bijector is a Bijector that allows intrinsic conditioning.""" + + @distribution_util.AppendDocstring(kwargs_dict={ + "**condition_kwargs": + "Named arguments forwarded to subclass implementation."}) + def forward(self, x, name="forward", **condition_kwargs): + return self._call_forward(x, name, **condition_kwargs) + + @distribution_util.AppendDocstring(kwargs_dict={ + "**condition_kwargs": + "Named arguments forwarded to subclass implementation."}) + def inverse(self, y, name="inverse", **condition_kwargs): + return self._call_inverse(y, name, **condition_kwargs) + + @distribution_util.AppendDocstring(kwargs_dict={ + "**condition_kwargs": + "Named arguments forwarded to subclass implementation."}) + def inverse_log_det_jacobian( + self, y, name="inverse_log_det_jacobian", **condition_kwargs): + return self._call_inverse_log_det_jacobian(y, name, **condition_kwargs) + + @distribution_util.AppendDocstring(kwargs_dict={ + "**condition_kwargs": + "Named arguments forwarded to subclass implementation."}) + def forward_log_det_jacobian( + self, x, name="forward_log_det_jacobian", **condition_kwargs): + return self._call_forward_log_det_jacobian(x, name, **condition_kwargs) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/conditional_bijector_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/conditional_bijector_impl.py deleted file mode 100644 index ccb1f029277bc07011df7be047a075274f2b3a27..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/conditional_bijector_impl.py +++ /dev/null @@ -1,55 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""ConditionalBijector base.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.ops.distributions import bijector -from tensorflow.python.ops.distributions import util as distribution_util - - -__all__ = ["ConditionalBijector"] - - -class ConditionalBijector(bijector.Bijector): - """Conditional Bijector is a Bijector that allows intrinsic conditioning.""" - - @distribution_util.AppendDocstring(kwargs_dict={ - "**condition_kwargs": - "Named arguments forwarded to subclass implementation."}) - def forward(self, x, name="forward", **condition_kwargs): - return self._call_forward(x, name, **condition_kwargs) - - @distribution_util.AppendDocstring(kwargs_dict={ - "**condition_kwargs": - "Named arguments forwarded to subclass implementation."}) - def inverse(self, y, name="inverse", **condition_kwargs): - return self._call_inverse(y, name, **condition_kwargs) - - @distribution_util.AppendDocstring(kwargs_dict={ - "**condition_kwargs": - "Named arguments forwarded to subclass implementation."}) - def inverse_log_det_jacobian( - self, y, name="inverse_log_det_jacobian", **condition_kwargs): - return self._call_inverse_log_det_jacobian(y, name, **condition_kwargs) - - @distribution_util.AppendDocstring(kwargs_dict={ - "**condition_kwargs": - "Named arguments forwarded to subclass implementation."}) - def forward_log_det_jacobian( - self, x, name="forward_log_det_jacobian", **condition_kwargs): - return self._call_forward_log_det_jacobian(x, name, **condition_kwargs) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/exp.py b/tensorflow/contrib/distributions/python/ops/bijectors/exp.py index 399d713098eb7223601beb9518dc51dd6160ad64..b1ff840d62a73c941a4d67dec73b5c9f4d5353f9 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/exp.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/exp.py @@ -18,12 +18,49 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.exp_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.contrib.distributions.python.ops.bijectors import power_transform -_allowed_symbols = ["Exp"] -remove_undocumented(__name__, _allowed_symbols) +__all__ = [ + "Exp", +] + + +class Exp(power_transform.PowerTransform): + """Compute `Y = g(X) = exp(X)`. + + Example Use: + + ```python + # Create the Y=g(X)=exp(X) transform which works only on Tensors with 1 + # batch ndim and 2 event ndims (i.e., vector of matrices). + exp = Exp(event_ndims=2) + x = [[[1., 2], + [3, 4]], + [[5, 6], + [7, 8]]] + exp(x) == exp.forward(x) + log(x) == exp.inverse(x) + ``` + + Note: the exp(.) is applied element-wise but the Jacobian is a reduction + over the event space. + """ + + def __init__(self, + event_ndims=0, + validate_args=False, + name="exp"): + """Instantiates the `Exp` bijector. + + Args: + event_ndims: Scalar `int32` `Tensor` indicating the number of dimensions + associated with a particular draw from the distribution. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + """ + super(Exp, self).__init__( + event_ndims=event_ndims, + validate_args=validate_args, + name=name) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/exp_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/exp_impl.py deleted file mode 100644 index b1ff840d62a73c941a4d67dec73b5c9f4d5353f9..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/exp_impl.py +++ /dev/null @@ -1,66 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Exp bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.distributions.python.ops.bijectors import power_transform - - -__all__ = [ - "Exp", -] - - -class Exp(power_transform.PowerTransform): - """Compute `Y = g(X) = exp(X)`. - - Example Use: - - ```python - # Create the Y=g(X)=exp(X) transform which works only on Tensors with 1 - # batch ndim and 2 event ndims (i.e., vector of matrices). - exp = Exp(event_ndims=2) - x = [[[1., 2], - [3, 4]], - [[5, 6], - [7, 8]]] - exp(x) == exp.forward(x) - log(x) == exp.inverse(x) - ``` - - Note: the exp(.) is applied element-wise but the Jacobian is a reduction - over the event space. - """ - - def __init__(self, - event_ndims=0, - validate_args=False, - name="exp"): - """Instantiates the `Exp` bijector. - - Args: - event_ndims: Scalar `int32` `Tensor` indicating the number of dimensions - associated with a particular draw from the distribution. - validate_args: Python `bool` indicating whether arguments should be - checked for correctness. - name: Python `str` name given to ops managed by this object. - """ - super(Exp, self).__init__( - event_ndims=event_ndims, - validate_args=validate_args, - name=name) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/gumbel.py b/tensorflow/contrib/distributions/python/ops/bijectors/gumbel.py index cf37aa51115ed98ab263bc03bcb297a03432a7ae..67f39785563255be0fe154aca3cbcf01c6a01e73 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/gumbel.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/gumbel.py @@ -18,12 +18,107 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.gumbel_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bijector -_allowed_symbols = ["Gumbel"] +__all__ = [ + "Gumbel", +] -remove_undocumented(__name__, _allowed_symbols) + +class Gumbel(bijector.Bijector): + """Compute `Y = g(X) = exp(-exp(-(X - loc) / scale))`. + + This bijector maps inputs from `[-inf, inf]` to [0, 1]`. The inverse of the + bijector applied to a uniform random variable `X ~ U(0, 1) gives back a + random variable with the + [Gumbel distribution](https://en.wikipedia.org/wiki/Gumbel_distribution): + + ```none + Y ~ Gumbel(loc, scale) + pdf(y; loc, scale) = exp( + -( (y - loc) / scale + exp(- (y - loc) / scale) ) ) / scale + ``` + """ + + def __init__(self, + loc=0., + scale=1., + event_ndims=0, + validate_args=False, + name="gumbel"): + """Instantiates the `Gumbel` bijector. + + Args: + loc: Float-like `Tensor` that is the same dtype and is + broadcastable with `scale`. + This is `loc` in `Y = g(X) = exp(-exp(-(X - loc) / scale))`. + scale: Positive Float-like `Tensor` that is the same dtype and is + broadcastable with `loc`. + This is `scale` in `Y = g(X) = exp(-exp(-(X - loc) / scale))`. + event_ndims: Python scalar indicating the number of dimensions associated + with a particular draw from the distribution. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + """ + self._graph_parents = [] + self._name = name + self._validate_args = validate_args + with self._name_scope("init", values=[loc, scale]): + self._loc = ops.convert_to_tensor(loc, name="loc") + self._scale = ops.convert_to_tensor(scale, name="scale") + check_ops.assert_same_float_dtype([self._loc, self._scale]) + if validate_args: + self._scale = control_flow_ops.with_dependencies([ + check_ops.assert_positive( + self._scale, message="Argument scale was not positive") + ], self._scale) + + super(Gumbel, self).__init__( + event_ndims=event_ndims, validate_args=validate_args, name=name) + + @property + def loc(self): + """The `loc` in `Y = g(X) = exp(-exp(-(X - loc) / scale))`.""" + return self._loc + + @property + def scale(self): + """This is `scale` in `Y = g(X) = exp(-exp(-(X - loc) / scale))`.""" + return self._scale + + def _forward(self, x): + z = (x - self.loc) / self.scale + return math_ops.exp(-math_ops.exp(-z)) + + def _inverse(self, y): + y = self._maybe_assert_valid_y(y) + return self.loc - self.scale * math_ops.log(-math_ops.log(y)) + + def _inverse_log_det_jacobian(self, y): + y = self._maybe_assert_valid_y(y) + event_dims = self._event_dims_tensor(y) + return math_ops.reduce_sum( + math_ops.log(self.scale / (-math_ops.log(y) * y)), axis=event_dims) + + def _forward_log_det_jacobian(self, x): + event_dims = self._event_dims_tensor(x) + z = (x - self.loc) / self.scale + return math_ops.reduce_sum( + -z - math_ops.exp(-z) - math_ops.log(self.scale), axis=event_dims) + + def _maybe_assert_valid_y(self, y): + if not self.validate_args: + return y + is_positive = check_ops.assert_non_negative( + y, message="Inverse transformation input must be greater than 0.") + less_than_one = check_ops.assert_less_equal( + y, + constant_op.constant(1., y.dtype), + message="Inverse transformation input must be less than or equal to 1.") + return control_flow_ops.with_dependencies([is_positive, less_than_one], y) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/gumbel_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/gumbel_impl.py deleted file mode 100644 index 67f39785563255be0fe154aca3cbcf01c6a01e73..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/gumbel_impl.py +++ /dev/null @@ -1,124 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Gumbel bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import ops -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops.distributions import bijector - -__all__ = [ - "Gumbel", -] - - -class Gumbel(bijector.Bijector): - """Compute `Y = g(X) = exp(-exp(-(X - loc) / scale))`. - - This bijector maps inputs from `[-inf, inf]` to [0, 1]`. The inverse of the - bijector applied to a uniform random variable `X ~ U(0, 1) gives back a - random variable with the - [Gumbel distribution](https://en.wikipedia.org/wiki/Gumbel_distribution): - - ```none - Y ~ Gumbel(loc, scale) - pdf(y; loc, scale) = exp( - -( (y - loc) / scale + exp(- (y - loc) / scale) ) ) / scale - ``` - """ - - def __init__(self, - loc=0., - scale=1., - event_ndims=0, - validate_args=False, - name="gumbel"): - """Instantiates the `Gumbel` bijector. - - Args: - loc: Float-like `Tensor` that is the same dtype and is - broadcastable with `scale`. - This is `loc` in `Y = g(X) = exp(-exp(-(X - loc) / scale))`. - scale: Positive Float-like `Tensor` that is the same dtype and is - broadcastable with `loc`. - This is `scale` in `Y = g(X) = exp(-exp(-(X - loc) / scale))`. - event_ndims: Python scalar indicating the number of dimensions associated - with a particular draw from the distribution. - validate_args: Python `bool` indicating whether arguments should be - checked for correctness. - name: Python `str` name given to ops managed by this object. - """ - self._graph_parents = [] - self._name = name - self._validate_args = validate_args - with self._name_scope("init", values=[loc, scale]): - self._loc = ops.convert_to_tensor(loc, name="loc") - self._scale = ops.convert_to_tensor(scale, name="scale") - check_ops.assert_same_float_dtype([self._loc, self._scale]) - if validate_args: - self._scale = control_flow_ops.with_dependencies([ - check_ops.assert_positive( - self._scale, message="Argument scale was not positive") - ], self._scale) - - super(Gumbel, self).__init__( - event_ndims=event_ndims, validate_args=validate_args, name=name) - - @property - def loc(self): - """The `loc` in `Y = g(X) = exp(-exp(-(X - loc) / scale))`.""" - return self._loc - - @property - def scale(self): - """This is `scale` in `Y = g(X) = exp(-exp(-(X - loc) / scale))`.""" - return self._scale - - def _forward(self, x): - z = (x - self.loc) / self.scale - return math_ops.exp(-math_ops.exp(-z)) - - def _inverse(self, y): - y = self._maybe_assert_valid_y(y) - return self.loc - self.scale * math_ops.log(-math_ops.log(y)) - - def _inverse_log_det_jacobian(self, y): - y = self._maybe_assert_valid_y(y) - event_dims = self._event_dims_tensor(y) - return math_ops.reduce_sum( - math_ops.log(self.scale / (-math_ops.log(y) * y)), axis=event_dims) - - def _forward_log_det_jacobian(self, x): - event_dims = self._event_dims_tensor(x) - z = (x - self.loc) / self.scale - return math_ops.reduce_sum( - -z - math_ops.exp(-z) - math_ops.log(self.scale), axis=event_dims) - - def _maybe_assert_valid_y(self, y): - if not self.validate_args: - return y - is_positive = check_ops.assert_non_negative( - y, message="Inverse transformation input must be greater than 0.") - less_than_one = check_ops.assert_less_equal( - y, - constant_op.constant(1., y.dtype), - message="Inverse transformation input must be less than or equal to 1.") - return control_flow_ops.with_dependencies([is_positive, less_than_one], y) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/inline.py b/tensorflow/contrib/distributions/python/ops/bijectors/inline.py index db10c3fc3a9135b4c408ada74622ba9b360f9ec1..fab1b22fbf92e7b92a5ec86ec62d66bec71a8c94 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/inline.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/inline.py @@ -18,12 +18,124 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.inline_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.python.ops.distributions import bijector -_allowed_symbols = ["Inline"] -remove_undocumented(__name__, _allowed_symbols) +__all__ = [ + "Inline", +] + + +class Inline(bijector.Bijector): + """Bijector constructed from custom callables. + + Example Use: + + ```python + exp = Inline( + forward_fn=tf.exp, + inverse_fn=tf.log, + inverse_log_det_jacobian_fn=( + lambda y: -tf.reduce_sum(tf.log(y), axis=-1)), + name="exp") + ``` + + The above example is equivalent to the `Bijector` `Exp(event_ndims=1)`. + """ + + def __init__(self, + forward_fn=None, + inverse_fn=None, + inverse_log_det_jacobian_fn=None, + forward_log_det_jacobian_fn=None, + forward_event_shape_fn=None, + forward_event_shape_tensor_fn=None, + inverse_event_shape_fn=None, + inverse_event_shape_tensor_fn=None, + is_constant_jacobian=False, + validate_args=False, + name="inline"): + """Creates a `Bijector` from callables. + + Args: + forward_fn: Python callable implementing the forward transformation. + inverse_fn: Python callable implementing the inverse transformation. + inverse_log_det_jacobian_fn: Python callable implementing the + log o det o jacobian of the inverse transformation. + forward_log_det_jacobian_fn: Python callable implementing the + log o det o jacobian of the forward transformation. + forward_event_shape_fn: Python callable implementing non-identical + static event shape changes. Default: shape is assumed unchanged. + forward_event_shape_tensor_fn: Python callable implementing non-identical + event shape changes. Default: shape is assumed unchanged. + inverse_event_shape_fn: Python callable implementing non-identical + static event shape changes. Default: shape is assumed unchanged. + inverse_event_shape_tensor_fn: Python callable implementing non-identical + event shape changes. Default: shape is assumed unchanged. + is_constant_jacobian: Python `bool` indicating that the Jacobian is + constant for all input arguments. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str`, name given to ops managed by this object. + """ + super(Inline, self).__init__( + event_ndims=0, + is_constant_jacobian=is_constant_jacobian, + validate_args=validate_args, + name=name) + self._forward_fn = forward_fn + self._inverse_fn = inverse_fn + self._inverse_log_det_jacobian_fn = inverse_log_det_jacobian_fn + self._forward_log_det_jacobian_fn = forward_log_det_jacobian_fn + self._forward_event_shape_fn = forward_event_shape_fn + self._forward_event_shape_tensor_fn = forward_event_shape_tensor_fn + self._inverse_event_shape_fn = inverse_event_shape_fn + self._inverse_event_shape_tensor_fn = inverse_event_shape_tensor_fn + + def _forward_event_shape(self, input_shape): + if self._forward_event_shape_fn is None: + # By default assume shape doesn't change. + return input_shape + return self._forward_event_shape_fn(input_shape) + + def _forward_event_shape_tensor(self, input_shape): + if self._forward_event_shape_tensor_fn is None: + # By default assume shape doesn't change. + return input_shape + return self._forward_event_shape_tensor_fn(input_shape) + + def _inverse_event_shape(self, output_shape): + if self._inverse_event_shape_fn is None: + # By default assume shape doesn't change. + return output_shape + return self._inverse_event_shape_fn(output_shape) + + def _inverse_event_shape_tensor(self, output_shape): + if self._inverse_event_shape_tensor_fn is None: + # By default assume shape doesn't change. + return output_shape + return self._inverse_event_shape_tensor_fn(output_shape) + + def _forward(self, x, **kwargs): + if not callable(self._forward_fn): + raise NotImplementedError( + "forward_fn is not a callable function.") + return self._forward_fn(x, **kwargs) + + def _inverse(self, y, **kwargs): + if not callable(self._inverse_fn): + raise NotImplementedError( + "inverse_fn is not a callable function.") + return self._inverse_fn(y, **kwargs) + + def _inverse_log_det_jacobian(self, y, **kwargs): + if not callable(self._inverse_log_det_jacobian_fn): + raise NotImplementedError( + "inverse_log_det_jacobian_fn is not a callable function.") + return self._inverse_log_det_jacobian_fn(y, **kwargs) + + def _forward_log_det_jacobian(self, y, **kwargs): + if not callable(self._forward_log_det_jacobian_fn): + raise NotImplementedError( + "forward_log_det_jacobian_fn is not a callable function.") + return self._forward_log_det_jacobian_fn(y, **kwargs) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/inline_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/inline_impl.py deleted file mode 100644 index fab1b22fbf92e7b92a5ec86ec62d66bec71a8c94..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/inline_impl.py +++ /dev/null @@ -1,141 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Inline bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.ops.distributions import bijector - - -__all__ = [ - "Inline", -] - - -class Inline(bijector.Bijector): - """Bijector constructed from custom callables. - - Example Use: - - ```python - exp = Inline( - forward_fn=tf.exp, - inverse_fn=tf.log, - inverse_log_det_jacobian_fn=( - lambda y: -tf.reduce_sum(tf.log(y), axis=-1)), - name="exp") - ``` - - The above example is equivalent to the `Bijector` `Exp(event_ndims=1)`. - """ - - def __init__(self, - forward_fn=None, - inverse_fn=None, - inverse_log_det_jacobian_fn=None, - forward_log_det_jacobian_fn=None, - forward_event_shape_fn=None, - forward_event_shape_tensor_fn=None, - inverse_event_shape_fn=None, - inverse_event_shape_tensor_fn=None, - is_constant_jacobian=False, - validate_args=False, - name="inline"): - """Creates a `Bijector` from callables. - - Args: - forward_fn: Python callable implementing the forward transformation. - inverse_fn: Python callable implementing the inverse transformation. - inverse_log_det_jacobian_fn: Python callable implementing the - log o det o jacobian of the inverse transformation. - forward_log_det_jacobian_fn: Python callable implementing the - log o det o jacobian of the forward transformation. - forward_event_shape_fn: Python callable implementing non-identical - static event shape changes. Default: shape is assumed unchanged. - forward_event_shape_tensor_fn: Python callable implementing non-identical - event shape changes. Default: shape is assumed unchanged. - inverse_event_shape_fn: Python callable implementing non-identical - static event shape changes. Default: shape is assumed unchanged. - inverse_event_shape_tensor_fn: Python callable implementing non-identical - event shape changes. Default: shape is assumed unchanged. - is_constant_jacobian: Python `bool` indicating that the Jacobian is - constant for all input arguments. - validate_args: Python `bool` indicating whether arguments should be - checked for correctness. - name: Python `str`, name given to ops managed by this object. - """ - super(Inline, self).__init__( - event_ndims=0, - is_constant_jacobian=is_constant_jacobian, - validate_args=validate_args, - name=name) - self._forward_fn = forward_fn - self._inverse_fn = inverse_fn - self._inverse_log_det_jacobian_fn = inverse_log_det_jacobian_fn - self._forward_log_det_jacobian_fn = forward_log_det_jacobian_fn - self._forward_event_shape_fn = forward_event_shape_fn - self._forward_event_shape_tensor_fn = forward_event_shape_tensor_fn - self._inverse_event_shape_fn = inverse_event_shape_fn - self._inverse_event_shape_tensor_fn = inverse_event_shape_tensor_fn - - def _forward_event_shape(self, input_shape): - if self._forward_event_shape_fn is None: - # By default assume shape doesn't change. - return input_shape - return self._forward_event_shape_fn(input_shape) - - def _forward_event_shape_tensor(self, input_shape): - if self._forward_event_shape_tensor_fn is None: - # By default assume shape doesn't change. - return input_shape - return self._forward_event_shape_tensor_fn(input_shape) - - def _inverse_event_shape(self, output_shape): - if self._inverse_event_shape_fn is None: - # By default assume shape doesn't change. - return output_shape - return self._inverse_event_shape_fn(output_shape) - - def _inverse_event_shape_tensor(self, output_shape): - if self._inverse_event_shape_tensor_fn is None: - # By default assume shape doesn't change. - return output_shape - return self._inverse_event_shape_tensor_fn(output_shape) - - def _forward(self, x, **kwargs): - if not callable(self._forward_fn): - raise NotImplementedError( - "forward_fn is not a callable function.") - return self._forward_fn(x, **kwargs) - - def _inverse(self, y, **kwargs): - if not callable(self._inverse_fn): - raise NotImplementedError( - "inverse_fn is not a callable function.") - return self._inverse_fn(y, **kwargs) - - def _inverse_log_det_jacobian(self, y, **kwargs): - if not callable(self._inverse_log_det_jacobian_fn): - raise NotImplementedError( - "inverse_log_det_jacobian_fn is not a callable function.") - return self._inverse_log_det_jacobian_fn(y, **kwargs) - - def _forward_log_det_jacobian(self, y, **kwargs): - if not callable(self._forward_log_det_jacobian_fn): - raise NotImplementedError( - "forward_log_det_jacobian_fn is not a callable function.") - return self._forward_log_det_jacobian_fn(y, **kwargs) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/invert.py b/tensorflow/contrib/distributions/python/ops/bijectors/invert.py index c134e10109ce5065eb58de1d847e3c487258954c..2c603fe61f36dd27f4984fe6c13c11f2fb534321 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/invert.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/invert.py @@ -18,12 +18,85 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.invert_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.python.ops.distributions import bijector as bijector_lib -_allowed_symbols = ["Invert"] +__all__ = [ + "Invert", +] -remove_undocumented(__name__, _allowed_symbols) + +class Invert(bijector_lib.Bijector): + """Bijector which inverts another Bijector. + + Example Use: [ExpGammaDistribution (see Background & Context)]( + https://reference.wolfram.com/language/ref/ExpGammaDistribution.html) + models `Y=log(X)` where `X ~ Gamma`. + + ```python + exp_gamma_distribution = TransformedDistribution( + distribution=Gamma(concentration=1., rate=2.), + bijector=bijector.Invert(bijector.Exp()) + ``` + + """ + + def __init__(self, bijector, validate_args=False, name=None): + """Creates a `Bijector` which swaps the meaning of `inverse` and `forward`. + + Note: An inverted bijector's `inverse_log_det_jacobian` is often more + efficient if the base bijector implements `_forward_log_det_jacobian`. If + `_forward_log_det_jacobian` is not implemented then the following code is + used: + + ```python + y = self.inverse(x, **kwargs) + return -self.inverse_log_det_jacobian(y, **kwargs) + ``` + + Args: + bijector: Bijector instance. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str`, name given to ops managed by this object. + """ + + if not bijector._is_injective: # pylint: disable=protected-access + raise NotImplementedError( + "Invert is not implemented for non-injective bijectors.") + + self._bijector = bijector + super(Invert, self).__init__( + event_ndims=bijector.event_ndims, + graph_parents=bijector.graph_parents, + is_constant_jacobian=bijector.is_constant_jacobian, + validate_args=validate_args, + dtype=bijector.dtype, + name=name or "_".join(["invert", bijector.name])) + + def _forward_event_shape(self, input_shape): + return self.bijector._inverse_event_shape(input_shape) # pylint: disable=protected-access + + def _forward_event_shape_tensor(self, input_shape): + return self.bijector._inverse_event_shape_tensor(input_shape) # pylint: disable=protected-access + + def _inverse_event_shape(self, output_shape): + return self.bijector._forward_event_shape(output_shape) # pylint: disable=protected-access + + def _inverse_event_shape_tensor(self, output_shape): + return self.bijector._forward_event_shape_tensor(output_shape) # pylint: disable=protected-access + + @property + def bijector(self): + return self._bijector + + def _forward(self, x, **kwargs): + return self.bijector._inverse(x, **kwargs) # pylint: disable=protected-access + + def _inverse(self, y, **kwargs): + return self.bijector._forward(y, **kwargs) # pylint: disable=protected-access + + def _inverse_log_det_jacobian(self, y, **kwargs): + return self.bijector._forward_log_det_jacobian(y, **kwargs) # pylint: disable=protected-access + + def _forward_log_det_jacobian(self, x, **kwargs): + return self.bijector._inverse_log_det_jacobian(x, **kwargs) # pylint: disable=protected-access diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/invert_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/invert_impl.py deleted file mode 100644 index 2c603fe61f36dd27f4984fe6c13c11f2fb534321..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/invert_impl.py +++ /dev/null @@ -1,102 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Invert bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.ops.distributions import bijector as bijector_lib - -__all__ = [ - "Invert", -] - - -class Invert(bijector_lib.Bijector): - """Bijector which inverts another Bijector. - - Example Use: [ExpGammaDistribution (see Background & Context)]( - https://reference.wolfram.com/language/ref/ExpGammaDistribution.html) - models `Y=log(X)` where `X ~ Gamma`. - - ```python - exp_gamma_distribution = TransformedDistribution( - distribution=Gamma(concentration=1., rate=2.), - bijector=bijector.Invert(bijector.Exp()) - ``` - - """ - - def __init__(self, bijector, validate_args=False, name=None): - """Creates a `Bijector` which swaps the meaning of `inverse` and `forward`. - - Note: An inverted bijector's `inverse_log_det_jacobian` is often more - efficient if the base bijector implements `_forward_log_det_jacobian`. If - `_forward_log_det_jacobian` is not implemented then the following code is - used: - - ```python - y = self.inverse(x, **kwargs) - return -self.inverse_log_det_jacobian(y, **kwargs) - ``` - - Args: - bijector: Bijector instance. - validate_args: Python `bool` indicating whether arguments should be - checked for correctness. - name: Python `str`, name given to ops managed by this object. - """ - - if not bijector._is_injective: # pylint: disable=protected-access - raise NotImplementedError( - "Invert is not implemented for non-injective bijectors.") - - self._bijector = bijector - super(Invert, self).__init__( - event_ndims=bijector.event_ndims, - graph_parents=bijector.graph_parents, - is_constant_jacobian=bijector.is_constant_jacobian, - validate_args=validate_args, - dtype=bijector.dtype, - name=name or "_".join(["invert", bijector.name])) - - def _forward_event_shape(self, input_shape): - return self.bijector._inverse_event_shape(input_shape) # pylint: disable=protected-access - - def _forward_event_shape_tensor(self, input_shape): - return self.bijector._inverse_event_shape_tensor(input_shape) # pylint: disable=protected-access - - def _inverse_event_shape(self, output_shape): - return self.bijector._forward_event_shape(output_shape) # pylint: disable=protected-access - - def _inverse_event_shape_tensor(self, output_shape): - return self.bijector._forward_event_shape_tensor(output_shape) # pylint: disable=protected-access - - @property - def bijector(self): - return self._bijector - - def _forward(self, x, **kwargs): - return self.bijector._inverse(x, **kwargs) # pylint: disable=protected-access - - def _inverse(self, y, **kwargs): - return self.bijector._forward(y, **kwargs) # pylint: disable=protected-access - - def _inverse_log_det_jacobian(self, y, **kwargs): - return self.bijector._forward_log_det_jacobian(y, **kwargs) # pylint: disable=protected-access - - def _forward_log_det_jacobian(self, x, **kwargs): - return self.bijector._inverse_log_det_jacobian(x, **kwargs) # pylint: disable=protected-access diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py b/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py index 132dc570f94719b6c71fb269866c943774481b7e..06c7c61ec3dc3980e0d12a984739dca5a925ac9f 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py @@ -18,16 +18,459 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.masked_autoregressive_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +import numpy as np -_allowed_symbols = [ +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.layers import core as layers +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import clip_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import template as template_ops +from tensorflow.python.ops import variable_scope as variable_scope_lib +from tensorflow.python.ops.distributions import bijector as bijector_lib + + +__all__ = [ "MaskedAutoregressiveFlow", - "masked_dense", "masked_autoregressive_default_template", + "masked_dense", ] -remove_undocumented(__name__, _allowed_symbols) + +class MaskedAutoregressiveFlow(bijector_lib.Bijector): + """Affine MaskedAutoregressiveFlow bijector for vector-valued events. + + The affine autoregressive flow [1] provides a relatively simple framework for + user-specified (deep) architectures to learn a distribution over vector-valued + events. Regarding terminology, + + "Autoregressive models decompose the joint density as a product of + conditionals, and model each conditional in turn. Normalizing flows + transform a base density (e.g. a standard Gaussian) into the target density + by an invertible transformation with tractable Jacobian." [1] + + In other words, the "autoregressive property" is equivalent to the + decomposition, `p(x) = prod{ p(x[i] | x[0:i]) : i=0, ..., d }`. The provided + `shift_and_log_scale_fn`, `masked_autoregressive_default_template`, achieves + this property by zeroing out weights in its `masked_dense` layers. + + In the `tf.distributions` framework, a "normalizing flow" is implemented as a + `tf.distributions.bijectors.Bijector`. The `forward` "autoregression" + is implemented using a `tf.while_loop` and a deep neural network (DNN) with + masked weights such that the autoregressive property is automatically met in + the `inverse`. + + A `TransformedDistribution` using `MaskedAutoregressiveFlow(...)` uses the + (expensive) forward-mode calculation to draw samples and the (cheap) + reverse-mode calculation to compute log-probabilities. Conversely, a + `TransformedDistribution` using `Invert(MaskedAutoregressiveFlow(...))` uses + the (expensive) forward-mode calculation to compute log-probabilities and the + (cheap) reverse-mode calculation to compute samples. See "Example Use" + [below] for more details. + + Given a `shift_and_log_scale_fn`, the forward and inverse transformations are + (a sequence of) affine transformations. A "valid" `shift_and_log_scale_fn` + must compute each `shift` (aka `loc` or "mu" [2]) and `log(scale)` (aka + "alpha" [2]) such that each are broadcastable with the arguments to `forward` + and `inverse`, i.e., such that the calculations in `forward`, `inverse` + [below] are possible. + + For convenience, `masked_autoregressive_default_template` is offered as a + possible `shift_and_log_scale_fn` function. It implements the MADE + architecture [2]. MADE is a feed-forward network that computes a `shift` and + `log(scale)` using `masked_dense` layers in a deep neural network. Weights are + masked to ensure the autoregressive property. It is possible that this + architecture is suboptimal for your task. To build alternative networks, + either change the arguments to `masked_autoregressive_default_template`, use + the `masked_dense` function to roll-out your own, or use some other + architecture, e.g., using `tf.layers`. + + Warning: no attempt is made to validate that the `shift_and_log_scale_fn` + enforces the "autoregressive property". + + Assuming `shift_and_log_scale_fn` has valid shape and autoregressive + semantics, the forward transformation is, + + ```python + def forward(x): + y = zeros_like(x) + event_size = x.shape[-1] + for _ in range(event_size): + shift, log_scale = shift_and_log_scale_fn(y) + y = x * math_ops.exp(log_scale) + shift + return y + ``` + + and the inverse transformation is, + + ```python + def inverse(y): + shift, log_scale = shift_and_log_scale_fn(y) + return (y - shift) / math_ops.exp(log_scale) + ``` + + Notice that the `inverse` does not need a for-loop. This is because in the + forward pass each calculation of `shift` and `log_scale` is based on the `y` + calculated so far (not `x`). In the `inverse`, the `y` is fully known, thus is + equivalent to the scaling used in `forward` after `event_size` passes, i.e., + the "last" `y` used to compute `shift`, `log_scale`. (Roughly speaking, this + also proves the transform is bijective.) + + #### Example Use + + ```python + tfd = tf.contrib.distributions + tfb = tfd.bijectors + + dims = 5 + + # A common choice for a normalizing flow is to use a Gaussian for the base + # distribution. (However, any continuous distribution would work.) E.g., + maf = tfd.TransformedDistribution( + distribution=tfd.Normal(loc=0., scale=1.), + bijector=tfb.MaskedAutoregressiveFlow( + shift_and_log_scale_fn=tfb.masked_autoregressive_default_template( + hidden_layers=[512, 512])), + event_shape=[dims]) + + x = maf.sample() # Expensive; uses `tf.while_loop`, no Bijector caching. + maf.log_prob(x) # Almost free; uses Bijector caching. + maf.log_prob(0.) # Cheap; no `tf.while_loop` despite no Bijector caching. + + # [1] also describes an "Inverse Autoregressive Flow", e.g., + iaf = tfd.TransformedDistribution( + distribution=tfd.Normal(loc=0., scale=1.), + bijector=tfb.Invert(tfb.MaskedAutoregressiveFlow( + shift_and_log_scale_fn=tfb.masked_autoregressive_default_template( + hidden_layers=[512, 512]))), + event_shape=[dims]) + + x = iaf.sample() # Cheap; no `tf.while_loop` despite no Bijector caching. + iaf.log_prob(x) # Almost free; uses Bijector caching. + iaf.log_prob(0.) # Expensive; uses `tf.while_loop`, no Bijector caching. + + # In many (if not most) cases the default `shift_and_log_scale_fn` will be a + # poor choice. Here's an example of using a "shift only" version and with a + # different number/depth of hidden layers. + shift_only = True + maf_no_scale_hidden2 = tfd.TransformedDistribution( + distribution=tfd.Normal(loc=0., scale=1.), + bijector=tfb.MaskedAutoregressiveFlow( + tfb.masked_autoregressive_default_template( + hidden_layers=[32], + shift_only=shift_only), + is_constant_jacobian=shift_only), + event_shape=[dims]) + ``` + + [1]: "Masked Autoregressive Flow for Density Estimation." + George Papamakarios, Theo Pavlakou, Iain Murray. Arxiv. 2017. + https://arxiv.org/abs/1705.07057 + + [2]: "MADE: Masked Autoencoder for Distribution Estimation." + Mathieu Germain, Karol Gregor, Iain Murray, Hugo Larochelle. ICML. 2015. + https://arxiv.org/abs/1502.03509 + + """ + + def __init__(self, + shift_and_log_scale_fn, + is_constant_jacobian=False, + validate_args=False, + name=None): + """Creates the MaskedAutoregressiveFlow bijector. + + Args: + shift_and_log_scale_fn: Python `callable` which computes `shift` and + `log_scale` from both the forward domain (`x`) and the inverse domain + (`y`). Calculation must respect the "autoregressive property" (see class + docstring). Suggested default + `masked_autoregressive_default_template(hidden_layers=...)`. + Typically the function contains `tf.Variables` and is wrapped using + `tf.make_template`. Returning `None` for either (both) `shift`, + `log_scale` is equivalent to (but more efficient than) returning zero. + is_constant_jacobian: Python `bool`. Default: `False`. When `True` the + implementation assumes `log_scale` does not depend on the forward domain + (`x`) or inverse domain (`y`) values. (No validation is made; + `is_constant_jacobian=False` is always safe but possibly computationally + inefficient.) + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str`, name given to ops managed by this object. + """ + name = name or "masked_autoregressive_flow" + self._shift_and_log_scale_fn = shift_and_log_scale_fn + super(MaskedAutoregressiveFlow, self).__init__( + is_constant_jacobian=is_constant_jacobian, + validate_args=validate_args, + name=name) + + def _forward(self, x): + event_size = array_ops.shape(x)[-1] + y0 = array_ops.zeros_like(x, name="y0") + # call the template once to ensure creation + _ = self._shift_and_log_scale_fn(y0) + def _loop_body(index, y0): + """While-loop body for autoregression calculation.""" + # Set caching device to avoid re-getting the tf.Variable for every while + # loop iteration. + with variable_scope_lib.variable_scope( + variable_scope_lib.get_variable_scope()) as vs: + if vs.caching_device is None: + vs.set_caching_device(lambda op: op.device) + shift, log_scale = self._shift_and_log_scale_fn(y0) + y = x + if log_scale is not None: + y *= math_ops.exp(log_scale) + if shift is not None: + y += shift + return index + 1, y + _, y = control_flow_ops.while_loop( + cond=lambda index, _: index < event_size, + body=_loop_body, + loop_vars=[0, y0]) + return y + + def _inverse(self, y): + shift, log_scale = self._shift_and_log_scale_fn(y) + x = y + if shift is not None: + x -= shift + if log_scale is not None: + x *= math_ops.exp(-log_scale) + return x + + def _inverse_log_det_jacobian(self, y): + _, log_scale = self._shift_and_log_scale_fn(y) + if log_scale is None: + return constant_op.constant(0., dtype=y.dtype, name="ildj") + return -math_ops.reduce_sum(log_scale, axis=-1) + + +MASK_INCLUSIVE = "inclusive" +MASK_EXCLUSIVE = "exclusive" + + +def _gen_slices(num_blocks, n_in, n_out, mask_type=MASK_EXCLUSIVE): + """Generate the slices for building an autoregressive mask.""" + # TODO(b/67594795): Better support of dynamic shape. + slices = [] + col = 0 + d_in = n_in // num_blocks + d_out = n_out // num_blocks + row = d_out if mask_type == MASK_EXCLUSIVE else 0 + for _ in range(num_blocks): + row_slice = slice(row, None) + col_slice = slice(col, col + d_in) + slices.append([row_slice, col_slice]) + col += d_in + row += d_out + return slices + + +def _gen_mask(num_blocks, + n_in, + n_out, + mask_type=MASK_EXCLUSIVE, + dtype=dtypes.float32): + """Generate the mask for building an autoregressive dense layer.""" + # TODO(b/67594795): Better support of dynamic shape. + mask = np.zeros([n_out, n_in], dtype=dtype.as_numpy_dtype()) + slices = _gen_slices(num_blocks, n_in, n_out, mask_type=mask_type) + for [row_slice, col_slice] in slices: + mask[row_slice, col_slice] = 1 + return mask + + +def masked_dense(inputs, + units, + num_blocks=None, + exclusive=False, + kernel_initializer=None, + reuse=None, + name=None, + *args, + **kwargs): + """A autoregressively masked dense layer. Analogous to `tf.layers.dense`. + + See [1] for detailed explanation. + + [1]: "MADE: Masked Autoencoder for Distribution Estimation." + Mathieu Germain, Karol Gregor, Iain Murray, Hugo Larochelle. ICML. 2015. + https://arxiv.org/abs/1502.03509 + + Arguments: + inputs: Tensor input. + units: Python `int` scalar representing the dimensionality of the output + space. + num_blocks: Python `int` scalar representing the number of blocks for the + MADE masks. + exclusive: Python `bool` scalar representing whether to zero the diagonal of + the mask, used for the first layer of a MADE. + kernel_initializer: Initializer function for the weight matrix. + If `None` (default), weights are initialized using the + `tf.glorot_random_initializer`. + reuse: Python `bool` scalar representing whether to reuse the weights of a + previous layer by the same name. + name: Python `str` used to describe ops managed by this function. + *args: `tf.layers.dense` arguments. + **kwargs: `tf.layers.dense` keyword arguments. + + Returns: + Output tensor. + + Raises: + NotImplementedError: if rightmost dimension of `inputs` is unknown prior to + graph execution. + """ + # TODO(b/67594795): Better support of dynamic shape. + input_depth = inputs.shape.with_rank_at_least(1)[-1].value + if input_depth is None: + raise NotImplementedError( + "Rightmost dimension must be known prior to graph execution.") + + mask = _gen_mask(num_blocks, input_depth, units, + MASK_EXCLUSIVE if exclusive else MASK_INCLUSIVE).T + + if kernel_initializer is None: + kernel_initializer = init_ops.glorot_normal_initializer() + + def masked_initializer(shape, dtype=None, partition_info=None): + return mask * kernel_initializer(shape, dtype, partition_info) + + with ops.name_scope(name, "masked_dense", [inputs, units, num_blocks]): + layer = layers.Dense( + units, + kernel_initializer=masked_initializer, + kernel_constraint=lambda x: mask * x, + name=name, + dtype=inputs.dtype.base_dtype, + _scope=name, + _reuse=reuse, + *args, + **kwargs) + return layer.apply(inputs) + + +def masked_autoregressive_default_template( + hidden_layers, + shift_only=False, + activation=nn_ops.relu, + log_scale_min_clip=-5., + log_scale_max_clip=3., + log_scale_clip_gradient=False, + name=None, + *args, + **kwargs): + """Build the MADE Model [1]. + + This will be wrapped in a make_template to ensure the variables are only + created once. It takes the input and returns the `loc` ("mu" [1]) and + `log_scale` ("alpha" [1]) from the MADE network. + + Warning: This function uses `masked_dense` to create randomly initialized + `tf.Variables`. It is presumed that these will be fit, just as you would any + other neural architecture which uses `tf.layers.dense`. + + #### About Hidden Layers: + + Each element of `hidden_layers` should be greater than the `input_depth` + (i.e., `input_depth = tf.shape(input)[-1]` where `input` is the input to the + neural network). This is necessary to ensure the autoregressivity property. + + #### About Clipping: + + This function also optionally clips the `log_scale` (but possibly not its + gradient). This is useful because if `log_scale` is too small/large it might + underflow/overflow making it impossible for the `MaskedAutoregressiveFlow` + bijector to implement a bijection. Additionally, the `log_scale_clip_gradient` + `bool` indicates whether the gradient should also be clipped. The default does + not clip the gradient; this is useful because it still provides gradient + information (for fitting) yet solves the numerical stability problem. I.e., + `log_scale_clip_gradient = False` means + `grad[exp(clip(x))] = grad[x] exp(clip(x))` rather than the usual + `grad[clip(x)] exp(clip(x))`. + + [1]: "MADE: Masked Autoencoder for Distribution Estimation." + Mathieu Germain, Karol Gregor, Iain Murray, Hugo Larochelle. ICML. 2015. + https://arxiv.org/abs/1502.03509 + + Arguments: + hidden_layers: Python `list`-like of non-negative integer, scalars + indicating the number of units in each hidden layer. Default: `[512, 512]. + shift_only: Python `bool` indicating if only the `shift` term shall be + computed. Default: `False`. + activation: Activation function (callable). Explicitly setting to `None` + implies a linear activation. + log_scale_min_clip: `float`-like scalar `Tensor`, or a `Tensor` with the + same shape as `log_scale`. The minimum value to clip by. Default: -5. + log_scale_max_clip: `float`-like scalar `Tensor`, or a `Tensor` with the + same shape as `log_scale`. The maximum value to clip by. Default: 3. + log_scale_clip_gradient: Python `bool` indicating that the gradient of + `tf.clip_by_value` should be preserved. Default: `False`. + name: A name for ops managed by this function. Default: + "masked_autoregressive_default_template". + *args: `tf.layers.dense` arguments. + **kwargs: `tf.layers.dense` keyword arguments. + + Returns: + shift: `Float`-like `Tensor` of shift terms (the "mu" in [2]). + log_scale: `Float`-like `Tensor` of log(scale) terms (the "alpha" in [2]). + + Raises: + NotImplementedError: if rightmost dimension of `inputs` is unknown prior to + graph execution. + """ + + with ops.name_scope(name, "masked_autoregressive_default_template", + values=[log_scale_min_clip, log_scale_max_clip]): + def _fn(x): + """MADE parameterized via `masked_autoregressive_default_template`.""" + # TODO(b/67594795): Better support of dynamic shape. + input_depth = x.shape.with_rank_at_least(1)[-1].value + if input_depth is None: + raise NotImplementedError( + "Rightmost dimension must be known prior to graph execution.") + input_shape = (np.int32(x.shape.as_list()) if x.shape.is_fully_defined() + else array_ops.shape(x)) + for i, units in enumerate(hidden_layers): + x = masked_dense( + inputs=x, + units=units, + num_blocks=input_depth, + exclusive=True if i == 0 else False, + activation=activation, + *args, + **kwargs) + x = masked_dense( + inputs=x, + units=(1 if shift_only else 2) * input_depth, + num_blocks=input_depth, + activation=None, + *args, + **kwargs) + if shift_only: + x = array_ops.reshape(x, shape=input_shape) + return x, None + x = array_ops.reshape( + x, shape=array_ops.concat([input_shape, [2]], axis=0)) + shift, log_scale = array_ops.unstack(x, num=2, axis=-1) + which_clip = (math_ops.clip_by_value if log_scale_clip_gradient + else _clip_by_value_preserve_grad) + log_scale = which_clip(log_scale, log_scale_min_clip, log_scale_max_clip) + return shift, log_scale + return template_ops.make_template( + "masked_autoregressive_default_template", _fn) + + +def _clip_by_value_preserve_grad(x, clip_value_min, clip_value_max, name=None): + """Clips input while leaving gradient unaltered.""" + with ops.name_scope(name, "clip_by_value_preserve_grad", + [x, clip_value_min, clip_value_max]): + clip_x = clip_ops.clip_by_value(x, clip_value_min, clip_value_max) + return x + array_ops.stop_gradient(clip_x - x) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive_impl.py deleted file mode 100644 index ae142883931274b594dbbafbe86bd71e75c621bc..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive_impl.py +++ /dev/null @@ -1,473 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""MaskedAutoregressiveFlow bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.layers import core as layers -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import clip_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn_ops -from tensorflow.python.ops import template as template_ops -from tensorflow.python.ops import variable_scope as variable_scope_lib -from tensorflow.python.ops.distributions import bijector as bijector_lib - - -__all__ = [ - "MaskedAutoregressiveFlow", - "masked_autoregressive_default_template", - "masked_dense", -] - - -class MaskedAutoregressiveFlow(bijector_lib.Bijector): - """Affine MaskedAutoregressiveFlow bijector for vector-valued events. - - The affine autoregressive flow [1] provides a relatively simple framework for - user-specified (deep) architectures to learn a distribution over vector-valued - events. Regarding terminology, - - "Autoregressive models decompose the joint density as a product of - conditionals, and model each conditional in turn. Normalizing flows - transform a base density (e.g. a standard Gaussian) into the target density - by an invertible transformation with tractable Jacobian." [1] - - In other words, the "autoregressive property" is equivalent to the - decomposition, `p(x) = prod{ p(x[i] | x[0:i]) : i=0, ..., d }`. The provided - `shift_and_log_scale_fn`, `masked_autoregressive_default_template`, achieves - this property by zeroing out weights in its `masked_dense` layers. - - In the `tf.distributions` framework, a "normalizing flow" is implemented as a - `tf.distributions.bijectors.Bijector`. The `forward` "autoregression" - is implemented using a `tf.while_loop` and a deep neural network (DNN) with - masked weights such that the autoregressive property is automatically met in - the `inverse`. - - A `TransformedDistribution` using `MaskedAutoregressiveFlow(...)` uses the - (expensive) forward-mode calculation to draw samples and the (cheap) - reverse-mode calculation to compute log-probabilities. Conversely, a - `TransformedDistribution` using `Invert(MaskedAutoregressiveFlow(...))` uses - the (expensive) forward-mode calculation to compute log-probabilities and the - (cheap) reverse-mode calculation to compute samples. See "Example Use" - [below] for more details. - - Given a `shift_and_log_scale_fn`, the forward and inverse transformations are - (a sequence of) affine transformations. A "valid" `shift_and_log_scale_fn` - must compute each `shift` (aka `loc` or "mu" [2]) and `log(scale)` (aka - "alpha" [2]) such that each are broadcastable with the arguments to `forward` - and `inverse`, i.e., such that the calculations in `forward`, `inverse` - [below] are possible. - - For convenience, `masked_autoregressive_default_template` is offered as a - possible `shift_and_log_scale_fn` function. It implements the MADE - architecture [2]. MADE is a feed-forward network that computes a `shift` and - `log(scale)` using `masked_dense` layers in a deep neural network. Weights are - masked to ensure the autoregressive property. It is possible that this - architecture is suboptimal for your task. To build alternative networks, - either change the arguments to `masked_autoregressive_default_template`, use - the `masked_dense` function to roll-out your own, or use some other - architecture, e.g., using `tf.layers`. - - Warning: no attempt is made to validate that the `shift_and_log_scale_fn` - enforces the "autoregressive property". - - Assuming `shift_and_log_scale_fn` has valid shape and autoregressive - semantics, the forward transformation is, - - ```python - def forward(x): - y = zeros_like(x) - event_size = x.shape[-1] - for _ in range(event_size): - shift, log_scale = shift_and_log_scale_fn(y) - y = x * math_ops.exp(log_scale) + shift - return y - ``` - - and the inverse transformation is, - - ```python - def inverse(y): - shift, log_scale = shift_and_log_scale_fn(y) - return (y - shift) / math_ops.exp(log_scale) - ``` - - Notice that the `inverse` does not need a for-loop. This is because in the - forward pass each calculation of `shift` and `log_scale` is based on the `y` - calculated so far (not `x`). In the `inverse`, the `y` is fully known, thus is - equivalent to the scaling used in `forward` after `event_size` passes, i.e., - the "last" `y` used to compute `shift`, `log_scale`. (Roughly speaking, this - also proves the transform is bijective.) - - #### Example Use - - ```python - ds = tf.contrib.distributions - bs = tf.contrib.distributions.bijectors - - dims = 5 - - # A common choice for a normalizing flow is to use a Gaussian for the base - # distribution. (However, any continuous distribution would work.) E.g., - maf = ds.TransformedDistribution( - distribution=ds.Normal(loc=0., scale=1.), - bijector=bs.MaskedAutoregressiveFlow( - shift_and_log_scale_fn=bs.masked_autoregressive_default_template( - hidden_layers=[512, 512])), - event_shape=[dims]) - - x = maf.sample() # Expensive; uses `tf.while_loop`, no Bijector caching. - maf.log_prob(x) # Almost free; uses Bijector caching. - maf.log_prob(0.) # Cheap; no `tf.while_loop` despite no Bijector caching. - - # [1] also describes an "Inverse Autoregressive Flow", e.g., - iaf = ds.TransformedDistribution( - distribution=ds.Normal(loc=0., scale=1.), - bijector=bs.Invert(bs.MaskedAutoregressiveFlow( - shift_and_log_scale_fn=bs.masked_autoregressive_default_template( - hidden_layers=[512, 512]))), - event_shape=[dims]) - - x = iaf.sample() # Cheap; no `tf.while_loop` despite no Bijector caching. - iaf.log_prob(x) # Almost free; uses Bijector caching. - iaf.log_prob(0.) # Expensive; uses `tf.while_loop`, no Bijector caching. - - # In many (if not most) cases the default `shift_and_log_scale_fn` will be a - # poor choice. Here's an example of using a "shift only" version and with a - # different number/depth of hidden layers. - shift_only = True - maf_no_scale_hidden2 = ds.TransformedDistribution( - distribution=ds.Normal(loc=0., scale=1.), - bijector=bs.MaskedAutoregressiveFlow( - bs.masked_autoregressive_default_template( - hidden_layers=[32], - shift_only=shift_only), - is_constant_jacobian=shift_only), - event_shape=[dims]) - ``` - - [1]: "Masked Autoregressive Flow for Density Estimation." - George Papamakarios, Theo Pavlakou, Iain Murray. Arxiv. 2017. - https://arxiv.org/abs/1705.07057 - - [2]: "MADE: Masked Autoencoder for Distribution Estimation." - Mathieu Germain, Karol Gregor, Iain Murray, Hugo Larochelle. ICML. 2015. - https://arxiv.org/abs/1502.03509 - - """ - - def __init__(self, - shift_and_log_scale_fn, - is_constant_jacobian=False, - validate_args=False, - name=None): - """Creates the MaskedAutoregressiveFlow bijector. - - Args: - shift_and_log_scale_fn: Python `callable` which computes `shift` and - `log_scale` from both the forward domain (`x`) and the inverse domain - (`y`). Calculation must respect the "autoregressive property" (see class - docstring). Suggested default - `masked_autoregressive_default_template(hidden_layers=...)`. - Typically the function contains `tf.Variables` and is wrapped using - `tf.make_template`. Returning `None` for either (both) `shift`, - `log_scale` is equivalent to (but more efficient than) returning zero. - is_constant_jacobian: Python `bool`. Default: `False`. When `True` the - implementation assumes `log_scale` does not depend on the forward domain - (`x`) or inverse domain (`y`) values. (No validation is made; - `is_constant_jacobian=False` is always safe but possibly computationally - inefficient.) - validate_args: Python `bool` indicating whether arguments should be - checked for correctness. - name: Python `str`, name given to ops managed by this object. - """ - name = name or "masked_autoregressive_flow" - self._shift_and_log_scale_fn = shift_and_log_scale_fn - super(MaskedAutoregressiveFlow, self).__init__( - is_constant_jacobian=is_constant_jacobian, - validate_args=validate_args, - name=name) - - def _forward(self, x): - event_size = array_ops.shape(x)[-1] - def _loop_body(index, y0): - """While-loop body for autoregression calculation.""" - # Set caching device to avoid re-getting the tf.Variable for every while - # loop iteration. - with variable_scope_lib.variable_scope( - variable_scope_lib.get_variable_scope()) as vs: - if vs.caching_device is None: - vs.set_caching_device(lambda op: op.device) - shift, log_scale = self._shift_and_log_scale_fn(y0) - y = x - if log_scale is not None: - y *= math_ops.exp(log_scale) - if shift is not None: - y += shift - return index + 1, y - _, y = control_flow_ops.while_loop( - cond=lambda index, _: index < event_size, - body=_loop_body, - loop_vars=[0, array_ops.zeros_like(x, name="y0")]) - return y - - def _inverse(self, y): - shift, log_scale = self._shift_and_log_scale_fn(y) - x = y - if shift is not None: - x -= shift - if log_scale is not None: - x *= math_ops.exp(-log_scale) - return x - - def _inverse_log_det_jacobian(self, y): - _, log_scale = self._shift_and_log_scale_fn(y) - if log_scale is None: - return constant_op.constant(0., dtype=y.dtype, name="ildj") - return -math_ops.reduce_sum(log_scale, axis=-1) - - -MASK_INCLUSIVE = "inclusive" -MASK_EXCLUSIVE = "exclusive" - - -def _gen_slices(num_blocks, n_in, n_out, mask_type=MASK_EXCLUSIVE): - """Generate the slices for building an autoregressive mask.""" - # TODO(b/67594795): Better support of dynamic shape. - slices = [] - col = 0 - d_in = n_in // num_blocks - d_out = n_out // num_blocks - row = d_out if mask_type == MASK_EXCLUSIVE else 0 - for _ in range(num_blocks): - row_slice = slice(row, None) - col_slice = slice(col, col + d_in) - slices.append([row_slice, col_slice]) - col += d_in - row += d_out - return slices - - -def _gen_mask(num_blocks, - n_in, - n_out, - mask_type=MASK_EXCLUSIVE, - dtype=dtypes.float32): - """Generate the mask for building an autoregressive dense layer.""" - # TODO(b/67594795): Better support of dynamic shape. - mask = np.zeros([n_out, n_in], dtype=dtype.as_numpy_dtype()) - slices = _gen_slices(num_blocks, n_in, n_out, mask_type=mask_type) - for [row_slice, col_slice] in slices: - mask[row_slice, col_slice] = 1 - return mask - - -def masked_dense(inputs, - units, - num_blocks=None, - exclusive=False, - kernel_initializer=None, - reuse=None, - name=None, - *args, - **kwargs): - """A autoregressively masked dense layer. Analogous to `tf.layers.dense`. - - See [1] for detailed explanation. - - [1]: "MADE: Masked Autoencoder for Distribution Estimation." - Mathieu Germain, Karol Gregor, Iain Murray, Hugo Larochelle. ICML. 2015. - https://arxiv.org/abs/1502.03509 - - Arguments: - inputs: Tensor input. - units: Python `int` scalar representing the dimensionality of the output - space. - num_blocks: Python `int` scalar representing the number of blocks for the - MADE masks. - exclusive: Python `bool` scalar representing whether to zero the diagonal of - the mask, used for the first layer of a MADE. - kernel_initializer: Initializer function for the weight matrix. - If `None` (default), weights are initialized using the - `tf.glorot_random_initializer`. - reuse: Python `bool` scalar representing whether to reuse the weights of a - previous layer by the same name. - name: Python `str` used to describe ops managed by this function. - *args: `tf.layers.dense` arguments. - **kwargs: `tf.layers.dense` keyword arguments. - - Returns: - Output tensor. - - Raises: - NotImplementedError: if rightmost dimension of `inputs` is unknown prior to - graph execution. - """ - # TODO(b/67594795): Better support of dynamic shape. - input_depth = inputs.shape.with_rank_at_least(1)[-1].value - if input_depth is None: - raise NotImplementedError( - "Rightmost dimension must be known prior to graph execution.") - - mask = _gen_mask(num_blocks, input_depth, units, - MASK_EXCLUSIVE if exclusive else MASK_INCLUSIVE).T - - if kernel_initializer is None: - kernel_initializer = init_ops.glorot_normal_initializer() - - def masked_initializer(shape, dtype=None, partition_info=None): - return mask * kernel_initializer(shape, dtype, partition_info) - - with ops.name_scope(name, "masked_dense", [inputs, units, num_blocks]): - layer = layers.Dense( - units, - kernel_initializer=masked_initializer, - kernel_constraint=lambda x: mask * x, - name=name, - dtype=inputs.dtype.base_dtype, - _scope=name, - _reuse=reuse, - *args, - **kwargs) - return layer.apply(inputs) - - -def masked_autoregressive_default_template( - hidden_layers, - shift_only=False, - activation=nn_ops.relu, - log_scale_min_clip=-5., - log_scale_max_clip=3., - log_scale_clip_gradient=False, - name=None, - *args, - **kwargs): - """Build the MADE Model [1]. - - This will be wrapped in a make_template to ensure the variables are only - created once. It takes the input and returns the `loc` ("mu" [1]) and - `log_scale` ("alpha" [1]) from the MADE network. - - Warning: This function uses `masked_dense` to create randomly initialized - `tf.Variables`. It is presumed that these will be fit, just as you would any - other neural architecture which uses `tf.layers.dense`. - - #### About Hidden Layers: - - Each element of `hidden_layers` should be greater than the `input_depth` - (i.e., `input_depth = tf.shape(input)[-1]` where `input` is the input to the - neural network). This is necessary to ensure the autoregressivity property. - - #### About Clipping: - - This function also optionally clips the `log_scale` (but possibly not its - gradient). This is useful because if `log_scale` is too small/large it might - underflow/overflow making it impossible for the `MaskedAutoregressiveFlow` - bijector to implement a bijection. Additionally, the `log_scale_clip_gradient` - `bool` indicates whether the gradient should also be clipped. The default does - not clip the gradient; this is useful because it still provides gradient - information (for fitting) yet solves the numerical stability problem. I.e., - `log_scale_clip_gradient = False` means - `grad[exp(clip(x))] = grad[x] exp(clip(x))` rather than the usual - `grad[clip(x)] exp(clip(x))`. - - [1]: "MADE: Masked Autoencoder for Distribution Estimation." - Mathieu Germain, Karol Gregor, Iain Murray, Hugo Larochelle. ICML. 2015. - https://arxiv.org/abs/1502.03509 - - Arguments: - hidden_layers: Python `list`-like of non-negative integer, scalars - indicating the number of units in each hidden layer. Default: `[512, 512]. - shift_only: Python `bool` indicating if only the `shift` term shall be - computed. Default: `False`. - activation: Activation function (callable). Explicitly setting to `None` - implies a linear activation. - log_scale_min_clip: `float`-like scalar `Tensor`, or a `Tensor` with the - same shape as `log_scale`. The minimum value to clip by. Default: -5. - log_scale_max_clip: `float`-like scalar `Tensor`, or a `Tensor` with the - same shape as `log_scale`. The maximum value to clip by. Default: 3. - log_scale_clip_gradient: Python `bool` indicating that the gradient of - `tf.clip_by_value` should be preserved. Default: `False`. - name: A name for ops managed by this function. Default: - "masked_autoregressive_default_template". - *args: `tf.layers.dense` arguments. - **kwargs: `tf.layers.dense` keyword arguments. - - Returns: - shift: `Float`-like `Tensor` of shift terms (the "mu" in [2]). - log_scale: `Float`-like `Tensor` of log(scale) terms (the "alpha" in [2]). - - Raises: - NotImplementedError: if rightmost dimension of `inputs` is unknown prior to - graph execution. - """ - - with ops.name_scope(name, "masked_autoregressive_default_template", - values=[log_scale_min_clip, log_scale_max_clip]): - def _fn(x): - """MADE parameterized via `masked_autoregressive_default_template`.""" - # TODO(b/67594795): Better support of dynamic shape. - input_depth = x.shape.with_rank_at_least(1)[-1].value - if input_depth is None: - raise NotImplementedError( - "Rightmost dimension must be known prior to graph execution.") - input_shape = (np.int32(x.shape.as_list()) if x.shape.is_fully_defined() - else array_ops.shape(x)) - for i, units in enumerate(hidden_layers): - x = masked_dense( - inputs=x, - units=units, - num_blocks=input_depth, - exclusive=True if i == 0 else False, - activation=activation, - *args, - **kwargs) - x = masked_dense( - inputs=x, - units=(1 if shift_only else 2) * input_depth, - num_blocks=input_depth, - activation=None, - *args, - **kwargs) - if shift_only: - x = array_ops.reshape(x, shape=input_shape) - return x, None - x = array_ops.reshape( - x, shape=array_ops.concat([input_shape, [2]], axis=0)) - shift, log_scale = array_ops.unstack(x, num=2, axis=-1) - which_clip = (math_ops.clip_by_value if log_scale_clip_gradient - else _clip_by_value_preserve_grad) - log_scale = which_clip(log_scale, log_scale_min_clip, log_scale_max_clip) - return shift, log_scale - return template_ops.make_template( - "masked_autoregressive_default_template", _fn) - - -def _clip_by_value_preserve_grad(x, clip_value_min, clip_value_max, name=None): - """Clips input while leaving gradient unaltered.""" - with ops.name_scope(name, "clip_by_value_preserve_grad", - [x, clip_value_min, clip_value_max]): - clip_x = clip_ops.clip_by_value(x, clip_value_min, clip_value_max) - return x + array_ops.stop_gradient(clip_x - x) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/permute.py b/tensorflow/contrib/distributions/python/ops/bijectors/permute.py index a187ce22d686ee1203802ae2bfe64b0e1a3ea850..8654cc39d0c41ec4f1b85cd5fc4366ceaf4b224d 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/permute.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/permute.py @@ -12,18 +12,127 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Permute bijector.""" +"""Permutation bijectors.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.permute_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +import numpy as np -_allowed_symbols = ["Permute"] +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops.distributions import bijector as bijector_lib -remove_undocumented(__name__, _allowed_symbols) + +__all__ = [ + "Permute", +] + + +class Permute(bijector_lib.Bijector): + """Permutes the rightmost dimension of a `Tensor`. + + ```python + tfd = tf.contrib.distributions + + reverse = tfd.bijectors.Permute(permutation=[2, 1, 0]) + + reverse.forward([-1., 0., 1.]) + # ==> [1., 0., -1] + + reverse.inverse([1., 0., -1]) + # ==> [-1., 0., 1.] + + reverse.forward_log_det_jacobian(any_value) + # ==> 0. + + reverse.inverse_log_det_jacobian(any_value) + # ==> 0. + ``` + + Warning: `tf.estimator` may repeatedly build the graph thus + `Permute(np.random.permutation(event_size)).astype("int32"))` is not a + reliable parameterization (nor would it be even if using `tf.constant`). A + safe alternative is to use `tf.get_variable` to achieve "init once" behavior, + i.e., + + ```python + def init_once(x, name): + return tf.get_variable(name, initializer=x, trainable=False) + + Permute(permutation=init_once( + np.random.permutation(event_size).astype("int32"), + name="permutation")) + ``` + + """ + + def __init__(self, permutation, validate_args=False, name=None): + """Creates the `Permute` bijector. + + Args: + permutation: An `int`-like vector-shaped `Tensor` representing the + permutation to apply to the rightmost dimension of the transformed + `Tensor`. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str`, name given to ops managed by this object. + + Raises: + TypeError: if `not permutation.dtype.is_integer`. + ValueError: if `permutation` does not contain exactly one of each of + `{0, 1, ..., d}`. + """ + with ops.name_scope(name, "permute", values=[permutation]): + permutation = ops.convert_to_tensor( + permutation, + name="permutation") + if not permutation.dtype.is_integer: + raise TypeError("permutation.dtype ({}) should be `int`-like.".format( + permutation.dtype.name)) + p = tensor_util.constant_value(permutation) + if p is not None: + if set(p) != set(np.arange(p.size)): + raise ValueError("Permutation over `d` must contain exactly one of " + "each of `{0, 1, ..., d}`.") + elif validate_args: + p, _ = nn_ops.top_k(-permutation, + k=array_ops.shape(permutation)[-1], + sorted=True) + permutation = control_flow_ops.with_dependencies([ + check_ops.assert_equal( + -p, math_ops.range(array_ops.size(p)), + message=("Permutation over `d` must contain exactly one of " + "each of `{0, 1, ..., d}`.")), + ], permutation) + self._permutation = permutation + super(Permute, self).__init__( + is_constant_jacobian=True, + validate_args=validate_args, + name=name or "permute") + + @property + def permutation(self): + return self._permutation + + def _forward(self, x): + return array_ops.gather(x, self.permutation, axis=-1) + + def _inverse(self, y): + return array_ops.gather( + y, + array_ops.invert_permutation(self.permutation), + axis=-1) + + def _inverse_log_det_jacobian(self, y): + return constant_op.constant(0., dtype=y.dtype) + + def _forward_log_det_jacobian(self, x): + return constant_op.constant(0., dtype=x.dtype) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/permute_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/permute_impl.py deleted file mode 100644 index b1d8f2f41b28a88208a19824377f93882b767f03..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/permute_impl.py +++ /dev/null @@ -1,138 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Permutation bijectors.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn_ops -from tensorflow.python.ops.distributions import bijector as bijector_lib - - -__all__ = [ - "Permute", -] - - -class Permute(bijector_lib.Bijector): - """Permutes the rightmost dimension of a `Tensor`. - - ```python - bs = tf.contrib.distributions.bijectors - - reverse = bs.Permute(permutation=[2, 1, 0]) - - reverse.forward([-1., 0., 1.]) - # ==> [1., 0., -1] - - reverse.inverse([1., 0., -1]) - # ==> [-1., 0., 1.] - - reverse.forward_log_det_jacobian(any_value) - # ==> 0. - - reverse.inverse_log_det_jacobian(any_value) - # ==> 0. - ``` - - Warning: `tf.estimator` may repeatedly build the graph thus - `Permute(np.random.permutation(event_size)).astype("int32"))` is not a - reliable parameterization (nor would it be even if using `tf.constant`). A - safe alternative is to use `tf.get_variable` to achieve "init once" behavior, - i.e., - - ```python - def init_once(x, name): - return tf.get_variable(name, initializer=x, trainable=False) - - Permute(permutation=init_once( - np.random.permutation(event_size).astype("int32"), - name="permutation")) - ``` - - """ - - def __init__(self, permutation, validate_args=False, name=None): - """Creates the `Permute` bijector. - - Args: - permutation: An `int`-like vector-shaped `Tensor` representing the - permutation to apply to the rightmost dimension of the transformed - `Tensor`. - validate_args: Python `bool` indicating whether arguments should be - checked for correctness. - name: Python `str`, name given to ops managed by this object. - - Raises: - TypeError: if `not permutation.dtype.is_integer`. - ValueError: if `permutation` does not contain exactly one of each of - `{0, 1, ..., d}`. - """ - with ops.name_scope(name, "permute", values=[permutation]): - permutation = ops.convert_to_tensor( - permutation, - name="permutation") - if not permutation.dtype.is_integer: - raise TypeError("permutation.dtype ({}) should be `int`-like.".format( - permutation.dtype.name)) - p = tensor_util.constant_value(permutation) - if p is not None: - if set(p) != set(np.arange(p.size)): - raise ValueError("Permutation over `d` must contain exactly one of " - "each of `{0, 1, ..., d}`.") - elif validate_args: - p, _ = nn_ops.top_k(-permutation, - k=array_ops.shape(permutation)[-1], - sorted=True) - permutation = control_flow_ops.with_dependencies([ - check_ops.assert_equal( - -p, math_ops.range(array_ops.size(p)), - message=("Permutation over `d` must contain exactly one of " - "each of `{0, 1, ..., d}`.")), - ], permutation) - self._permutation = permutation - super(Permute, self).__init__( - is_constant_jacobian=True, - validate_args=validate_args, - name=name or "permute") - - @property - def permutation(self): - return self._permutation - - def _forward(self, x): - return array_ops.gather(x, self.permutation, axis=-1) - - def _inverse(self, y): - return array_ops.gather( - y, - array_ops.invert_permutation(self.permutation), - axis=-1) - - def _inverse_log_det_jacobian(self, y): - return constant_op.constant(0., dtype=y.dtype) - - def _forward_log_det_jacobian(self, x): - return constant_op.constant(0., dtype=x.dtype) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/power_transform.py b/tensorflow/contrib/distributions/python/ops/bijectors/power_transform.py index a83199549cd16101ab7b39b43d19a17bc66f03df..c37db61720d10949f294ff7b2e9778ba6efa57f0 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/power_transform.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/power_transform.py @@ -18,12 +18,110 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.power_transform_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bijector -_allowed_symbols = ["PowerTransform"] -remove_undocumented(__name__, _allowed_symbols) +__all__ = [ + "PowerTransform", +] + + +class PowerTransform(bijector.Bijector): + """Compute `Y = g(X) = (1 + X * c)**(1 / c), X >= -1 / c`. + + The [power transform](https://en.wikipedia.org/wiki/Power_transform) maps + inputs from `[0, inf]` to `[-1/c, inf]`; this is equivalent to the `inverse` + of this bijector. + + This bijector is equivalent to the `Exp` bijector when `c=0`. + """ + + def __init__(self, + power=0., + event_ndims=0, + validate_args=False, + name="power_transform"): + """Instantiates the `PowerTransform` bijector. + + Args: + power: Python `float` scalar indicating the transform power, i.e., + `Y = g(X) = (1 + X * c)**(1 / c)` where `c` is the `power`. + event_ndims: Python scalar indicating the number of dimensions associated + with a particular draw from the distribution. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + + Raises: + ValueError: if `power < 0` or is not known statically. + """ + self._graph_parents = [] + self._name = name + self._validate_args = validate_args + with self._name_scope("init", values=[power]): + power = tensor_util.constant_value( + ops.convert_to_tensor(power, name="power")) + if power is None or power < 0: + raise ValueError("`power` must be a non-negative TF constant.") + self._power = power + super(PowerTransform, self).__init__( + event_ndims=event_ndims, + validate_args=validate_args, + name=name) + + @property + def power(self): + """The `c` in: `Y = g(X) = (1 + X * c)**(1 / c)`.""" + return self._power + + def _forward(self, x): + x = self._maybe_assert_valid_x(x) + if self.power == 0.: + return math_ops.exp(x) + # If large x accuracy is an issue, consider using: + # (1. + x * self.power)**(1. / self.power) when x >> 1. + return math_ops.exp(math_ops.log1p(x * self.power) / self.power) + + def _inverse(self, y): + y = self._maybe_assert_valid_y(y) + if self.power == 0.: + return math_ops.log(y) + # If large y accuracy is an issue, consider using: + # (y**self.power - 1.) / self.power when y >> 1. + return math_ops.expm1(math_ops.log(y) * self.power) / self.power + + def _inverse_log_det_jacobian(self, y): + y = self._maybe_assert_valid_y(y) + event_dims = self._event_dims_tensor(y) + return (self.power - 1.) * math_ops.reduce_sum( + math_ops.log(y), axis=event_dims) + + def _forward_log_det_jacobian(self, x): + x = self._maybe_assert_valid_x(x) + event_dims = self._event_dims_tensor(x) + if self.power == 0.: + return math_ops.reduce_sum(x, axis=event_dims) + return (1. / self.power - 1.) * math_ops.reduce_sum( + math_ops.log1p(x * self.power), + axis=event_dims) + + def _maybe_assert_valid_x(self, x): + if not self.validate_args or self.power == 0.: + return x + is_valid = check_ops.assert_non_negative( + 1. + self.power * x, + message="Forward transformation input must be at least {}.".format( + -1. / self.power)) + return control_flow_ops.with_dependencies([is_valid], x) + + def _maybe_assert_valid_y(self, y): + if not self.validate_args: + return y + is_valid = check_ops.assert_positive( + y, message="Inverse transformation input must be greater than 0.") + return control_flow_ops.with_dependencies([is_valid], y) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/power_transform_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/power_transform_impl.py deleted file mode 100644 index c37db61720d10949f294ff7b2e9778ba6efa57f0..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/power_transform_impl.py +++ /dev/null @@ -1,127 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""PowerTransform bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops.distributions import bijector - - -__all__ = [ - "PowerTransform", -] - - -class PowerTransform(bijector.Bijector): - """Compute `Y = g(X) = (1 + X * c)**(1 / c), X >= -1 / c`. - - The [power transform](https://en.wikipedia.org/wiki/Power_transform) maps - inputs from `[0, inf]` to `[-1/c, inf]`; this is equivalent to the `inverse` - of this bijector. - - This bijector is equivalent to the `Exp` bijector when `c=0`. - """ - - def __init__(self, - power=0., - event_ndims=0, - validate_args=False, - name="power_transform"): - """Instantiates the `PowerTransform` bijector. - - Args: - power: Python `float` scalar indicating the transform power, i.e., - `Y = g(X) = (1 + X * c)**(1 / c)` where `c` is the `power`. - event_ndims: Python scalar indicating the number of dimensions associated - with a particular draw from the distribution. - validate_args: Python `bool` indicating whether arguments should be - checked for correctness. - name: Python `str` name given to ops managed by this object. - - Raises: - ValueError: if `power < 0` or is not known statically. - """ - self._graph_parents = [] - self._name = name - self._validate_args = validate_args - with self._name_scope("init", values=[power]): - power = tensor_util.constant_value( - ops.convert_to_tensor(power, name="power")) - if power is None or power < 0: - raise ValueError("`power` must be a non-negative TF constant.") - self._power = power - super(PowerTransform, self).__init__( - event_ndims=event_ndims, - validate_args=validate_args, - name=name) - - @property - def power(self): - """The `c` in: `Y = g(X) = (1 + X * c)**(1 / c)`.""" - return self._power - - def _forward(self, x): - x = self._maybe_assert_valid_x(x) - if self.power == 0.: - return math_ops.exp(x) - # If large x accuracy is an issue, consider using: - # (1. + x * self.power)**(1. / self.power) when x >> 1. - return math_ops.exp(math_ops.log1p(x * self.power) / self.power) - - def _inverse(self, y): - y = self._maybe_assert_valid_y(y) - if self.power == 0.: - return math_ops.log(y) - # If large y accuracy is an issue, consider using: - # (y**self.power - 1.) / self.power when y >> 1. - return math_ops.expm1(math_ops.log(y) * self.power) / self.power - - def _inverse_log_det_jacobian(self, y): - y = self._maybe_assert_valid_y(y) - event_dims = self._event_dims_tensor(y) - return (self.power - 1.) * math_ops.reduce_sum( - math_ops.log(y), axis=event_dims) - - def _forward_log_det_jacobian(self, x): - x = self._maybe_assert_valid_x(x) - event_dims = self._event_dims_tensor(x) - if self.power == 0.: - return math_ops.reduce_sum(x, axis=event_dims) - return (1. / self.power - 1.) * math_ops.reduce_sum( - math_ops.log1p(x * self.power), - axis=event_dims) - - def _maybe_assert_valid_x(self, x): - if not self.validate_args or self.power == 0.: - return x - is_valid = check_ops.assert_non_negative( - 1. + self.power * x, - message="Forward transformation input must be at least {}.".format( - -1. / self.power)) - return control_flow_ops.with_dependencies([is_valid], x) - - def _maybe_assert_valid_y(self, y): - if not self.validate_args: - return y - is_valid = check_ops.assert_positive( - y, message="Inverse transformation input must be greater than 0.") - return control_flow_ops.with_dependencies([is_valid], y) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py b/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py index 8997f7ab6929745275edb38712a5bbb0a9b25ddb..55eca063126797d577653f0d6bcdfddf8192bdb5 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py @@ -12,18 +12,303 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Reshape bijector.""" +"""Reshape bijectors.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.reshape_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +import numpy as np -_allowed_symbols = ["Reshape"] +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bijector as bijector_lib -remove_undocumented(__name__, _allowed_symbols) + +__all__ = [ + "Reshape", +] + + +def _static_ndims_from_shape(shape): + return shape.shape.with_rank_at_least(1)[0].value + + +def _ndims_from_shape(shape): + return array_ops.shape(shape)[0] + + +class Reshape(bijector_lib.Bijector): + """Reshapes the `event_shape` of a `Tensor`. + + The semantics generally follow that of `tf.reshape()`, with + a few differences: + + * The user must provide both the input and output shape, so that + the transformation can be inverted. If an input shape is not + specified, the default assumes a vector-shaped input, i.e., + event_shape_in = (-1,). + * The `Reshape` bijector automatically broadcasts over the leftmost + dimensions of its input (`sample_shape` and `batch_shape`); only + the rightmost `event_ndims_in` dimensions are reshaped. The + number of dimensions to reshape is inferred from the provided + `event_shape_in` (`event_ndims_in = len(event_shape_in)`). + + Example usage: + ```python + + tfd = tf.contrib.distributions + + r = tfd.bijectors.Reshape(event_shape_out=[1, -1]) + + r.forward([3., 4.]) # shape [2] + # ==> [[3., 4.]] # shape [1, 2] + + r.forward([[1., 2.], [3., 4.]]) # shape [2, 2] + # ==> [[[1., 2.]], + # [[3., 4.]]] # shape [2, 1, 2] + + r.inverse([[3., 4.]]) # shape [1,2] + # ==> [3., 4.] # shape [2] + + r.forward_log_det_jacobian(any_value) + # ==> 0. + + r.inverse_log_det_jacobian(any_value) + # ==> 0. + ``` + + """ + + def __init__(self, event_shape_out, event_shape_in=(-1,), + validate_args=False, name=None): + """Creates a `Reshape` bijector. + + Args: + event_shape_out: An `int`-like vector-shaped `Tensor` + representing the event shape of the transformed output. + event_shape_in: An optional `int`-like vector-shape `Tensor` + representing the event shape of the input. This is required in + order to define inverse operations; the default of (-1,) + assumes a vector-shaped input. + validate_args: Python `bool` indicating whether arguments should + be checked for correctness. + name: Python `str`, name given to ops managed by this object. + + Raises: + TypeError: if either `event_shape_in` or `event_shape_out` has + non-integer `dtype`. + ValueError: if either of `event_shape_in` or `event_shape_out` + has non-vector shape (`rank > 1`), or if their sizes do not + match. + """ + with ops.name_scope(name, "reshape", + values=[event_shape_out, event_shape_in]): + + event_shape_out = ops.convert_to_tensor(event_shape_out, + name="event_shape_out", + preferred_dtype=dtypes.int32) + event_shape_in = ops.convert_to_tensor(event_shape_in, + name="event_shape_in", + preferred_dtype=dtypes.int32) + + assertions = [] + assertions.extend(self._maybe_check_valid_shape( + event_shape_out, validate_args)) + assertions.extend(self._maybe_check_valid_shape( + event_shape_in, validate_args)) + + self._assertions = assertions + self._event_shape_in = event_shape_in + self._event_shape_out = event_shape_out + + super(Reshape, self).__init__(is_constant_jacobian=True, + validate_args=validate_args, + name=name or "reshape") + + def _maybe_check_valid_shape(self, shape, validate_args): + """Check that a shape Tensor is int-type and otherwise sane.""" + if not shape.dtype.is_integer: + raise TypeError("{} dtype ({}) should be `int`-like.".format( + shape.op.name, shape.dtype.name)) + + assertions = [] + + ndims = array_ops.rank(shape) + ndims_ = tensor_util.constant_value(ndims) + if ndims_ is not None and ndims_ > 1: + raise ValueError("`{}` rank ({}) should be <= 1.".format( + shape.op.name, ndims_)) + elif validate_args: + assertions.append(check_ops.assert_less_equal( + ndims, 1, message="`{}` rank should be <= 1.".format(shape.op.name))) + + shape_ = tensor_util.constant_value_as_shape(shape) + if shape_.is_fully_defined(): + es = np.int32(shape_.as_list()) + if sum(es == -1) > 1: + raise ValueError( + "`{}` must have at most one `-1` (given {})" + .format(shape.op.name, es)) + if np.any(es < -1): + raise ValueError( + "`{}` elements must be either positive integers or `-1`" + "(given {})." + .format(shape.op.name, es)) + elif validate_args: + assertions.extend([ + check_ops.assert_less_equal( + math_ops.reduce_sum( + math_ops.cast(math_ops.equal(shape, -1), dtypes.int32)), + 1, + message="`{}` elements must have at most one `-1`." + .format(shape.op.name)), + check_ops.assert_greater_equal( + shape, -1, + message="`{}` elements must be either positive integers or `-1`." + .format(shape.op.name)), + ]) + return assertions + + def _reshape_helper(self, x, event_shape_in, event_shape_out): + """Reshape only the event_shape of an input `Tensor`.""" + + event_ndims_in_ = _static_ndims_from_shape(event_shape_in) + event_ndims_in = _ndims_from_shape(event_shape_in) + x_ndims_, x_ndims = x.shape.ndims, array_ops.rank(x) + + assertions = [] + + # Ensure x.event_shape is compatible with event_shape_in. + if (event_ndims_in_ is not None + and x_ndims_ is not None + and x.shape.with_rank_at_least(event_ndims_in_)[ + x_ndims_-event_ndims_in_:].is_fully_defined()): + x_event_shape_, x_event_shape = [ # pylint: disable=unbalanced-tuple-unpacking + np.int32(x.shape[x_ndims_-event_ndims_in_:])]*2 + else: + x_event_shape_, x_event_shape = ( + None, array_ops.shape(x)[x_ndims-event_ndims_in:]) + + event_shape_in_ = tensor_util.constant_value(event_shape_in) + + if x_event_shape_ is not None and event_shape_in_ is not None: + # Compare the shape dimensions that are fully specified in the + # input (i.e., for which event_shape_in is not -1). If x_event_shape + # matches along all of these dimensions, it is compatible with + # the desired input shape and any further mismatches (i.e., + # imcompatibility with the desired *output* shape) will be + # caught inside of array_ops.reshape() below. + x_event_shape_specified_ = x_event_shape_[event_shape_in_ >= 0] + event_shape_in_specified_ = event_shape_in_[event_shape_in_ >= 0] + if not np.equal(x_event_shape_specified_, + event_shape_in_specified_).all(): + raise ValueError( + "Input `event_shape` does not match `event_shape_in` ({} vs {}).". + format(x_event_shape_, event_shape_in_)) + elif self.validate_args: + # Similarly to the static case, we compare the shape dimensions + # that are fully specified in the input. We extract these + # dimensions using boolean_mask(), which requires that the mask + # have known ndims. We can assume that shape Tensors always have + # ndims==1 (this assumption is verified inside of + # _maybe_check_valid_shape), so the reshape operation is just a + # no-op that formally encodes this fact to make boolean_mask() + # happy. + event_shape_mask = array_ops.reshape(event_shape_in >= 0, [-1]) + x_event_shape_specified = array_ops.boolean_mask(x_event_shape, + event_shape_mask) + event_shape_in_specified = array_ops.boolean_mask(event_shape_in, + event_shape_mask) + assertions.append(check_ops.assert_equal( + x_event_shape_specified, event_shape_in_specified, + message="Input `event_shape` does not match `event_shape_in`.")) + + if assertions: + x = control_flow_ops.with_dependencies(assertions, x) + + # get the parts of shape(x) that will not change + sample_and_batch_shape = array_ops.shape(x) + + ndims = (x.shape.ndims if x.shape.ndims is not None + else array_ops.rank(x)) + sample_and_batch_shape = sample_and_batch_shape[ + :(ndims - math_ops.abs(event_ndims_in))] + + if (event_ndims_in_ is not None + and x_ndims_ is not None + and event_ndims_in_ == x_ndims_): + # Hack to allow forward/inverse_event_shape to do shape + # inference by calling this helper method with a dummy Tensor of + # shape event_shape_in. In this special case, + # sample_and_batch_shape will be empty so we can preserve static + # shape information by avoiding the concat operation below + # (which would be a no-op). + new_shape = event_shape_out + else: + new_shape = array_ops.concat( + [sample_and_batch_shape, event_shape_out], axis=0) + + return array_ops.reshape(x, new_shape) + + def _forward(self, x): + with ops.control_dependencies(self._assertions): + return self._reshape_helper(x, + self._event_shape_in, + self._event_shape_out) + + def _inverse(self, y): + with ops.control_dependencies(self._assertions): + return self._reshape_helper(y, + self._event_shape_out, + self._event_shape_in) + + def _inverse_log_det_jacobian(self, y): + with ops.control_dependencies(self._assertions): + return constant_op.constant(0., dtype=y.dtype) + + def _forward_log_det_jacobian(self, x): + with ops.control_dependencies(self._assertions): + return constant_op.constant(0., dtype=x.dtype) + + def _forward_event_shape(self, input_shape): + # NOTE: this method and the other *_event_shape* methods + # compute shape by explicit transformation of a dummy + # variable. This approach is not generally recommended because it + # bloats the graph and could in general trigger side effects. + # + # In this particular case of the Reshape bijector, the + # forward and inverse transforms have no side effects, and we + # believe the reduction in code complexity from delegating the + # heavy lifting to tf.reshape() is worth the added graph ops. + # However, you should think hard before implementing this approach + # in other Bijectors; it is strongly preferred to compute + # shapes explicitly whenever it's feasible to do so. + with ops.control_dependencies(self._assertions): + dummy = array_ops.zeros(dtype=dtypes.float32, shape=input_shape) + dummy_reshaped = self.forward(dummy) + return dummy_reshaped.shape + + def _inverse_event_shape(self, output_shape): + with ops.control_dependencies(self._assertions): + dummy = array_ops.zeros(dtype=dtypes.float32, shape=output_shape) + dummy_reshaped = self.inverse(dummy) + return dummy_reshaped.shape + + def _forward_event_shape_tensor(self, input_shape): + with ops.control_dependencies(self._assertions): + dummy = array_ops.zeros(dtype=dtypes.float32, shape=input_shape) + dummy_reshaped = self.forward(dummy) + return array_ops.shape(dummy_reshaped) + + def _inverse_event_shape_tensor(self, output_shape): + with ops.control_dependencies(self._assertions): + dummy = array_ops.zeros(dtype=dtypes.float32, shape=output_shape) + dummy_reshaped = self.inverse(dummy) + return array_ops.shape(dummy_reshaped) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/reshape_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/reshape_impl.py deleted file mode 100644 index 93682639aa3be3b8f59a369dedb6ee773c468130..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/reshape_impl.py +++ /dev/null @@ -1,297 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Reshape bijectors.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops.distributions import bijector as bijector_lib - - -__all__ = [ - "Reshape", -] - - -class Reshape(bijector_lib.Bijector): - """Reshapes the `event_shape` of a `Tensor`. - - The semantics generally follow that of `tf.reshape()`, with - a few differences: - * The user must provide both the input and output shape, so that - the transformation can be inverted. - * The `Reshape` bijector automatically broadcasts over the leftmost - dimensions of its input (`sample_shape` and `batch_shape`); only - the rightmost `event_ndims_in` dimensions are reshaped. The - number of dimensions to reshape is inferred from the provided - `event_shape_in` (`event_ndims_in = len(event_shape_in)`). - * The `Reshape` bijector does not currently support - partially-specified shapes, i.e., those with a dimension - implicitly specified by `-1`. - - Example usage: - ```python - - bs = tf.contrib.distributions.bijectors - - reverse = bs.Reshape(event_shape_out=[1,2], - event_shape_in=[2,]) - - reverse.forward([1., 2.]) # shape [2,] - # ==> [[1., 2.]] # shape [1,2] - - reverse.forward([[1., 2.], [3., 4.]]) # shape [2, 2] - # ==> [[[1., 2.]], [[3., 4.]]] # shape [2, 1, 2] - - reverse.inverse([[1., 2.]]) # shape [1,2] - # ==> [1., 2.] # shape [2,] - - reverse.forward_log_det_jacobian(any_value) - # ==> 0. - - reverse.inverse_log_det_jacobian(any_value) - # ==> 0. - ``` - - """ - - def __init__(self, event_shape_out, event_shape_in, - validate_args=False, name=None): - """Creates a `Reshape` bijector. - - Args: - event_shape_out: An `int`-like vector-shaped `Tensor` - representing the fully specified (no -1's) event shape of the - transformed output. - event_shape_in: An `int`-like vector-shaped `Tensor` - representing the fully specified (no -1's) event shape of the - input. - validate_args: Python `bool` indicating whether arguments should - be checked for correctness. - name: Python `str`, name given to ops managed by this object. - - Raises: - TypeError: if either `event_shape_in` or `event_shape_out` has - non-vector shape (`rank > 1`), or non-integer `dtype`. - ValueError: if either `event_shape_in` or `event_shape_out` - contains non-positive entries, or if their sizes do not match - (`prod(event_shape_in)` != `prod(event_shape_out)`), or if - their dimensionality(s) cannot be statically inferred. - """ - with ops.name_scope(name, "reshape", - values=[event_shape_out, event_shape_in]): - - event_shape_out = ops.convert_to_tensor(event_shape_out, - name="event_shape_out", - preferred_dtype=dtypes.int32) - event_shape_in = ops.convert_to_tensor(event_shape_in, - name="event_shape_in", - preferred_dtype=dtypes.int32) - - # check that input shapes are positive integers - assertions = [] - assertions += self._maybe_check_valid_shape( - event_shape_out, "event_shape_out", - validate_args=validate_args) - assertions += self._maybe_check_valid_shape( - event_shape_in, "event_shape_in", validate_args=validate_args) - - # check that prod(event_shape_in) = prod(event_shape_out) - assertions += self._maybe_check_matching_sizes( - event_shape_in, event_shape_out, validate_args=validate_args) - - self._assertions = assertions - self._event_shape_in = event_shape_in - self._event_shape_out = event_shape_out - self._event_shape_in_static = tensor_util.constant_value_as_shape( - event_shape_in) - self._event_shape_out_static = tensor_util.constant_value_as_shape( - event_shape_out) - - super(Reshape, self).__init__(is_constant_jacobian=True, - validate_args=validate_args, - name=name or "reshape") - - def _maybe_check_valid_shape(self, shape_tensor, label, - validate_args=False): - """Check that a shape Tensor is int-type and positive.""" - - assertions = [] - - if not shape_tensor.dtype.is_integer: - raise TypeError("{} dtype ({}) should be `int`-like.".format( - label, shape_tensor.dtype.name)) - - shape_rank = tensor_util.constant_value(array_ops.rank(shape_tensor)) - if shape_rank is not None and shape_rank > 1: - raise ValueError("{} rank should be <= 1.".format(label)) - - s = tensor_util.constant_value(shape_tensor) - if s is not None: - if (s <= 0).any(): - raise ValueError("{} entries must be positive, but found {}".format( - label, s)) - elif validate_args: - assertions.append(check_ops.assert_positive( - shape_tensor, message="{} entries must be positive".format(label))) - - return assertions - - def _maybe_check_matching_sizes(self, event_shape_in, event_shape_out, - validate_args=False): - """Check that prod(event_shape_in)==prod(event_shape_out).""" - - def _get_size_from_shape(shape): - """Computes size from a shape `Tensor`, statically if possible.""" - s = tensor_util.constant_value(shape) - if s is not None: - return [np.int32(np.prod(s))]*2 - return None, math_ops.reduce_prod(shape, name="size") - - # Ensure `event_shape_in` is compatible with `event_shape_out`. - event_size_in_, event_size_in = _get_size_from_shape( # pylint: disable=unbalanced-tuple-unpacking - event_shape_in) - event_size_out_, event_size_out = _get_size_from_shape( # pylint: disable=unbalanced-tuple-unpacking - event_shape_out) - - assertions = [] - if event_size_in_ is not None and event_size_out_ is not None: - if event_size_in_ != event_size_out_: - raise ValueError( - "Input `event_size` ({}) does not match output `event_size` ({}).". - format(event_size_in, event_size_out_)) - elif validate_args: - assertions.append(check_ops.assert_equal( - event_size_in, event_size_out, - message="Input/output `event_size`s do not match.")) - - return assertions - - def _reshape_helper(self, x, event_shape_in, event_shape_out): - """Reshape only the event_shape of an input `Tensor`.""" - - def _get_rank_from_shape(shape): - """Computes rank from a shape `Tensor`, statically if possible.""" - # Uses fact that rank is "shape of shape". - ndims = shape.shape.with_rank_at_least(1)[0].value - if ndims is not None: - return ndims, ndims - return None, array_ops.shape(shape)[0] - - event_ndims_in_, event_ndims_in = _get_rank_from_shape(event_shape_in) - - assertions = [] - # Ensure x.event_shape is compatible with event_shape_in. - if x.shape.ndims is not None: - x_ndims_, x_ndims = [x.shape.ndims]*2 - else: - x_ndims_, x_ndims = None, array_ops.rank(x) - - if (event_ndims_in_ is not None - and x_ndims_ is not None - and x.shape.with_rank_at_least(event_ndims_in_)[ - x_ndims_-event_ndims_in_:].is_fully_defined()): - x_event_shape_, x_event_shape = [ # pylint: disable=unbalanced-tuple-unpacking - np.int32(x.shape[x_ndims_-event_ndims_in_:])]*2 - else: - x_event_shape_, x_event_shape = ( - None, array_ops.shape(x)[x_ndims-event_ndims_in:]) - - event_shape_in_ = tensor_util.constant_value(event_shape_in) - - if x_event_shape_ is not None and event_shape_in_ is not None: - if not np.equal(x_event_shape_, event_shape_in_).all(): - raise ValueError( - "Input `event_shape` ({}) does not match `event_shape_in` ({}).". - format(x_event_shape_, event_shape_in_)) - elif self.validate_args: - assertions.append(check_ops.assert_equal( - x_event_shape, event_shape_in, - message="Input `event_shape` does not match `event_shape_in`.")) - - if assertions: - x = control_flow_ops.with_dependencies(assertions, x) - - # get the parts of shape(x) that will not change - sample_and_batch_shape = array_ops.shape(x) - - ndims = (x.shape.ndims if x.shape.ndims is not None - else array_ops.rank(x)) - sample_and_batch_shape = sample_and_batch_shape[ - :(ndims - math_ops.abs(event_ndims_in))] - - new_shape = array_ops.concat( - [sample_and_batch_shape, event_shape_out], axis=0) - - return array_ops.reshape(x, new_shape) - - def _forward(self, x): - with ops.control_dependencies(self._assertions): - return self._reshape_helper(x, - self._event_shape_in, - self._event_shape_out) - - def _inverse(self, y): - with ops.control_dependencies(self._assertions): - return self._reshape_helper(y, - self._event_shape_out, - self._event_shape_in) - - def _inverse_log_det_jacobian(self, y): - with ops.control_dependencies(self._assertions): - return constant_op.constant(0., dtype=y.dtype) - - def _forward_log_det_jacobian(self, x): - with ops.control_dependencies(self._assertions): - return constant_op.constant(0., dtype=x.dtype) - - def _forward_event_shape(self, input_shape): - self._event_shape_in_static.assert_is_compatible_with(input_shape) - return self._event_shape_out_static - - def _inverse_event_shape(self, output_shape): - self._event_shape_out_static.assert_is_compatible_with(output_shape) - return self._event_shape_in_static - - def _forward_event_shape_tensor(self, input_shape): - input_assertions = self._maybe_check_valid_shape( - input_shape, "input event shape", validate_args=self.validate_args) - input_assertions += self._maybe_check_matching_sizes( - input_shape, self._event_shape_out, - validate_args=self.validate_args) - - return control_flow_ops.with_dependencies( - input_assertions + self._assertions, self._event_shape_out) - - def _inverse_event_shape_tensor(self, output_shape): - - output_assertions = self._maybe_check_valid_shape( - output_shape, "output event shape", validate_args=self.validate_args) - output_assertions += self._maybe_check_matching_sizes( - output_shape, self._event_shape_in, validate_args=self.validate_args) - - return control_flow_ops.with_dependencies( - output_assertions + self._assertions, self._event_shape_in) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid.py b/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid.py index c20e76c0b7367369865faf973377201c8b8b17e6..a640dfe7dfbcce96261589c7fc49107deaefdd54 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid.py @@ -18,12 +18,31 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops.distributions import bijector -_allowed_symbols = ["Sigmoid"] -remove_undocumented(__name__, _allowed_symbols) +__all__ = [ + "Sigmoid", +] + + +class Sigmoid(bijector.Bijector): + """Bijector which computes `Y = g(X) = 1 / (1 + exp(-X))`.""" + + def __init__(self, validate_args=False, name="sigmoid"): + super(Sigmoid, self).__init__( + event_ndims=0, validate_args=validate_args, name=name) + + def _forward(self, x): + return math_ops.sigmoid(x) + + def _inverse(self, y): + return math_ops.log(y) - math_ops.log1p(-y) + + def _inverse_log_det_jacobian(self, y): + return -math_ops.log(y) - math_ops.log1p(-y) + + def _forward_log_det_jacobian(self, x): + return -nn_ops.softplus(-x) - nn_ops.softplus(x) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid_centered.py b/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid_centered.py index 448125230d24066697624bce03fed71a2c2f00b1..223bc9d042c69be05b0e578835a31ed6e83c0c97 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid_centered.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid_centered.py @@ -18,12 +18,22 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid_centered_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.contrib.distributions.python.ops.bijectors import softmax_centered -_allowed_symbols = ["SigmoidCentered"] -remove_undocumented(__name__, _allowed_symbols) +__all__ = [ + "SigmoidCentered", +] + + +class SigmoidCentered(softmax_centered.SoftmaxCentered): + """Bijector which computes Y = g(X) = exp([X 0]) / (1 + exp(-X)). + + Equivalent to: `bijector.SoftmaxCentered(event_ndims=0)`. + + See `bijector.SoftmaxCentered` for more details. + """ + + def __init__(self, validate_args=False, name="sigmoid_centered"): + super(SigmoidCentered, self).__init__( + event_ndims=0, validate_args=validate_args, name=name) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh.py b/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh.py index b3cf03c24612f5c618c71c0a8615f272acdf2d10..3a75e4ae9495793901b0da91a5aa3982aab35852 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh.py @@ -18,12 +18,162 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.sinh_arcsinh_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +import numpy as np -_allowed_symbols = ["SinhArcsinh"] +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bijector -remove_undocumented(__name__, _allowed_symbols) +__all__ = [ + "SinhArcsinh", +] + + +def _sqrtx2p1(x): + """Implementation of `sqrt(1 + x**2)` which is stable despite large `x`.""" + return array_ops.where( + math_ops.abs(x) * np.sqrt(np.finfo(x.dtype.as_numpy_dtype).eps) <= 1., + math_ops.sqrt(x**2. + 1.), + # For large x, calculating x**2 can overflow. This can be alleviated by + # considering: + # sqrt(1 + x**2) + # = exp(0.5 log(1 + x**2)) + # = exp(0.5 log(x**2 * (1 + x**-2))) + # = exp(log(x) + 0.5 * log(1 + x**-2)) + # = |x| * exp(0.5 log(1 + x**-2)) + # = |x| * sqrt(1 + x**-2) + # We omit the last term in this approximation. + # When |x| > 1 / sqrt(machineepsilon), the second term will be 1, + # due to sqrt(1 + x**-2) = 1. This is also true with the gradient term, + # and higher order gradients, since the first order derivative of + # sqrt(1 + x**-2) is -2 * x**-3 / (1 + x**-2) = -2 / (x**3 + x), + # and all nth-order derivatives will be O(x**-(n + 2)). This makes any + # gradient terms that contain any derivatives of sqrt(1 + x**-2) vanish. + math_ops.abs(x)) + + +class SinhArcsinh(bijector.Bijector): + """Compute `Y = g(X) = Sinh( (Arcsinh(X) + skewness) * tailweight )`. + + For `skewness in (-inf, inf)` and `tailweight in (0, inf)`, this + transformation is a + diffeomorphism of the real line `(-inf, inf)`. The inverse transform is + `X = g^{-1}(Y) = Sinh( ArcSinh(Y) / tailweight - skewness )`. + + The `SinhArcsinh` transformation of the Normal is described in + [Sinh-arcsinh distributions](https://www.jstor.org/stable/27798865) + This Bijector allows a similar transformation of any distribution supported on + `(-inf, inf)`. + + #### Meaning of the parameters + + * If `skewness = 0` and `tailweight = 1`, this transform is the identity. + * Positive (negative) `skewness` leads to positive (negative) skew. + * positive skew means, for unimodal `X` centered at zero, the mode of `Y` is + "tilted" to the right. + * positive skew means positive values of `Y` become more likely, and + negative values become less likely. + * Larger (smaller) `tailweight` leads to fatter (thinner) tails. + * Fatter tails mean larger values of `|Y|` become more likely. + * If `X` is a unit Normal, `tailweight < 1` leads to a distribution that is + "flat" around `Y = 0`, and a very steep drop-off in the tails. + * If `X` is a unit Normal, `tailweight > 1` leads to a distribution more + peaked at the mode with heavier tails. + + To see the argument about the tails, note that for `|X| >> 1` and + `|X| >> (|skewness| * tailweight)**tailweight`, we have + `Y approx 0.5 X**tailweight e**(sign(X) skewness * tailweight)`. + """ + + def __init__(self, + skewness=None, + tailweight=None, + event_ndims=0, + validate_args=False, + name="SinhArcsinh"): + """Instantiates the `SinhArcsinh` bijector. + + Args: + skewness: Skewness parameter. Float-type `Tensor`. Default is `0` + of type `float32`. + tailweight: Tailweight parameter. Positive `Tensor` of same `dtype` as + `skewness` and broadcastable `shape`. Default is `1` of type `float32`. + event_ndims: Python scalar indicating the number of dimensions associated + with a particular draw from the distribution. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + """ + self._graph_parents = [] + self._name = name + self._validate_args = validate_args + with self._name_scope("init", values=[skewness, tailweight]): + tailweight = 1. if tailweight is None else tailweight + skewness = 0. if skewness is None else skewness + self._skewness = ops.convert_to_tensor( + skewness, name="skewness") + self._tailweight = ops.convert_to_tensor( + tailweight, name="tailweight", dtype=self._skewness.dtype) + check_ops.assert_same_float_dtype([self._skewness, self._tailweight]) + if validate_args: + self._tailweight = control_flow_ops.with_dependencies([ + check_ops.assert_positive( + self._tailweight, + message="Argument tailweight was not positive") + ], self._tailweight) + super(SinhArcsinh, self).__init__( + event_ndims=event_ndims, validate_args=validate_args, name=name) + + @property + def skewness(self): + """The `skewness` in: `Y = Sinh((Arcsinh(X) + skewness) * tailweight)`.""" + return self._skewness + + @property + def tailweight(self): + """The `tailweight` in: `Y = Sinh((Arcsinh(X) + skewness) * tailweight)`.""" + return self._tailweight + + def _forward(self, x): + return math_ops.sinh((math_ops.asinh(x) + self.skewness) * self.tailweight) + + def _inverse(self, y): + return math_ops.sinh(math_ops.asinh(y) / self.tailweight - self.skewness) + + def _inverse_log_det_jacobian(self, y): + # x = sinh(arcsinh(y) / tailweight - skewness) + # Using sinh' = cosh, arcsinh'(y) = 1 / sqrt(y**2 + 1), + # dx/dy + # = cosh(arcsinh(y) / tailweight - skewness) + # / (tailweight * sqrt(y**2 + 1)) + event_dims = self._event_dims_tensor(y) + return math_ops.reduce_sum( + # This is computed inside the log to avoid catastrophic cancellations + # from cosh((arcsinh(y) / tailweight) - skewness) and sqrt(x**2 + 1). + math_ops.log(math_ops.cosh( + math_ops.asinh(y) / self.tailweight - self.skewness) + # TODO(srvasude): Consider using cosh(arcsinh(x)) in cases + # where (arcsinh(x) / tailweight) - skewness ~= arcsinh(x). + / _sqrtx2p1(y)) + - math_ops.log(self.tailweight), + axis=event_dims) + + def _forward_log_det_jacobian(self, x): + # y = sinh((arcsinh(x) + skewness) * tailweight) + # Using sinh' = cosh, arcsinh'(x) = 1 / sqrt(x**2 + 1), + # dy/dx + # = cosh((arcsinh(x) + skewness) * tailweight) * tailweight / sqrt(x**2 + 1) + event_dims = self._event_dims_tensor(x) + return math_ops.reduce_sum( + # This is computed inside the log to avoid catastrophic cancellations + # from cosh((arcsinh(x) + skewness) * tailweight) and sqrt(x**2 + 1). + math_ops.log(math_ops.cosh( + (math_ops.asinh(x) + self.skewness) * self.tailweight) + # TODO(srvasude): Consider using cosh(arcsinh(x)) in cases + # where (arcsinh(x) + skewness) * tailweight ~= arcsinh(x). + / _sqrtx2p1(x)) + + math_ops.log(self.tailweight), + axis=event_dims) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh_impl.py deleted file mode 100644 index 3a75e4ae9495793901b0da91a5aa3982aab35852..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh_impl.py +++ /dev/null @@ -1,179 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""SinhArcsinh bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops.distributions import bijector - -__all__ = [ - "SinhArcsinh", -] - - -def _sqrtx2p1(x): - """Implementation of `sqrt(1 + x**2)` which is stable despite large `x`.""" - return array_ops.where( - math_ops.abs(x) * np.sqrt(np.finfo(x.dtype.as_numpy_dtype).eps) <= 1., - math_ops.sqrt(x**2. + 1.), - # For large x, calculating x**2 can overflow. This can be alleviated by - # considering: - # sqrt(1 + x**2) - # = exp(0.5 log(1 + x**2)) - # = exp(0.5 log(x**2 * (1 + x**-2))) - # = exp(log(x) + 0.5 * log(1 + x**-2)) - # = |x| * exp(0.5 log(1 + x**-2)) - # = |x| * sqrt(1 + x**-2) - # We omit the last term in this approximation. - # When |x| > 1 / sqrt(machineepsilon), the second term will be 1, - # due to sqrt(1 + x**-2) = 1. This is also true with the gradient term, - # and higher order gradients, since the first order derivative of - # sqrt(1 + x**-2) is -2 * x**-3 / (1 + x**-2) = -2 / (x**3 + x), - # and all nth-order derivatives will be O(x**-(n + 2)). This makes any - # gradient terms that contain any derivatives of sqrt(1 + x**-2) vanish. - math_ops.abs(x)) - - -class SinhArcsinh(bijector.Bijector): - """Compute `Y = g(X) = Sinh( (Arcsinh(X) + skewness) * tailweight )`. - - For `skewness in (-inf, inf)` and `tailweight in (0, inf)`, this - transformation is a - diffeomorphism of the real line `(-inf, inf)`. The inverse transform is - `X = g^{-1}(Y) = Sinh( ArcSinh(Y) / tailweight - skewness )`. - - The `SinhArcsinh` transformation of the Normal is described in - [Sinh-arcsinh distributions](https://www.jstor.org/stable/27798865) - This Bijector allows a similar transformation of any distribution supported on - `(-inf, inf)`. - - #### Meaning of the parameters - - * If `skewness = 0` and `tailweight = 1`, this transform is the identity. - * Positive (negative) `skewness` leads to positive (negative) skew. - * positive skew means, for unimodal `X` centered at zero, the mode of `Y` is - "tilted" to the right. - * positive skew means positive values of `Y` become more likely, and - negative values become less likely. - * Larger (smaller) `tailweight` leads to fatter (thinner) tails. - * Fatter tails mean larger values of `|Y|` become more likely. - * If `X` is a unit Normal, `tailweight < 1` leads to a distribution that is - "flat" around `Y = 0`, and a very steep drop-off in the tails. - * If `X` is a unit Normal, `tailweight > 1` leads to a distribution more - peaked at the mode with heavier tails. - - To see the argument about the tails, note that for `|X| >> 1` and - `|X| >> (|skewness| * tailweight)**tailweight`, we have - `Y approx 0.5 X**tailweight e**(sign(X) skewness * tailweight)`. - """ - - def __init__(self, - skewness=None, - tailweight=None, - event_ndims=0, - validate_args=False, - name="SinhArcsinh"): - """Instantiates the `SinhArcsinh` bijector. - - Args: - skewness: Skewness parameter. Float-type `Tensor`. Default is `0` - of type `float32`. - tailweight: Tailweight parameter. Positive `Tensor` of same `dtype` as - `skewness` and broadcastable `shape`. Default is `1` of type `float32`. - event_ndims: Python scalar indicating the number of dimensions associated - with a particular draw from the distribution. - validate_args: Python `bool` indicating whether arguments should be - checked for correctness. - name: Python `str` name given to ops managed by this object. - """ - self._graph_parents = [] - self._name = name - self._validate_args = validate_args - with self._name_scope("init", values=[skewness, tailweight]): - tailweight = 1. if tailweight is None else tailweight - skewness = 0. if skewness is None else skewness - self._skewness = ops.convert_to_tensor( - skewness, name="skewness") - self._tailweight = ops.convert_to_tensor( - tailweight, name="tailweight", dtype=self._skewness.dtype) - check_ops.assert_same_float_dtype([self._skewness, self._tailweight]) - if validate_args: - self._tailweight = control_flow_ops.with_dependencies([ - check_ops.assert_positive( - self._tailweight, - message="Argument tailweight was not positive") - ], self._tailweight) - super(SinhArcsinh, self).__init__( - event_ndims=event_ndims, validate_args=validate_args, name=name) - - @property - def skewness(self): - """The `skewness` in: `Y = Sinh((Arcsinh(X) + skewness) * tailweight)`.""" - return self._skewness - - @property - def tailweight(self): - """The `tailweight` in: `Y = Sinh((Arcsinh(X) + skewness) * tailweight)`.""" - return self._tailweight - - def _forward(self, x): - return math_ops.sinh((math_ops.asinh(x) + self.skewness) * self.tailweight) - - def _inverse(self, y): - return math_ops.sinh(math_ops.asinh(y) / self.tailweight - self.skewness) - - def _inverse_log_det_jacobian(self, y): - # x = sinh(arcsinh(y) / tailweight - skewness) - # Using sinh' = cosh, arcsinh'(y) = 1 / sqrt(y**2 + 1), - # dx/dy - # = cosh(arcsinh(y) / tailweight - skewness) - # / (tailweight * sqrt(y**2 + 1)) - event_dims = self._event_dims_tensor(y) - return math_ops.reduce_sum( - # This is computed inside the log to avoid catastrophic cancellations - # from cosh((arcsinh(y) / tailweight) - skewness) and sqrt(x**2 + 1). - math_ops.log(math_ops.cosh( - math_ops.asinh(y) / self.tailweight - self.skewness) - # TODO(srvasude): Consider using cosh(arcsinh(x)) in cases - # where (arcsinh(x) / tailweight) - skewness ~= arcsinh(x). - / _sqrtx2p1(y)) - - math_ops.log(self.tailweight), - axis=event_dims) - - def _forward_log_det_jacobian(self, x): - # y = sinh((arcsinh(x) + skewness) * tailweight) - # Using sinh' = cosh, arcsinh'(x) = 1 / sqrt(x**2 + 1), - # dy/dx - # = cosh((arcsinh(x) + skewness) * tailweight) * tailweight / sqrt(x**2 + 1) - event_dims = self._event_dims_tensor(x) - return math_ops.reduce_sum( - # This is computed inside the log to avoid catastrophic cancellations - # from cosh((arcsinh(x) + skewness) * tailweight) and sqrt(x**2 + 1). - math_ops.log(math_ops.cosh( - (math_ops.asinh(x) + self.skewness) * self.tailweight) - # TODO(srvasude): Consider using cosh(arcsinh(x)) in cases - # where (arcsinh(x) + skewness) * tailweight ~= arcsinh(x). - / _sqrtx2p1(x)) - + math_ops.log(self.tailweight), - axis=event_dims) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py b/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py index be6608f97880ae68e10b17c815bf2d8438293261..e4a1d3dde230724e74d5076c5bba079590b94a70 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py @@ -18,12 +18,232 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.softmax_centered_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +import numpy as np -_allowed_symbols = ["SoftmaxCentered"] +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops.distributions import bijector -remove_undocumented(__name__, _allowed_symbols) + +__all__ = [ + "SoftmaxCentered", +] + + +class SoftmaxCentered(bijector.Bijector): + """Bijector which computes `Y = g(X) = exp([X 0]) / sum(exp([X 0]))`. + + To implement [softmax](https://en.wikipedia.org/wiki/Softmax_function) as a + bijection, the forward transformation appends a value to the input and the + inverse removes this coordinate. The appended coordinate represents a pivot, + e.g., `softmax(x) = exp(x-c) / sum(exp(x-c))` where `c` is the implicit last + coordinate. + + Because we append a coordinate, this bijector only supports `event_ndim in [0, + 1]`, i.e., scalars and vectors. + + Example Use: + + ```python + bijector.SoftmaxCentered(event_ndims=1).forward(tf.log([2, 3, 4])) + # Result: [0.2, 0.3, 0.4, 0.1] + # Extra result: 0.1 + + bijector.SoftmaxCentered(event_ndims=1).inverse([0.2, 0.3, 0.4, 0.1]) + # Result: tf.log([2, 3, 4]) + # Extra coordinate removed. + ``` + + At first blush it may seem like the [Invariance of domain]( + https://en.wikipedia.org/wiki/Invariance_of_domain) theorem implies this + implementation is not a bijection. However, the appended dimension + makes the (forward) image non-open and the theorem does not directly apply. + """ + + def __init__(self, + event_ndims=0, + validate_args=False, + name="softmax_centered"): + self._graph_parents = [] + self._name = name + with self._name_scope("init", values=[event_ndims]): + event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims") + event_ndims = tensor_util.constant_value(event_ndims) + if event_ndims is None or event_ndims not in [0, 1]: + raise ValueError("`event_ndims` must be a TF constant which is 0 or 1") + self._static_event_ndims = event_ndims + super(SoftmaxCentered, self).__init__( + event_ndims=event_ndims, + validate_args=validate_args, + name=name) + + def _forward_event_shape(self, input_shape): + if input_shape.ndims is None: + return input_shape + if input_shape.ndims != self._static_event_ndims: + raise ValueError("input_shape.dims = %d != %d" % + (input_shape.ndims, self._static_event_ndims)) + if input_shape.ndims == 0: + return tensor_shape.TensorShape([2]) + if input_shape.ndims == 1: + return tensor_shape.TensorShape(input_shape[0] + 1) + # Unreachable code: + raise ValueError("event_ndims = %d must be 0 or 1" % input_shape.ndims) + + def _forward_event_shape_tensor(self, input_shape): + ndims = array_ops.shape(input_shape) + if self.validate_args: + # It is not possible for a negative shape so we need only check <= 1. + is_zero_or_one = check_ops.assert_equal( + ndims, 0 if self._static_event_ndims == 0 else 1, + message="event_ndims must be 0 or 1") + ndims = control_flow_ops.with_dependencies([is_zero_or_one], ndims) + if self._static_event_ndims == 0: + return ops.convert_to_tensor( + [2], dtype=dtypes.int32, name="output_shape") + return input_shape + 1 + + def _inverse_event_shape(self, output_shape): + if output_shape.ndims is None: + return output_shape + if output_shape.ndims != 1: + raise ValueError("output_shape.ndims = %d != 1" % output_shape.ndims) + if self._static_event_ndims == 0: + return tensor_shape.TensorShape([]) + return tensor_shape.TensorShape(output_shape[0] - 1) + + def _inverse_event_shape_tensor(self, output_shape): + ndims = array_ops.shape(output_shape)[0] + if self.validate_args: + # It is not possible for a negative shape so we need only check <= 1. + is_one = check_ops.assert_equal( + ndims, 1, message="event_ndims must be 1") + ndims = control_flow_ops.with_dependencies([is_one], ndims) + if self._static_event_ndims == 0: + return ops.convert_to_tensor([], dtype=dtypes.int32, name="output_shape") + return array_ops.expand_dims(output_shape[0] - 1, dim=0) + + def _forward(self, x): + # Pad the last dim with a zeros vector. We need this because it lets us + # infer the scale in the inverse function. + y = array_ops.expand_dims(x, dim=-1) if self._static_event_ndims == 0 else x + ndims = _get_ndims(y) + y = array_ops.pad(y, paddings=array_ops.one_hot(indices=[-1, ndims - 1], + depth=ndims, + axis=0, + dtype=dtypes.int32)) + # Set shape hints. + if x.shape.ndims is not None: + shape = x.shape.as_list() + if self._static_event_ndims == 0: + shape += [2] + elif shape[-1] is not None: + shape[-1] += 1 + shape = tensor_shape.TensorShape(shape) + y.shape.assert_is_compatible_with(shape) + y.set_shape(shape) + + # Since we only support event_ndims in [0, 1] and we do padding, we always + # reduce over the last dimension, i.e., dim=-1 (which is the default). + return nn_ops.softmax(y) + + def _inverse(self, y): + # To derive the inverse mapping note that: + # y[i] = exp(x[i]) / normalization + # and + # y[end] = 1 / normalization. + # Thus: + # x[i] = log(exp(x[i])) - log(y[end]) - log(normalization) + # = log(exp(x[i])/normalization) - log(y[end]) + # = log(y[i]) - log(y[end]) + shape = (np.asarray(y.shape.as_list(), dtype=np.int32) + if y.shape.is_fully_defined() + else array_ops.shape(y, name="shape")) + ndims = _get_ndims(y) + + # Do this first to make sure CSE catches that it'll happen again in + # _inverse_log_det_jacobian. + x = math_ops.log(y) + + # We now extract the last coordinate of the rightmost dimension. + # Our trick is to slice from [0,0,...,shape[-1]-1] to shape[:-1]+[1]. + begin = array_ops.one_hot(indices=ndims-1, + depth=ndims, + on_value=shape[-1]-np.array(1, dtype=shape.dtype), + dtype=shape.dtype) + size = array_ops.concat([shape[:-1], np.asarray([1], dtype=shape.dtype)], 0) + log_normalization = -array_ops.strided_slice(x, begin, begin + size) + + # Here we slice out all but the last coordinate; see above for idea. + begin = array_ops.zeros_like(shape) + size = array_ops.concat([shape[:-1], [shape[-1] - 1]], 0) + x = array_ops.strided_slice(x, begin, begin + size) + + x += log_normalization + + if self._static_event_ndims == 0: + x = array_ops.squeeze(x, squeeze_dims=[ndims-1]) + + # Set shape hints. + if y.shape.ndims is not None: + shape = y.shape.as_list() + if self._static_event_ndims == 0: + shape = shape[:-1] + elif shape[-1] is not None: + shape[-1] -= 1 + shape = tensor_shape.TensorShape(shape) + x.shape.assert_is_compatible_with(shape) + x.set_shape(shape) + + return x + + def _inverse_log_det_jacobian(self, y): + # WLOG, consider the vector case: + # x = log(y[:-1]) - log(y[-1]) + # where, + # y[-1] = 1 - sum(y[:-1]). + # We have: + # det{ dX/dY } = det{ diag(1 ./ y[:-1]) + 1 / y[-1] } + # = det{ inv{ diag(y[:-1]) - y[:-1]' y[:-1] } } (1) + # = 1 / det{ diag(y[:-1]) - y[:-1]' y[:-1] } + # = 1 / { (1 + y[:-1]' inv(diag(y[:-1])) y[:-1]) * + # det(diag(y[:-1])) } (2) + # = 1 / { y[-1] prod(y[:-1]) } + # = 1 / prod(y) + # (1) - https://en.wikipedia.org/wiki/Sherman%E2%80%93Morrison_formula + # or by noting that det{ dX/dY } = 1 / det{ dY/dX } from Bijector + # docstring "Tip". + # (2) - https://en.wikipedia.org/wiki/Matrix_determinant_lemma + return -math_ops.reduce_sum(math_ops.log(y), axis=-1) + + def _forward_log_det_jacobian(self, x): + if self._static_event_ndims == 0: + return x - 2. * nn_ops.softplus(x) + else: + # This code is similar to nn_ops.log_softmax but different because we have + # an implicit zero column to handle. I.e., instead of: + # reduce_sum(logits - reduce_sum(exp(logits), dim)) + # we must do: + # log_normalization = 1 + reduce_sum(exp(logits)) + # -log_normalization + reduce_sum(logits - log_normalization) + log_normalization = nn_ops.softplus( + math_ops.reduce_logsumexp(x, axis=-1, keep_dims=True)) + fldj = (-log_normalization + + math_ops.reduce_sum(x - log_normalization, + axis=-1, + keep_dims=True)) + return array_ops.squeeze(fldj, squeeze_dims=-1) + + +def _get_ndims(x): + """Returns `ndims`, statically if possible.""" + if x.shape.ndims is not None: + return x.shape.ndims + return array_ops.rank(x, name="ndims") diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered_impl.py deleted file mode 100644 index 8645cc1b6b04be75a419342591272f07a4a1711c..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered_impl.py +++ /dev/null @@ -1,245 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""SoftmaxCentered bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_util -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn_ops -from tensorflow.python.ops.distributions import bijector - - -__all__ = [ - "SoftmaxCentered", -] - - -class SoftmaxCentered(bijector.Bijector): - """Bijector which computes `Y = g(X) = exp([X 0]) / sum(exp([X 0]))`. - - To implement [softmax](https://en.wikipedia.org/wiki/Softmax_function) as a - bijection, the forward transformation appends a value to the input and the - inverse removes this coordinate. The appended coordinate represents a pivot, - e.g., `softmax(x) = exp(x-c) / sum(exp(x-c))` where `c` is the implicit last - coordinate. - - Because we append a coordinate, this bijector only supports `event_ndim in [0, - 1]`, i.e., scalars and vectors. - - Example Use: - - ```python - bijector.SoftmaxCentered(event_ndims=1).forward(tf.log([2, 3, 4])) - # Result: [0.2, 0.3, 0.4, 0.1] - # Extra result: 0.1 - - bijector.SoftmaxCentered(event_ndims=1).inverse([0.2, 0.3, 0.4, 0.1]) - # Result: tf.log([2, 3, 4]) - # Extra coordinate removed. - ``` - - At first blush it may seem like the [Invariance of domain]( - https://en.wikipedia.org/wiki/Invariance_of_domain) theorem implies this - implementation is not a bijection. However, the appended dimension - makes the (forward) image non-open and the theorem does not directly apply. - """ - - def __init__(self, - event_ndims=0, - validate_args=False, - name="softmax_centered"): - self._graph_parents = [] - self._name = name - with self._name_scope("init", values=[event_ndims]): - event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims") - event_ndims = tensor_util.constant_value(event_ndims) - if event_ndims is None or event_ndims not in [0, 1]: - raise ValueError("`event_ndims` must be a TF constant which is 0 or 1") - self._static_event_ndims = event_ndims - super(SoftmaxCentered, self).__init__( - event_ndims=event_ndims, - validate_args=validate_args, - name=name) - - def _forward_event_shape(self, input_shape): - if input_shape.ndims is None: - return input_shape - if input_shape.ndims != self._static_event_ndims: - raise ValueError("input_shape.dims = %d != %d" % - (input_shape.ndims, self._static_event_ndims)) - if input_shape.ndims == 0: - return tensor_shape.TensorShape([2]) - if input_shape.ndims == 1: - return tensor_shape.TensorShape(input_shape[0] + 1) - # Unreachable code: - raise ValueError("event_ndims = %d must be 0 or 1" % input_shape.ndims) - - def _forward_event_shape_tensor(self, input_shape): - ndims = array_ops.shape(input_shape) - if self.validate_args: - # It is not possible for a negative shape so we need only check <= 1. - is_zero_or_one = check_ops.assert_equal( - ndims, 0 if self._static_event_ndims == 0 else 1, - message="event_ndims must be 0 or 1") - ndims = control_flow_ops.with_dependencies([is_zero_or_one], ndims) - if self._static_event_ndims == 0: - return ops.convert_to_tensor( - [2], dtype=dtypes.int32, name="output_shape") - return input_shape + 1 - - def _inverse_event_shape(self, output_shape): - if output_shape.ndims is None: - return output_shape - if output_shape.ndims != 1: - raise ValueError("output_shape.ndims = %d != 1" % output_shape.ndims) - if self._static_event_ndims == 0: - return tensor_shape.TensorShape([]) - return tensor_shape.TensorShape(output_shape[0] - 1) - - def _inverse_event_shape_tensor(self, output_shape): - ndims = array_ops.shape(output_shape)[0] - if self.validate_args: - # It is not possible for a negative shape so we need only check <= 1. - is_one = check_ops.assert_equal( - ndims, 1, message="event_ndims must be 1") - ndims = control_flow_ops.with_dependencies([is_one], ndims) - if self._static_event_ndims == 0: - return ops.convert_to_tensor([], dtype=dtypes.int32, name="output_shape") - return array_ops.expand_dims(output_shape[0] - 1, dim=0) - - def _forward(self, x): - # Pad the last dim with a zeros vector. We need this because it lets us - # infer the scale in the inverse function. - y = array_ops.expand_dims(x, dim=-1) if self._static_event_ndims == 0 else x - ndims = (y.get_shape().ndims if y.get_shape().ndims is not None - else array_ops.rank(y)) - y = array_ops.pad(y, - paddings=array_ops.concat( - (array_ops.zeros( - (ndims - 1, 2), dtype=dtypes.int32), [[0, 1]]), - 0)) - - # Set shape hints. - if x.get_shape().ndims is not None: - shape = x.get_shape().as_list() - if self._static_event_ndims == 0: - shape += [2] - elif shape[-1] is not None: - shape[-1] += 1 - shape = tensor_shape.TensorShape(shape) - y.get_shape().assert_is_compatible_with(shape) - y.set_shape(shape) - - # Since we only support event_ndims in [0, 1] and we do padding, we always - # reduce over the last dimension, i.e., dim=-1 (which is the default). - return nn_ops.softmax(y) - - def _inverse(self, y): - # To derive the inverse mapping note that: - # y[i] = exp(x[i]) / normalization - # and - # y[end] = 1 / normalization. - # Thus: - # x[i] = log(exp(x[i])) - log(y[end]) - log(normalization) - # = log(exp(x[i])/normalization) - log(y[end]) - # = log(y[i]) - log(y[end]) - shape = (np.asarray(y.get_shape().as_list(), dtype=np.int32) - if y.get_shape().is_fully_defined() - else array_ops.shape(y, name="shape")) - ndims = y.get_shape().ndims or math_ops.rank(y, name="ndims") - - # Do this first to make sure CSE catches that it'll happen again in - # _inverse_log_det_jacobian. - x = math_ops.log(y) - - # We now extract the last coordinate of the rightmost dimension. - # Our trick is to slice from [0,0,...,shape[-1]-1] to shape[:-1]+[1]. - begin = array_ops.one_hot(indices=ndims-1, - depth=ndims, - on_value=shape[-1]-np.array(1, dtype=shape.dtype), - dtype=shape.dtype) - size = array_ops.concat([shape[:-1], np.asarray([1], dtype=shape.dtype)], 0) - log_normalization = -array_ops.strided_slice(x, begin, begin + size) - - # Here we slice out all but the last coordinate; see above for idea. - begin = array_ops.zeros_like(shape) - size = array_ops.concat([shape[:-1], [shape[-1] - 1]], 0) - x = array_ops.strided_slice(x, begin, begin + size) - - x += log_normalization - - if self._static_event_ndims == 0: - x = array_ops.squeeze(x, squeeze_dims=[ndims-1]) - - # Set shape hints. - if y.get_shape().ndims is not None: - shape = y.get_shape().as_list() - if self._static_event_ndims == 0: - shape = shape[:-1] - elif shape[-1] is not None: - shape[-1] -= 1 - shape = tensor_shape.TensorShape(shape) - x.get_shape().assert_is_compatible_with(shape) - x.set_shape(shape) - - return x - - def _inverse_log_det_jacobian(self, y): - # WLOG, consider the vector case: - # x = log(y[:-1]) - log(y[-1]) - # where, - # y[-1] = 1 - sum(y[:-1]). - # We have: - # det{ dX/dY } = det{ diag(1 ./ y[:-1]) + 1 / y[-1] } - # = det{ inv{ diag(y[:-1]) - y[:-1]' y[:-1] } } (1) - # = 1 / det{ diag(y[:-1]) - y[:-1]' y[:-1] } - # = 1 / { (1 + y[:-1]' inv(diag(y[:-1])) y[:-1]) * - # det(diag(y[:-1])) } (2) - # = 1 / { y[-1] prod(y[:-1]) } - # = 1 / prod(y) - # (1) - https://en.wikipedia.org/wiki/Sherman%E2%80%93Morrison_formula - # or by noting that det{ dX/dY } = 1 / det{ dY/dX } from Bijector - # docstring "Tip". - # (2) - https://en.wikipedia.org/wiki/Matrix_determinant_lemma - return -math_ops.reduce_sum(math_ops.log(y), axis=-1) - - def _forward_log_det_jacobian(self, x): - if self._static_event_ndims == 0: - return x - 2. * nn_ops.softplus(x) - else: - # This code is similar to nn_ops.log_softmax but different because we have - # an implicit zero column to handle. I.e., instead of: - # reduce_sum(logits - reduce_sum(exp(logits), dim)) - # we must do: - # log_normalization = 1 + reduce_sum(exp(logits)) - # -log_normalization + reduce_sum(logits - log_normalization) - log_normalization = nn_ops.softplus( - math_ops.reduce_logsumexp(x, axis=-1, keep_dims=True)) - fldj = (-log_normalization + - math_ops.reduce_sum(x - log_normalization, - axis=-1, - keep_dims=True)) - return array_ops.squeeze(fldj, squeeze_dims=-1) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/softplus.py b/tensorflow/contrib/distributions/python/ops/bijectors/softplus.py index 250a1144b53bb43271ff7ee494604d9bae6feda8..81957fcf78922fa15fd20a25d144071f431161ae 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/softplus.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/softplus.py @@ -18,12 +18,127 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.softplus_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.python.framework import ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops.distributions import bijector +from tensorflow.python.ops.distributions import util as distribution_util -_allowed_symbols = ["Softplus"] -remove_undocumented(__name__, _allowed_symbols) +__all__ = [ + "Softplus", +] + + +class Softplus(bijector.Bijector): + """Bijector which computes `Y = g(X) = Log[1 + exp(X)]`. + + The softplus `Bijector` has the following two useful properties: + + * The domain is the positive real numbers + * `softplus(x) approx x`, for large `x`, so it does not overflow as easily as + the `Exp` `Bijector`. + + The optional nonzero `hinge_softness` parameter changes the transition at + zero. With `hinge_softness = c`, the bijector is: + + ```f_c(x) := c * g(x / c) = c * Log[1 + exp(x / c)].``` + + For large `x >> 1`, `c * Log[1 + exp(x / c)] approx c * Log[exp(x / c)] = x`, + so the behavior for large `x` is the same as the standard softplus. + + As `c > 0` approaches 0 from the right, `f_c(x)` becomes less and less soft, + approaching `max(0, x)`. + + * `c = 1` is the default. + * `c > 0` but small means `f(x) approx ReLu(x) = max(0, x)`. + * `c < 0` flips sign and reflects around the `y-axis`: `f_{-c}(x) = -f_c(-x)`. + * `c = 0` results in a non-bijective transformation and triggers an exception. + + Example Use: + + ```python + # Create the Y=g(X)=softplus(X) transform which works only on Tensors with 1 + # batch ndim and 2 event ndims (i.e., vector of matrices). + softplus = Softplus(event_ndims=2) + x = [[[1., 2], + [3, 4]], + [[5, 6], + [7, 8]]] + log(1 + exp(x)) == softplus.forward(x) + log(exp(x) - 1) == softplus.inverse(x) + ``` + + Note: log(.) and exp(.) are applied element-wise but the Jacobian is a + reduction over the event space. + """ + + @distribution_util.AppendDocstring( + kwargs_dict={ + "hinge_softness": ( + "Nonzero floating point `Tensor`. Controls the softness of what " + "would otherwise be a kink at the origin. Default is 1.0")}) + def __init__(self, + event_ndims=0, + hinge_softness=None, + validate_args=False, + name="softplus"): + with ops.name_scope(name, values=[hinge_softness]): + if hinge_softness is not None: + self._hinge_softness = ops.convert_to_tensor( + hinge_softness, name="hinge_softness") + else: + self._hinge_softness = None + if validate_args: + nonzero_check = check_ops.assert_none_equal( + ops.convert_to_tensor( + 0, dtype=self.hinge_softness.dtype), + self.hinge_softness, + message="hinge_softness must be non-zero") + self._hinge_softness = control_flow_ops.with_dependencies( + [nonzero_check], self.hinge_softness) + + super(Softplus, self).__init__( + event_ndims=event_ndims, + validate_args=validate_args, + name=name) + + def _forward(self, x): + if self.hinge_softness is None: + return nn_ops.softplus(x) + hinge_softness = math_ops.cast(self.hinge_softness, x.dtype) + return hinge_softness * nn_ops.softplus(x / hinge_softness) + + def _inverse(self, y): + if self.hinge_softness is None: + return distribution_util.softplus_inverse(y) + hinge_softness = math_ops.cast(self.hinge_softness, y.dtype) + return hinge_softness * distribution_util.softplus_inverse( + y / hinge_softness) + + def _inverse_log_det_jacobian(self, y): + # Could also do: + # ildj = math_ops.reduce_sum(y - distribution_util.softplus_inverse(y), + # axis=event_dims) + # but the following is more numerically stable. Ie, + # Y = Log[1 + exp{X}] ==> X = Log[exp{Y} - 1] + # ==> dX/dY = exp{Y} / (exp{Y} - 1) + # = 1 / (1 - exp{-Y}), + # which is the most stable for large Y > 0. For small Y, we use + # 1 - exp{-Y} approx Y. + if self.hinge_softness is not None: + y /= math_ops.cast(self.hinge_softness, y.dtype) + return -math_ops.reduce_sum(math_ops.log(-math_ops.expm1(-y)), + axis=self._event_dims_tensor(y)) + + def _forward_log_det_jacobian(self, x): + if self.hinge_softness is not None: + x /= math_ops.cast(self.hinge_softness, x.dtype) + return -math_ops.reduce_sum(nn_ops.softplus(-x), + axis=self._event_dims_tensor(x)) + + @property + def hinge_softness(self): + return self._hinge_softness diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/softplus_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/softplus_impl.py deleted file mode 100644 index 81957fcf78922fa15fd20a25d144071f431161ae..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/softplus_impl.py +++ /dev/null @@ -1,144 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Softplus bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.framework import ops -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn_ops -from tensorflow.python.ops.distributions import bijector -from tensorflow.python.ops.distributions import util as distribution_util - - -__all__ = [ - "Softplus", -] - - -class Softplus(bijector.Bijector): - """Bijector which computes `Y = g(X) = Log[1 + exp(X)]`. - - The softplus `Bijector` has the following two useful properties: - - * The domain is the positive real numbers - * `softplus(x) approx x`, for large `x`, so it does not overflow as easily as - the `Exp` `Bijector`. - - The optional nonzero `hinge_softness` parameter changes the transition at - zero. With `hinge_softness = c`, the bijector is: - - ```f_c(x) := c * g(x / c) = c * Log[1 + exp(x / c)].``` - - For large `x >> 1`, `c * Log[1 + exp(x / c)] approx c * Log[exp(x / c)] = x`, - so the behavior for large `x` is the same as the standard softplus. - - As `c > 0` approaches 0 from the right, `f_c(x)` becomes less and less soft, - approaching `max(0, x)`. - - * `c = 1` is the default. - * `c > 0` but small means `f(x) approx ReLu(x) = max(0, x)`. - * `c < 0` flips sign and reflects around the `y-axis`: `f_{-c}(x) = -f_c(-x)`. - * `c = 0` results in a non-bijective transformation and triggers an exception. - - Example Use: - - ```python - # Create the Y=g(X)=softplus(X) transform which works only on Tensors with 1 - # batch ndim and 2 event ndims (i.e., vector of matrices). - softplus = Softplus(event_ndims=2) - x = [[[1., 2], - [3, 4]], - [[5, 6], - [7, 8]]] - log(1 + exp(x)) == softplus.forward(x) - log(exp(x) - 1) == softplus.inverse(x) - ``` - - Note: log(.) and exp(.) are applied element-wise but the Jacobian is a - reduction over the event space. - """ - - @distribution_util.AppendDocstring( - kwargs_dict={ - "hinge_softness": ( - "Nonzero floating point `Tensor`. Controls the softness of what " - "would otherwise be a kink at the origin. Default is 1.0")}) - def __init__(self, - event_ndims=0, - hinge_softness=None, - validate_args=False, - name="softplus"): - with ops.name_scope(name, values=[hinge_softness]): - if hinge_softness is not None: - self._hinge_softness = ops.convert_to_tensor( - hinge_softness, name="hinge_softness") - else: - self._hinge_softness = None - if validate_args: - nonzero_check = check_ops.assert_none_equal( - ops.convert_to_tensor( - 0, dtype=self.hinge_softness.dtype), - self.hinge_softness, - message="hinge_softness must be non-zero") - self._hinge_softness = control_flow_ops.with_dependencies( - [nonzero_check], self.hinge_softness) - - super(Softplus, self).__init__( - event_ndims=event_ndims, - validate_args=validate_args, - name=name) - - def _forward(self, x): - if self.hinge_softness is None: - return nn_ops.softplus(x) - hinge_softness = math_ops.cast(self.hinge_softness, x.dtype) - return hinge_softness * nn_ops.softplus(x / hinge_softness) - - def _inverse(self, y): - if self.hinge_softness is None: - return distribution_util.softplus_inverse(y) - hinge_softness = math_ops.cast(self.hinge_softness, y.dtype) - return hinge_softness * distribution_util.softplus_inverse( - y / hinge_softness) - - def _inverse_log_det_jacobian(self, y): - # Could also do: - # ildj = math_ops.reduce_sum(y - distribution_util.softplus_inverse(y), - # axis=event_dims) - # but the following is more numerically stable. Ie, - # Y = Log[1 + exp{X}] ==> X = Log[exp{Y} - 1] - # ==> dX/dY = exp{Y} / (exp{Y} - 1) - # = 1 / (1 - exp{-Y}), - # which is the most stable for large Y > 0. For small Y, we use - # 1 - exp{-Y} approx Y. - if self.hinge_softness is not None: - y /= math_ops.cast(self.hinge_softness, y.dtype) - return -math_ops.reduce_sum(math_ops.log(-math_ops.expm1(-y)), - axis=self._event_dims_tensor(y)) - - def _forward_log_det_jacobian(self, x): - if self.hinge_softness is not None: - x /= math_ops.cast(self.hinge_softness, x.dtype) - return -math_ops.reduce_sum(nn_ops.softplus(-x), - axis=self._event_dims_tensor(x)) - - @property - def hinge_softness(self): - return self._hinge_softness diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/weibull.py b/tensorflow/contrib/distributions/python/ops/bijectors/weibull.py index d439f28884d8bd7f2b808317e10c5b5e44bfcfa2..00520bcda85e9527767e6342bf75f10667c264a8 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/weibull.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/weibull.py @@ -18,12 +18,132 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.weibull_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bijector -_allowed_symbols = ["Weibull"] -remove_undocumented(__name__, _allowed_symbols) +__all__ = [ + "Weibull", +] + + +class Weibull(bijector.Bijector): + """Compute `Y = g(X) = 1 - exp((-X / scale) ** concentration), X >= 0`. + + This bijector maps inputs from `[0, inf]` to [0, 1]`. The inverse of the + bijector applied to a uniform random variable `X ~ U(0, 1) gives back a + random variable with the + [Weibull distribution](https://en.wikipedia.org/wiki/Weibull_distribution): + + ```none + Y ~ Weibull(scale, concentration) + pdf(y; scale, concentration, y >= 0) = (scale / concentration) * ( + scale / concentration) ** (concentration - 1) * exp( + -(y / scale) ** concentration) + ``` + """ + + def __init__(self, + scale=1., + concentration=1., + event_ndims=0, + validate_args=False, + name="weibull"): + """Instantiates the `Weibull` bijector. + + Args: + scale: Positive Float-type `Tensor` that is the same dtype and is + broadcastable with `concentration`. + This is `l` in `Y = g(X) = 1 - exp((-x / l) ** k)`. + concentration: Positive Float-type `Tensor` that is the same dtype and is + broadcastable with `scale`. + This is `k` in `Y = g(X) = 1 - exp((-x / l) ** k)`. + event_ndims: Python scalar indicating the number of dimensions associated + with a particular draw from the distribution. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + """ + self._graph_parents = [] + self._name = name + self._validate_args = validate_args + with self._name_scope("init", values=[scale, concentration]): + self._scale = ops.convert_to_tensor(scale, name="scale") + self._concentration = ops.convert_to_tensor( + concentration, name="concentration") + check_ops.assert_same_float_dtype([self._scale, self._concentration]) + if validate_args: + self._scale = control_flow_ops.with_dependencies([ + check_ops.assert_positive( + self._scale, + message="Argument scale was not positive") + ], self._scale) + self._concentration = control_flow_ops.with_dependencies([ + check_ops.assert_positive( + self._concentration, + message="Argument concentration was not positive") + ], self._concentration) + + super(Weibull, self).__init__( + event_ndims=event_ndims, + validate_args=validate_args, + name=name) + + @property + def scale(self): + """The `l` in `Y = g(X) = 1 - exp((-x / l) ** k)`.""" + return self._scale + + @property + def concentration(self): + """The `k` in `Y = g(X) = 1 - exp((-x / l) ** k)`.""" + return self._concentration + + def _forward(self, x): + x = self._maybe_assert_valid_x(x) + return -math_ops.expm1(-((x / self.scale) ** self.concentration)) + + def _inverse(self, y): + y = self._maybe_assert_valid_y(y) + return self.scale * (-math_ops.log1p(-y)) ** (1 / self.concentration) + + def _inverse_log_det_jacobian(self, y): + y = self._maybe_assert_valid_y(y) + event_dims = self._event_dims_tensor(y) + return math_ops.reduce_sum( + -math_ops.log1p(-y) + + (1 / self.concentration - 1) * math_ops.log(-math_ops.log1p(-y)) + + math_ops.log(self.scale / self.concentration), + axis=event_dims) + + def _forward_log_det_jacobian(self, x): + x = self._maybe_assert_valid_x(x) + event_dims = self._event_dims_tensor(x) + return math_ops.reduce_sum( + -(x / self.scale) ** self.concentration + + (self.concentration - 1) * math_ops.log(x) + + math_ops.log(self.concentration) + + -self.concentration * math_ops.log(self.scale), + axis=event_dims) + + def _maybe_assert_valid_x(self, x): + if not self.validate_args: + return x + is_valid = check_ops.assert_non_negative( + x, + message="Forward transformation input must be at least {}.".format(0)) + return control_flow_ops.with_dependencies([is_valid], x) + + def _maybe_assert_valid_y(self, y): + if not self.validate_args: + return y + is_positive = check_ops.assert_non_negative( + y, message="Inverse transformation input must be greater than 0.") + less_than_one = check_ops.assert_less_equal( + y, constant_op.constant(1., y.dtype), + message="Inverse transformation input must be less than or equal to 1.") + return control_flow_ops.with_dependencies([is_positive, less_than_one], y) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/weibull_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/weibull_impl.py deleted file mode 100644 index 00520bcda85e9527767e6342bf75f10667c264a8..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/weibull_impl.py +++ /dev/null @@ -1,149 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Weibull bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import ops -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops.distributions import bijector - - -__all__ = [ - "Weibull", -] - - -class Weibull(bijector.Bijector): - """Compute `Y = g(X) = 1 - exp((-X / scale) ** concentration), X >= 0`. - - This bijector maps inputs from `[0, inf]` to [0, 1]`. The inverse of the - bijector applied to a uniform random variable `X ~ U(0, 1) gives back a - random variable with the - [Weibull distribution](https://en.wikipedia.org/wiki/Weibull_distribution): - - ```none - Y ~ Weibull(scale, concentration) - pdf(y; scale, concentration, y >= 0) = (scale / concentration) * ( - scale / concentration) ** (concentration - 1) * exp( - -(y / scale) ** concentration) - ``` - """ - - def __init__(self, - scale=1., - concentration=1., - event_ndims=0, - validate_args=False, - name="weibull"): - """Instantiates the `Weibull` bijector. - - Args: - scale: Positive Float-type `Tensor` that is the same dtype and is - broadcastable with `concentration`. - This is `l` in `Y = g(X) = 1 - exp((-x / l) ** k)`. - concentration: Positive Float-type `Tensor` that is the same dtype and is - broadcastable with `scale`. - This is `k` in `Y = g(X) = 1 - exp((-x / l) ** k)`. - event_ndims: Python scalar indicating the number of dimensions associated - with a particular draw from the distribution. - validate_args: Python `bool` indicating whether arguments should be - checked for correctness. - name: Python `str` name given to ops managed by this object. - """ - self._graph_parents = [] - self._name = name - self._validate_args = validate_args - with self._name_scope("init", values=[scale, concentration]): - self._scale = ops.convert_to_tensor(scale, name="scale") - self._concentration = ops.convert_to_tensor( - concentration, name="concentration") - check_ops.assert_same_float_dtype([self._scale, self._concentration]) - if validate_args: - self._scale = control_flow_ops.with_dependencies([ - check_ops.assert_positive( - self._scale, - message="Argument scale was not positive") - ], self._scale) - self._concentration = control_flow_ops.with_dependencies([ - check_ops.assert_positive( - self._concentration, - message="Argument concentration was not positive") - ], self._concentration) - - super(Weibull, self).__init__( - event_ndims=event_ndims, - validate_args=validate_args, - name=name) - - @property - def scale(self): - """The `l` in `Y = g(X) = 1 - exp((-x / l) ** k)`.""" - return self._scale - - @property - def concentration(self): - """The `k` in `Y = g(X) = 1 - exp((-x / l) ** k)`.""" - return self._concentration - - def _forward(self, x): - x = self._maybe_assert_valid_x(x) - return -math_ops.expm1(-((x / self.scale) ** self.concentration)) - - def _inverse(self, y): - y = self._maybe_assert_valid_y(y) - return self.scale * (-math_ops.log1p(-y)) ** (1 / self.concentration) - - def _inverse_log_det_jacobian(self, y): - y = self._maybe_assert_valid_y(y) - event_dims = self._event_dims_tensor(y) - return math_ops.reduce_sum( - -math_ops.log1p(-y) + - (1 / self.concentration - 1) * math_ops.log(-math_ops.log1p(-y)) + - math_ops.log(self.scale / self.concentration), - axis=event_dims) - - def _forward_log_det_jacobian(self, x): - x = self._maybe_assert_valid_x(x) - event_dims = self._event_dims_tensor(x) - return math_ops.reduce_sum( - -(x / self.scale) ** self.concentration + - (self.concentration - 1) * math_ops.log(x) + - math_ops.log(self.concentration) + - -self.concentration * math_ops.log(self.scale), - axis=event_dims) - - def _maybe_assert_valid_x(self, x): - if not self.validate_args: - return x - is_valid = check_ops.assert_non_negative( - x, - message="Forward transformation input must be at least {}.".format(0)) - return control_flow_ops.with_dependencies([is_valid], x) - - def _maybe_assert_valid_y(self, y): - if not self.validate_args: - return y - is_positive = check_ops.assert_non_negative( - y, message="Inverse transformation input must be greater than 0.") - less_than_one = check_ops.assert_less_equal( - y, constant_op.constant(1., y.dtype), - message="Inverse transformation input must be less than or equal to 1.") - return control_flow_ops.with_dependencies([is_positive, less_than_one], y) diff --git a/tensorflow/contrib/distributions/python/ops/cauchy.py b/tensorflow/contrib/distributions/python/ops/cauchy.py new file mode 100644 index 0000000000000000000000000000000000000000..6f5d724a2a945ed8f9c159d8314327c6f994d1db --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/cauchy.py @@ -0,0 +1,221 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The Cauchy distribution class.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops.distributions import distribution + +__all__ = [ + "Cauchy", +] + + +class Cauchy(distribution.Distribution): + """The Cauchy distribution with location `loc` and scale `scale`. + + #### Mathematical details + + The probability density function (pdf) is, + + ```none + pdf(x; loc, scale) = 1 / (pi scale (1 + z**2)) + z = (x - loc) / scale + ``` + where `loc` is the location, and `scale` is the scale. + + The Cauchy distribution is a member of the [location-scale family]( + https://en.wikipedia.org/wiki/Location-scale_family), i.e. + `Y ~ Cauchy(loc, scale)` is equivalent to, + + ```none + X ~ Cauchy(loc=0, scale=1) + Y = loc + scale * X + ``` + + #### Examples + + Examples of initialization of one or a batch of distributions. + + ```python + tfd = tf.contrib.distributions + + # Define a single scalar Cauchy distribution. + dist = tfd.Cauchy(loc=0., scale=3.) + + # Evaluate the cdf at 1, returning a scalar. + dist.cdf(1.) + + # Define a batch of two scalar valued Cauchy distributions. + dist = tfd.Cauchy(loc=[1, 2.], scale=[11, 22.]) + + # Evaluate the pdf of the first distribution on 0, and the second on 1.5, + # returning a length two tensor. + dist.prob([0, 1.5]) + + # Get 3 samples, returning a 3 x 2 tensor. + dist.sample([3]) + + # Arguments are broadcast when possible. + # Define a batch of two scalar valued Cauchy distributions. + # Both have median 1, but different scales. + dist = tfd.Cauchy(loc=1., scale=[11, 22.]) + + # Evaluate the pdf of both distributions on the same point, 3.0, + # returning a length 2 tensor. + dist.prob(3.) + ``` + + """ + + def __init__(self, + loc, + scale, + validate_args=False, + allow_nan_stats=True, + name="Cauchy"): + """Construct Cauchy distributions. + + The parameters `loc` and `scale` must be shaped in a way that supports + broadcasting (e.g. `loc + scale` is a valid operation). + + Args: + loc: Floating point tensor; the modes of the distribution(s). + scale: Floating point tensor; the locations of the distribution(s). + Must contain only positive values. + validate_args: Python `bool`, default `False`. When `True` distribution + parameters are checked for validity despite possibly degrading runtime + performance. When `False` invalid inputs may silently render incorrect + outputs. + allow_nan_stats: Python `bool`, default `True`. When `True`, + statistics (e.g., mean, mode, variance) use the value "`NaN`" to + indicate the result is undefined. When `False`, an exception is raised + if one or more of the statistic's batch members are undefined. + name: Python `str` name prefixed to Ops created by this class. + + Raises: + TypeError: if `loc` and `scale` have different `dtype`. + """ + parameters = locals() + with ops.name_scope(name, values=[loc, scale]): + with ops.control_dependencies([check_ops.assert_positive(scale)] + if validate_args else []): + self._loc = array_ops.identity(loc, name="loc") + self._scale = array_ops.identity(scale, name="scale") + check_ops.assert_same_float_dtype([self._loc, self._scale]) + super(Cauchy, self).__init__( + dtype=self._scale.dtype, + reparameterization_type=distribution.FULLY_REPARAMETERIZED, + validate_args=validate_args, + allow_nan_stats=allow_nan_stats, + parameters=parameters, + graph_parents=[self._loc, self._scale], + name=name) + + @staticmethod + def _param_shapes(sample_shape): + return dict( + zip(("loc", "scale"), + ([ops.convert_to_tensor(sample_shape, dtype=dtypes.int32)] * 2))) + + @property + def loc(self): + """Distribution parameter for the mean.""" + return self._loc + + @property + def scale(self): + """Distribution parameter for standard deviation.""" + return self._scale + + def _batch_shape_tensor(self): + return array_ops.broadcast_dynamic_shape( + array_ops.shape(self.loc), array_ops.shape(self.scale)) + + def _batch_shape(self): + return array_ops.broadcast_static_shape(self.loc.shape, self.scale.shape) + + def _event_shape_tensor(self): + return constant_op.constant([], dtype=dtypes.int32) + + def _event_shape(self): + return tensor_shape.scalar() + + def _sample_n(self, n, seed=None): + shape = array_ops.concat([[n], self.batch_shape_tensor()], 0) + probs = random_ops.random_uniform( + shape=shape, minval=0., maxval=1., dtype=self.dtype, seed=seed) + return self._quantile(probs) + + def _log_prob(self, x): + return self._log_unnormalized_prob(x) - self._log_normalization() + + def _cdf(self, x): + return math_ops.atan(self._z(x)) / np.pi + 0.5 + + def _log_cdf(self, x): + return math_ops.log1p(2 / np.pi * math_ops.atan(self._z(x))) - np.log(2) + + def _log_unnormalized_prob(self, x): + return -math_ops.log1p(math_ops.square(self._z(x))) + + def _log_normalization(self): + return np.log(np.pi) + math_ops.log(self.scale) + + def _entropy(self): + h = np.log(4 * np.pi) + math_ops.log(self.scale) + return h * array_ops.ones_like(self.loc) + + def _quantile(self, p): + return self.loc + self.scale * math_ops.tan(np.pi * (p - 0.5)) + + def _mode(self): + return self.loc * array_ops.ones_like(self.scale) + + def _z(self, x): + """Standardize input `x`.""" + with ops.name_scope("standardize", values=[x]): + return (x - self.loc) / self.scale + + def _inv_z(self, z): + """Reconstruct input `x` from a its normalized version.""" + with ops.name_scope("reconstruct", values=[z]): + return z * self.scale + self.loc + + def _mean(self): + if self.allow_nan_stats: + return array_ops.fill(self.batch_shape_tensor(), + self.dtype.as_numpy_dtype(np.nan)) + else: + raise ValueError("`mean` is undefined for Cauchy distribution.") + + def _stddev(self): + if self.allow_nan_stats: + return array_ops.fill(self.batch_shape_tensor(), + self.dtype.as_numpy_dtype(np.nan)) + else: + raise ValueError("`stddev` is undefined for Cauchy distribution.") diff --git a/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py b/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py index 599c855cda434d9249187d5d154d50a8a8c49a6c..1d4c5660d8d73b7b6a7e758fc834ccfddeb5c8ea 100644 --- a/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py +++ b/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py @@ -121,7 +121,7 @@ class ConditionalTransformedDistribution( log_prob = self.distribution.log_prob(x, **distribution_kwargs) if self._is_maybe_event_override: log_prob = math_ops.reduce_sum(log_prob, self._reduce_event_indices) - return ildj + log_prob + return math_ops.cast(ildj, log_prob.dtype) + log_prob @distribution_util.AppendDocstring(kwargs_dict=_condition_kwargs_dict) def _prob(self, y, bijector_kwargs=None, distribution_kwargs=None): @@ -143,7 +143,7 @@ class ConditionalTransformedDistribution( prob = self.distribution.prob(x, **distribution_kwargs) if self._is_maybe_event_override: prob = math_ops.reduce_prod(prob, self._reduce_event_indices) - return math_ops.exp(ildj) * prob + return math_ops.exp(math_ops.cast(ildj, prob.dtype)) * prob @distribution_util.AppendDocstring(kwargs_dict=_condition_kwargs_dict) def _log_cdf(self, y, bijector_kwargs=None, distribution_kwargs=None): diff --git a/tensorflow/contrib/distributions/python/ops/deterministic.py b/tensorflow/contrib/distributions/python/ops/deterministic.py index 850d08d1bd69ebc7661557d648e2bffe77e6a908..8049522e9f5dc26b244b7e710a9ae8b981efd6b6 100644 --- a/tensorflow/contrib/distributions/python/ops/deterministic.py +++ b/tensorflow/contrib/distributions/python/ops/deterministic.py @@ -290,8 +290,10 @@ class VectorDeterministic(_BaseDeterministic): #### Examples ```python + tfd = tf.contrib.distributions + # Initialize a single VectorDeterministic supported at [0., 2.] in R^2. - constant = tf.contrib.distributions.Deterministic([0., 2.]) + constant = tfd.Deterministic([0., 2.]) constant.prob([0., 2.]) ==> 1. constant.prob([0., 3.]) @@ -299,7 +301,7 @@ class VectorDeterministic(_BaseDeterministic): # Initialize a [3] batch of constants on R^2. loc = [[0., 1.], [2., 3.], [4., 5.]] - constant = constant_lib.VectorDeterministic(loc) + constant = tfd.VectorDeterministic(loc) constant.prob([[0., 1.], [1.9, 3.], [3.99, 5.]]) ==> [1., 0., 0.] ``` diff --git a/tensorflow/contrib/distributions/python/ops/gumbel.py b/tensorflow/contrib/distributions/python/ops/gumbel.py index ba8d3c639b397422f0f6210ba9f48650f0da1e3e..d0efaefb8e78ddf4436e9e5a112d2c1cdddaf3b5 100644 --- a/tensorflow/contrib/distributions/python/ops/gumbel.py +++ b/tensorflow/contrib/distributions/python/ops/gumbel.py @@ -62,15 +62,17 @@ class _Gumbel(distribution.Distribution): Examples of initialization of one or a batch of distributions. ```python + tfd = tf.contrib.distributions + # Define a single scalar Gumbel distribution. - dist = tf.contrib.distributions.Gumbel(loc=0., scale=3.) + dist = tfd.Gumbel(loc=0., scale=3.) # Evaluate the cdf at 1, returning a scalar. dist.cdf(1.) # Define a batch of two scalar valued Gumbels. # The first has mean 1 and scale 11, the second 2 and 22. - dist = tf.contrib.distributions.Gumbel(loc=[1, 2.], scale=[11, 22.]) + dist = tfd.Gumbel(loc=[1, 2.], scale=[11, 22.]) # Evaluate the pdf of the first distribution on 0, and the second on 1.5, # returning a length two tensor. @@ -85,7 +87,7 @@ class _Gumbel(distribution.Distribution): ```python # Define a batch of two scalar valued Logistics. # Both have mean 1, but different scales. - dist = tf.contrib.distributions.Gumbel(loc=1., scale=[11, 22.]) + dist = tfd.Gumbel(loc=1., scale=[11, 22.]) # Evaluate the pdf of both distributions on the same point, 3.0, # returning a length 2 tensor. diff --git a/tensorflow/contrib/distributions/python/ops/half_normal.py b/tensorflow/contrib/distributions/python/ops/half_normal.py new file mode 100644 index 0000000000000000000000000000000000000000..fc0751a6e0b78cb3d79bd3478e740bb05cd26428 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/half_normal.py @@ -0,0 +1,171 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The Half Normal distribution class.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import random_ops +from tensorflow.python.ops.distributions import distribution +from tensorflow.python.ops.distributions import special_math + + +__all__ = [ + "HalfNormal", +] + + +class HalfNormal(distribution.Distribution): + """The Half Normal distribution with scale `scale`. + + #### Mathematical details + + The half normal is a transformation of a centered normal distribution. + If some random variable `X` has normal distribution, + ```none + X ~ Normal(0.0, scale) + Y = |X| + ``` + Then `Y` will have half normal distribution. The probability density + function (pdf) is: + + ```none + pdf(x; scale, x > 0) = sqrt(2) / (scale * sqrt(pi)) * + exp(- 1/2 * (x / scale) ** 2) + ) + ``` + Where `scale = sigma` is the standard deviation of the underlying normal + distribution. + + #### Examples + + Examples of initialization of one or a batch of distributions. + + ```python + # Define a single scalar HalfNormal distribution. + dist = tf.contrib.distributions.HalfNormal(scale=3.0) + + # Evaluate the cdf at 1, returning a scalar. + dist.cdf(1.) + + # Define a batch of two scalar valued HalfNormals. + # The first has scale 11.0, the second 22.0 + dist = tf.contrib.distributions.HalfNormal(scale=[11.0, 22.0]) + + # Evaluate the pdf of the first distribution on 1.0, and the second on 1.5, + # returning a length two tensor. + dist.prob([1.0, 1.5]) + + # Get 3 samples, returning a 3 x 2 tensor. + dist.sample([3]) + ``` + + """ + + def __init__(self, + scale, + validate_args=False, + allow_nan_stats=True, + name="HalfNormal"): + """Construct HalfNormals with scale `scale`. + + Args: + scale: Floating point tensor; the scales of the distribution(s). + Must contain only positive values. + validate_args: Python `bool`, default `False`. When `True` distribution + parameters are checked for validity despite possibly degrading runtime + performance. When `False` invalid inputs may silently render incorrect + outputs. + allow_nan_stats: Python `bool`, default `True`. When `True`, + statistics (e.g., mean, mode, variance) use the value "`NaN`" to + indicate the result is undefined. When `False`, an exception is raised + if one or more of the statistic's batch members are undefined. + name: Python `str` name prefixed to Ops created by this class. + """ + parameters = locals() + with ops.name_scope(name, values=[scale]): + with ops.control_dependencies([check_ops.assert_positive(scale)] if + validate_args else []): + self._scale = array_ops.identity(scale, name="scale") + super(HalfNormal, self).__init__( + dtype=self._scale.dtype, + reparameterization_type=distribution.FULLY_REPARAMETERIZED, + validate_args=validate_args, + allow_nan_stats=allow_nan_stats, + parameters=parameters, + graph_parents=[self._scale], + name=name) + + @staticmethod + def _param_shapes(sample_shape): + return {"scale": ops.convert_to_tensor(sample_shape, dtype=dtypes.int32)} + + @property + def scale(self): + """Distribution parameter for the scale.""" + return self._scale + + def _batch_shape_tensor(self): + return array_ops.shape(self.scale) + + def _batch_shape(self): + return self.scale.shape + + def _event_shape_tensor(self): + return constant_op.constant([], dtype=dtypes.int32) + + def _event_shape(self): + return tensor_shape.scalar() + + def _sample_n(self, n, seed=None): + shape = array_ops.concat([[n], self.batch_shape_tensor()], 0) + sampled = random_ops.random_normal( + shape=shape, mean=0., stddev=1., dtype=self.dtype, seed=seed) + return math_ops.abs(sampled * self.scale) + + def _prob(self, x): + coeff = np.sqrt(2) / self.scale / np.sqrt(np.pi) + pdf = coeff * math_ops.exp(- 0.5 * (x / self.scale) ** 2) + return pdf * math_ops.cast(x >= 0, self.dtype) + + def _cdf(self, x): + truncated_x = nn.relu(x) + return math_ops.erf(truncated_x / self.scale / np.sqrt(2.0)) + + def _entropy(self): + return 0.5 * math_ops.log(np.pi * self.scale ** 2.0 / 2.0) + 0.5 + + def _mean(self): + return self.scale * np.sqrt(2.0) / np.sqrt(np.pi) + + def _quantile(self, p): + return np.sqrt(2.0) * self.scale * special_math.erfinv(p) + + def _mode(self): + return array_ops.zeros(self.batch_shape_tensor()) + + def _variance(self): + return self.scale ** 2.0 * (1.0 - 2.0 / np.pi) diff --git a/tensorflow/contrib/distributions/python/ops/independent.py b/tensorflow/contrib/distributions/python/ops/independent.py index 6a74ca9a0ae1ad30081d21cc15a65be052a99e2a..cbce005013281ff3c58c94d525d5ce7a865d725a 100644 --- a/tensorflow/contrib/distributions/python/ops/independent.py +++ b/tensorflow/contrib/distributions/python/ops/independent.py @@ -68,11 +68,11 @@ class Independent(distribution_lib.Distribution): #### Examples ```python - ds = tf.contrib.distributions + tfd = tf.contrib.distributions # Make independent distribution from a 2-batch Normal. - ind = ds.Independent( - distribution=ds.Normal(loc=[-1., 1], scale=[0.1, 0.5]), + ind = tfd.Independent( + distribution=tfd.Normal(loc=[-1., 1], scale=[0.1, 0.5]), reinterpreted_batch_ndims=1) # All batch dims have been "absorbed" into event dims. @@ -80,8 +80,8 @@ class Independent(distribution_lib.Distribution): ind.event_shape # ==> [2] # Make independent distribution from a 2-batch bivariate Normal. - ind = ds.Independent( - distribution=ds.MultivariateNormalDiag( + ind = tfd.Independent( + distribution=tfd.MultivariateNormalDiag( loc=[[-1., 1], [1, -1]], scale_identity_multiplier=[1., 0.5]), reinterpreted_batch_ndims=1) diff --git a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py index 956dee38a378813434656a28a69c89b6ec1e8b72..ee4d86867d48b20e97757bcec57d452085814b80 100644 --- a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py +++ b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py @@ -88,8 +88,9 @@ class InverseGamma(distribution.Distribution): #### Examples ```python - dist = InverseGamma(concentration=3.0, rate=2.0) - dist2 = InverseGamma(concentration=[3.0, 4.0], rate=[2.0, 3.0]) + tfd = tf.contrib.distributions + dist = tfd.InverseGamma(concentration=3.0, rate=2.0) + dist2 = tfd.InverseGamma(concentration=[3.0, 4.0], rate=[2.0, 3.0]) ``` """ diff --git a/tensorflow/contrib/distributions/python/ops/logistic.py b/tensorflow/contrib/distributions/python/ops/logistic.py index 48794a48828fe796e233e968d8c755136ce166ad..473677f8d91b184e029f345bb05f5c5d63df7a40 100644 --- a/tensorflow/contrib/distributions/python/ops/logistic.py +++ b/tensorflow/contrib/distributions/python/ops/logistic.py @@ -60,15 +60,17 @@ class Logistic(distribution.Distribution): Examples of initialization of one or a batch of distributions. ```python + tfd = tf.contrib.distributions + # Define a single scalar Logistic distribution. - dist = tf.contrib.distributions.Logistic(loc=0., scale=3.) + dist = tfd.Logistic(loc=0., scale=3.) # Evaluate the cdf at 1, returning a scalar. dist.cdf(1.) # Define a batch of two scalar valued Logistics. # The first has mean 1 and scale 11, the second 2 and 22. - dist = tf.contrib.distributions.Logistic(loc=[1, 2.], scale=[11, 22.]) + dist = tfd.Logistic(loc=[1, 2.], scale=[11, 22.]) # Evaluate the pdf of the first distribution on 0, and the second on 1.5, # returning a length two tensor. @@ -76,14 +78,11 @@ class Logistic(distribution.Distribution): # Get 3 samples, returning a 3 x 2 tensor. dist.sample([3]) - ``` - Arguments are broadcast when possible. - - ```python + # Arguments are broadcast when possible. # Define a batch of two scalar valued Logistics. # Both have mean 1, but different scales. - dist = tf.contrib.distributions.Logistic(loc=1., scale=[11, 22.]) + dist = tfd.Logistic(loc=1., scale=[11, 22.]) # Evaluate the pdf of both distributions on the same point, 3.0, # returning a length 2 tensor. diff --git a/tensorflow/contrib/distributions/python/ops/mixture.py b/tensorflow/contrib/distributions/python/ops/mixture.py index e676931d9145e72907d990148ee2d180e0da0258..f2d492f5489a197157558ae727416b51db04793e 100644 --- a/tensorflow/contrib/distributions/python/ops/mixture.py +++ b/tensorflow/contrib/distributions/python/ops/mixture.py @@ -49,13 +49,13 @@ class Mixture(distribution.Distribution): ```python # Create a mixture of two Gaussians: - ds = tf.contrib.distributions + tfd = tf.contrib.distributions mix = 0.3 - bimix_gauss = ds.Mixture( - cat=ds.Categorical(probs=[mix, 1.-mix]), + bimix_gauss = tfd.Mixture( + cat=tfd.Categorical(probs=[mix, 1.-mix]), components=[ - ds.Normal(loc=-1., scale=0.1), - ds.Normal(loc=+1., scale=0.5), + tfd.Normal(loc=-1., scale=0.1), + tfd.Normal(loc=+1., scale=0.5), ]) # Plot the PDF. diff --git a/tensorflow/contrib/distributions/python/ops/mixture_same_family.py b/tensorflow/contrib/distributions/python/ops/mixture_same_family.py index 5558ef0f255db684b229d129666634e50c625887..0ca236c3761f9d3a0fcc79ff9db792319108db0d 100644 --- a/tensorflow/contrib/distributions/python/ops/mixture_same_family.py +++ b/tensorflow/contrib/distributions/python/ops/mixture_same_family.py @@ -43,15 +43,14 @@ class MixtureSameFamily(distribution.Distribution): #### Examples ```python - import matplotlib.pyplot as plt - ds = tf.contrib.distributions + tfd = tf.contrib.distributions ### Create a mixture of two scalar Gaussians: - gm = ds.MixtureSameFamily( - mixture_distribution=ds.Categorical( + gm = tfd.MixtureSameFamily( + mixture_distribution=tfd.Categorical( probs=[0.3, 0.7]), - components_distribution=ds.Normal( + components_distribution=tfd.Normal( loc=[-1., 1], # One for each component. scale=[0.1, 0.5])) # And same here. @@ -63,14 +62,15 @@ class MixtureSameFamily(distribution.Distribution): # Plot PDF. x = np.linspace(-2., 3., int(1e4), dtype=np.float32) + import matplotlib.pyplot as plt plt.plot(x, gm.prob(x).eval()); ### Create a mixture of two Bivariate Gaussians: - gm = ds.MixtureSameFamily( - mixture_distribution=ds.Categorical( + gm = tfd.MixtureSameFamily( + mixture_distribution=tfd.Categorical( probs=[0.3, 0.7]), - components_distribution=ds.MultivariateNormalDiag( + components_distribution=tfd.MultivariateNormalDiag( loc=[[-1., 1], # component 1 [1, -1]], # component 2 scale_identity_multiplier=[.3, .6])) @@ -320,13 +320,14 @@ class MixtureSameFamily(distribution.Distribution): return array_ops.shape(d.batch_shape_tensor())[0] dist_batch_ndims = _get_ndims(self) cat_batch_ndims = _get_ndims(self.mixture_distribution) - bnd = distribution_util.pick_vector( + pad_ndims = array_ops.where( self.mixture_distribution.is_scalar_batch(), - [dist_batch_ndims], [cat_batch_ndims])[0] + dist_batch_ndims, + dist_batch_ndims - cat_batch_ndims) s = array_ops.shape(x) x = array_ops.reshape(x, shape=array_ops.concat([ s[:-1], - array_ops.ones([bnd], dtype=dtypes.int32), + array_ops.ones([pad_ndims], dtype=dtypes.int32), s[-1:], array_ops.ones([self._event_ndims], dtype=dtypes.int32), ], axis=0)) diff --git a/tensorflow/contrib/distributions/python/ops/mvn_diag.py b/tensorflow/contrib/distributions/python/ops/mvn_diag.py index 163cf75d990d5fe7ec1e3aaf0040fc71f61774a7..e862552880f4073c8fa8e90134d0633e7484b0bf 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_diag.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_diag.py @@ -84,10 +84,10 @@ class MultivariateNormalDiag( #### Examples ```python - ds = tf.contrib.distributions + tfd = tf.contrib.distributions # Initialize a single 2-variate Gaussian. - mvn = ds.MultivariateNormalDiag( + mvn = tfd.MultivariateNormalDiag( loc=[1., -1], scale_diag=[1, 2.]) @@ -101,7 +101,7 @@ class MultivariateNormalDiag( mvn.prob([-1., 0]).eval() # shape: [] # Initialize a 3-batch, 2-variate scaled-identity Gaussian. - mvn = ds.MultivariateNormalDiag( + mvn = tfd.MultivariateNormalDiag( loc=[1., -1], scale_identity_multiplier=[1, 2., 3]) @@ -119,7 +119,7 @@ class MultivariateNormalDiag( mvn.prob([-1., 0]).eval() # shape: [3] # Initialize a 2-batch of 3-variate Gaussians. - mvn = ds.MultivariateNormalDiag( + mvn = tfd.MultivariateNormalDiag( loc=[[1., 2, 3], [11, 22, 33]] # shape: [2, 3] scale_diag=[[1., 2, 3], diff --git a/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py b/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py index 040bc230722194316b8a74627344e315a2578281..413e88f03ae0286c294f3404549a73e1a47dcff7 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py @@ -86,7 +86,7 @@ class MultivariateNormalDiagPlusLowRank( #### Examples ```python - ds = tf.contrib.distributions + tfd = tf.contrib.distributions # Initialize a single 3-variate Gaussian with covariance `cov = S @ S.T`, # `S = diag(d) + U @ diag(m) @ U.T`. The perturbation, `U @ diag(m) @ U.T`, is @@ -97,7 +97,7 @@ class MultivariateNormalDiagPlusLowRank( [-1, 1], [2, -0.5]] # shape: [3, 2] m = [4., 5] # shape: [2] - mvn = ds.MultivariateNormalDiagPlusLowRank( + mvn = tfd.MultivariateNormalDiagPlusLowRank( loc=mu scale_diag=d scale_perturb_factor=U, @@ -118,7 +118,7 @@ class MultivariateNormalDiagPlusLowRank( m = [[0.1, 0.2], [0.4, 0.5]] # shape: [b, r] = [2, 2] - mvn = ds.MultivariateNormalDiagPlusLowRank( + mvn = tfd.MultivariateNormalDiagPlusLowRank( loc=mu, scale_perturb_factor=U, scale_perturb_diag=m) diff --git a/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py b/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py index f9952b2069d6dfd2593e6bd71ede0badf44cdf98..8e69dadfb42e8d885b3af552b1f093b2857a6aa3 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py @@ -73,14 +73,14 @@ class MultivariateNormalFullCovariance(mvn_tril.MultivariateNormalTriL): #### Examples ```python - ds = tf.contrib.distributions + tfd = tf.contrib.distributions # Initialize a single 3-variate Gaussian. mu = [1., 2, 3] cov = [[ 0.36, 0.12, 0.06], [ 0.12, 0.29, -0.13], [ 0.06, -0.13, 0.26]] - mvn = ds.MultivariateNormalFullCovariance( + mvn = tfd.MultivariateNormalFullCovariance( loc=mu, covariance_matrix=cov) @@ -100,7 +100,7 @@ class MultivariateNormalFullCovariance(mvn_tril.MultivariateNormalTriL): mu = [[1., 2, 3], [11, 22, 33]] # shape: [2, 3] covariance_matrix = ... # shape: [2, 3, 3], symmetric, positive definite. - mvn = ds.MultivariateNormalFullCovariance( + mvn = tfd.MultivariateNormalFullCovariance( loc=mu, covariance=covariance_matrix) diff --git a/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py b/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py index 300bdd5f6064a1cc9c336689ac4fae04338edb30..a7399792892f4c179c05168184d76ec95c168b51 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py @@ -90,8 +90,7 @@ class MultivariateNormalLinearOperator( #### Examples ```python - ds = tf.contrib.distributions - la = tf.linalg + tfd = tf.contrib.distributions # Initialize a single 3-variate Gaussian. mu = [1., 2, 3] @@ -103,9 +102,9 @@ class MultivariateNormalLinearOperator( # [ 0.2, 0.5, 0. ], # [ 0.1, -0.3, 0.4]]) - mvn = ds.MultivariateNormalLinearOperator( + mvn = tfd.MultivariateNormalLinearOperator( loc=mu, - scale=la.LinearOperatorLowerTriangular(scale)) + scale=tf.linalg.LinearOperatorLowerTriangular(scale)) # Covariance agrees with cholesky(cov) parameterization. mvn.covariance().eval() @@ -122,9 +121,9 @@ class MultivariateNormalLinearOperator( scale_diag = [[1., 2, 3], [0.5, 1, 1.5]] # shape: [2, 3] - mvn = ds.MultivariateNormalLinearOperator( + mvn = tfd.MultivariateNormalLinearOperator( loc=mu, - scale=la.LinearOperatorDiag(scale_diag)) + scale=tf.linalg.LinearOperatorDiag(scale_diag)) # Compute the pdf of two `R^3` observations; return a length-2 vector. x = [[-0.9, 0, 0.1], diff --git a/tensorflow/contrib/distributions/python/ops/mvn_tril.py b/tensorflow/contrib/distributions/python/ops/mvn_tril.py index 260dcc18f513d5440d3d39368539274c03faa72a..6c7dc4ca7aaf5b3a20b072e9360d15528ad10556 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_tril.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_tril.py @@ -76,12 +76,13 @@ class MultivariateNormalTriL( ``` Trainable (batch) lower-triangular matrices can be created with - `ds.matrix_diag_transform()` and/or `ds.fill_triangular()` + `tf.contrib.distributions.matrix_diag_transform()` and/or + `tf.contrib.distributions.fill_triangular()` #### Examples ```python - ds = tf.contrib.distributions + tfd = tf.contrib.distributions # Initialize a single 3-variate Gaussian. mu = [1., 2, 3] @@ -92,7 +93,7 @@ class MultivariateNormalTriL( # ==> [[ 0.6, 0. , 0. ], # [ 0.2, 0.5, 0. ], # [ 0.1, -0.3, 0.4]]) - mvn = ds.MultivariateNormalTriL( + mvn = tfd.MultivariateNormalTriL( loc=mu, scale_tril=scale) @@ -112,7 +113,7 @@ class MultivariateNormalTriL( mu = [[1., 2, 3], [11, 22, 33]] # shape: [2, 3] tril = ... # shape: [2, 3, 3], lower triangular, non-zero diagonal. - mvn = ds.MultivariateNormalTriL( + mvn = tfd.MultivariateNormalTriL( loc=mu, scale_tril=tril) @@ -124,9 +125,9 @@ class MultivariateNormalTriL( # Instantiate a "learnable" MVN. dims = 4 with tf.variable_scope("model"): - mvn = ds.MultivariateNormalTriL( + mvn = tfd.MultivariateNormalTriL( loc=tf.get_variable(shape=[dims], dtype=tf.float32, name="mu"), - scale_tril=ds.fill_triangular( + scale_tril=tfd.fill_triangular( tf.get_variable(shape=[dims * (dims + 1) / 2], dtype=tf.float32, name="chol_Sigma"))) ``` diff --git a/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py b/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py index 8a95038a3c8eccf8a75fea79d0a62f9883b4f13a..2701c36fb53b1ae3fd736be3b1288e3dd40c739a 100644 --- a/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py +++ b/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py @@ -107,10 +107,11 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): #### Examples ```python - ds = tf.contrib.distributions + tfd = tf.contrib.distributions + # Create two batches of PoissonLogNormalQuadratureCompounds, one with # prior `loc = 0.` and another with `loc = 1.` In both cases `scale = 1.` - pln = ds.PoissonLogNormalQuadratureCompound( + pln = tfd.PoissonLogNormalQuadratureCompound( loc=[0., -0.5], scale=1., quadrature_grid_and_probs=( @@ -292,7 +293,7 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): # where, # # Z|v ~ interpolate_affine[v](distribution) - # V ~ mixture_distrubution + # V ~ mixture_distribution # # thus, # diff --git a/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py b/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py index b05f15771a3a94779ffddea8f16ad2fa4ea2fdd1..c4b8f055b7fbc3f0835b503eddd7617610326d8c 100644 --- a/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py +++ b/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py @@ -115,7 +115,7 @@ class SinhArcsinh(transformed_distribution.TransformedDistribution): tailweight: Tailweight parameter. Default is `1.0` (unchanged tailweight) distribution: `tf.Distribution`-like instance. Distribution that is transformed to produce this distribution. - Default is `ds.Normal(0., 1.)`. + Default is `tf.distributions.Normal(0., 1.)`. Must be a scalar-batch, scalar-event distribution. Typically `distribution.reparameterization_type = FULLY_REPARAMETERIZED` or it is a function of non-trainable parameters. WARNING: If you backprop through diff --git a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py index 92043d6a08833888c36009261addca0d14949ea8..904724af429f3cb5835f6e05abcb574467ef6918 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py +++ b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py @@ -188,8 +188,7 @@ class VectorDiffeomixture(distribution_lib.Distribution): #### Examples ```python - ds = tf.contrib.distributions - la = tf.linalg + tfd = tf.contrib.distributions # Create two batches of VectorDiffeomixtures, one with mix_loc=[0.] and # another with mix_loc=[1]. In both cases, `K=2` and the affine @@ -197,20 +196,20 @@ class VectorDiffeomixture(distribution_lib.Distribution): # k=0: loc=zeros(dims) scale=LinearOperatorScaledIdentity # k=1: loc=[2.]*dims scale=LinOpDiag dims = 5 - vdm = ds.VectorDiffeomixture( + vdm = tfd.VectorDiffeomixture( mix_loc=[[0.], [1]], mix_scale=[1.], - distribution=ds.Normal(loc=0., scale=1.), + distribution=tfd.Normal(loc=0., scale=1.), loc=[ None, # Equivalent to `np.zeros(dims, dtype=np.float32)`. np.float32([2.]*dims), ], scale=[ - la.LinearOperatorScaledIdentity( + tf.linalg.LinearOperatorScaledIdentity( num_rows=dims, multiplier=np.float32(1.1), is_positive_definite=True), - la.LinearOperatorDiag( + tf.linalg.LinearOperatorDiag( diag=np.linspace(2.5, 3.5, dims, dtype=np.float32), is_positive_definite=True), ], diff --git a/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py b/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py index 356d78b67a8107750f68f7f84d73d1231f5b2b03..526fe2d39aef9aed833b889de80e849c469435e7 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py +++ b/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py @@ -89,14 +89,13 @@ class VectorExponentialDiag( #### Examples ```python - ds = tf.contrib.distributions - la = tf.linalg + tfd = tf.contrib.distributions # Initialize a single 2-variate VectorExponential, supported on # {(x, y) in R^2 : x > 0, y > 0}. # The first component has pdf exp{-x}, the second 0.5 exp{-x / 2} - vex = ds.VectorExponentialDiag(scale_diag=[1., 2.]) + vex = tfd.VectorExponentialDiag(scale_diag=[1., 2.]) # Compute the pdf of an`R^2` observation; return a scalar. vex.prob([3., 4.]).eval() # shape: [] @@ -107,7 +106,7 @@ class VectorExponentialDiag( scale_diag = [[1., 2, 3], [0.5, 1, 1.5]] # shape: [2, 3] - vex = ds.VectorExponentialDiag(loc, scale_diag) + vex = tfd.VectorExponentialDiag(loc, scale_diag) # Compute the pdf of two `R^3` observations; return a length-2 vector. x = [[1.9, 2.2, 3.1], diff --git a/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py b/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py index b313a851b381e5b3a057fd17e6c2ef4eb0fc34f1..9d5fd9ac4178a1ae29b1ce32f304b22fd3d234dc 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py +++ b/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py @@ -107,16 +107,15 @@ class VectorExponentialLinearOperator( #### Examples ```python - ds = tf.contrib.distributions - la = tf.linalg + tfd = tf.contrib.distributions # Initialize a single 2-variate VectorExponential, supported on # {(x, y) in R^2 : x > 0, y > 0}. mat = [[1.0, 0.1], [0.1, 1.0]] - vex = ds.VectorExponentialLinearOperator( - scale=la.LinearOperatorFullMatrix(mat)) + vex = tfd.VectorExponentialLinearOperator( + scale=tf.linalg.LinearOperatorFullMatrix(mat)) # Compute the pdf of an`R^2` observation; return a scalar. vex.prob([1., 2.]).eval() # shape: [] @@ -127,9 +126,9 @@ class VectorExponentialLinearOperator( scale_diag = [[1., 2, 3], [0.5, 1, 1.5]] # shape: [2, 3] - vex = ds.VectorExponentialLinearOperator( + vex = tfd.VectorExponentialLinearOperator( loc=mu, - scale=la.LinearOperatorDiag(scale_diag)) + scale=tf.linalg.LinearOperatorDiag(scale_diag)) # Compute the pdf of two `R^3` observations; return a length-2 vector. x = [[1.9, 2.2, 3.1], diff --git a/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py b/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py index 0e3867809a820f49cfa7f5282c47f786626481a6..8dd983b750d9b39775e570800006011f4968f7f3 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py +++ b/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py @@ -101,10 +101,10 @@ class VectorLaplaceDiag( #### Examples ```python - ds = tf.contrib.distributions + tfd = tf.contrib.distributions # Initialize a single 2-variate VectorLaplace. - vla = ds.VectorLaplaceDiag( + vla = tfd.VectorLaplaceDiag( loc=[1., -1], scale_diag=[1, 2.]) @@ -118,7 +118,7 @@ class VectorLaplaceDiag( vla.prob([-1., 0]).eval() # shape: [] # Initialize a 3-batch, 2-variate scaled-identity VectorLaplace. - vla = ds.VectorLaplaceDiag( + vla = tfd.VectorLaplaceDiag( loc=[1., -1], scale_identity_multiplier=[1, 2., 3]) @@ -136,7 +136,7 @@ class VectorLaplaceDiag( vla.prob([-1., 0]).eval() # shape: [3] # Initialize a 2-batch of 3-variate VectorLaplace's. - vla = ds.VectorLaplaceDiag( + vla = tfd.VectorLaplaceDiag( loc=[[1., 2, 3], [11, 22, 33]] # shape: [2, 3] scale_diag=[[1., 2, 3], diff --git a/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py b/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py index c7abdbb4caf9bee4cbd5991eb5d652f20dd0f8d1..ec485c95c15da2794b67d2699d2bdd9db97bb6c4 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py +++ b/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py @@ -109,8 +109,7 @@ class VectorLaplaceLinearOperator( #### Examples ```python - ds = tf.contrib.distributions - la = tf.linalg + tfd = tf.contrib.distributions # Initialize a single 3-variate VectorLaplace with some desired covariance. mu = [1., 2, 3] @@ -124,9 +123,9 @@ class VectorLaplaceLinearOperator( # [ 0.1, -0.3, 0.4]]) # Divide scale by sqrt(2) so that the final covariance will be what we want. - vla = ds.VectorLaplaceLinearOperator( + vla = tfd.VectorLaplaceLinearOperator( loc=mu, - scale=la.LinearOperatorLowerTriangular(scale / tf.sqrt(2))) + scale=tf.linalg.LinearOperatorLowerTriangular(scale / tf.sqrt(2.))) # Covariance agrees with cholesky(cov) parameterization. vla.covariance().eval() @@ -143,9 +142,9 @@ class VectorLaplaceLinearOperator( scale_diag = [[1., 2, 3], [0.5, 1, 1.5]] # shape: [2, 3] - vla = ds.VectorLaplaceLinearOperator( + vla = tfd.VectorLaplaceLinearOperator( loc=mu, - scale=la.LinearOperatorDiag(scale_diag)) + scale=tf.linalg.LinearOperatorDiag(scale_diag)) # Compute the pdf of two `R^3` observations; return a length-2 vector. x = [[-0.9, 0, 0.1], diff --git a/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py b/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py index 544a8710709a0afb56c6ae6f36d35de892e8e420..e1ccf116457a97261b9ce3965552764771d3bdd2 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py +++ b/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py @@ -143,7 +143,7 @@ class VectorSinhArcsinhDiag(transformed_distribution.TransformedDistribution): broadcastable with `event_shape`. distribution: `tf.Distribution`-like instance. Distribution from which `k` iid samples are used as input to transformation `F`. Default is - `ds.Normal(0., 1.)`. + `tf.distributions.Normal(loc=0., scale=1.)`. Must be a scalar-batch, scalar-event distribution. Typically `distribution.reparameterization_type = FULLY_REPARAMETERIZED` or it is a function of non-trainable parameters. WARNING: If you backprop through diff --git a/tensorflow/contrib/distributions/python/ops/vector_student_t.py b/tensorflow/contrib/distributions/python/ops/vector_student_t.py index 29d41ab81c62d621c3c3533e1449341e9a085645..8c67647a618d22a58428d78865c4ebf7d98bdf9e 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_student_t.py +++ b/tensorflow/contrib/distributions/python/ops/vector_student_t.py @@ -91,14 +91,14 @@ class _VectorStudentT(transformed_distribution.TransformedDistribution): Extra leading dimensions, if provided, allow for batches. ```python - ds = tf.contrib.distributions + tfd = tf.contrib.distributions # Initialize a single 3-variate vector Student's t-distribution. mu = [1., 2, 3] chol = [[1., 0, 0.], [1, 3, 0], [1, 2, 3]] - vt = ds.VectorStudentT(df=2, loc=mu, scale_tril=chol) + vt = tfd.VectorStudentT(df=2, loc=mu, scale_tril=chol) # Evaluate this on an observation in R^3, returning a scalar. vt.prob([-1., 0, 1]) @@ -107,7 +107,7 @@ class _VectorStudentT(transformed_distribution.TransformedDistribution): mu = [[1., 2, 3], [11, 22, 33]] chol = ... # shape 2 x 3 x 3, lower triangular, positive diagonal. - vt = ds.VectorStudentT(loc=mu, scale_tril=chol) + vt = tfd.VectorStudentT(loc=mu, scale_tril=chol) # Evaluate this on a two observations, each in R^3, returning a length two # tensor. diff --git a/tensorflow/contrib/eager/README.md b/tensorflow/contrib/eager/README.md index ae4b07799f5c123b68529443a1765fbfbac05492..09242ee47ddd044dfc99e22d5b7751a989c86485 100644 --- a/tensorflow/contrib/eager/README.md +++ b/tensorflow/contrib/eager/README.md @@ -1,4 +1,4 @@ -# TensorFlow Eager Execution +# Eager Execution > *WARNING*: This is a preview/pre-alpha version. The API and performance > characteristics are subject to change. @@ -76,3 +76,6 @@ For an introduction to eager execution in TensorFlow, see: ## Changelog - 2017/10/31: Initial preview release. +- 2017/12/01: Example of dynamic neural network: + [SPINN: Stack-augmented Parser-Interpreter Neural Network](https://arxiv.org/abs/1603.06021). + See [README.md](python/examples/spinn/README.md) for details. diff --git a/tensorflow/contrib/eager/python/BUILD b/tensorflow/contrib/eager/python/BUILD index 2b84bc2e9b7453fac99ea2becc328ca854cf555d..fb667cd91bdb5296e6aacf1963981ce5cfd76be3 100644 --- a/tensorflow/contrib/eager/python/BUILD +++ b/tensorflow/contrib/eager/python/BUILD @@ -12,16 +12,16 @@ py_library( visibility = ["//visibility:public"], deps = [ ":datasets", - ":evaluator", ":metrics", ":network", ":saver", - ":summary_writer", "//tensorflow/python:framework_ops", "//tensorflow/python:framework_test_lib", "//tensorflow/python:numerics", "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:script_ops", "//tensorflow/python:util", + "//tensorflow/python:variable_scope", "//tensorflow/python/eager:backprop", "//tensorflow/python/eager:context", "//tensorflow/python/eager:core", @@ -51,21 +51,22 @@ py_library( srcs_version = "PY2AND3", visibility = ["//tensorflow:internal"], deps = [ + "//tensorflow/contrib/data/python/ops:prefetching_py", "//tensorflow/python:array_ops", "//tensorflow/python:dataset_ops_gen", "//tensorflow/python:errors", "//tensorflow/python:framework_ops", "//tensorflow/python:resource_variable_ops", + "//tensorflow/python/data/ops:iterator_ops", "//tensorflow/python/data/util:nest", "//tensorflow/python/eager:context", ], ) -py_test( +cuda_py_test( name = "datasets_test", srcs = ["datasets_test.py"], - srcs_version = "PY2AND3", - deps = [ + additional_deps = [ ":datasets", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", @@ -103,37 +104,6 @@ cuda_py_test( ], ) -py_library( - name = "summary_writer", - srcs = ["summary_writer.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/contrib/summary:gen_summary_ops", - "//tensorflow/python:constant_op", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:init_ops", - "//tensorflow/python:resource_variable_ops", - "//tensorflow/python:state_ops", - "//tensorflow/python:summary_op_util", - "//tensorflow/python:variable_scope", - "//tensorflow/python/eager:context", - ], -) - -cuda_py_test( - name = "summary_writer_test", - srcs = ["summary_writer_test.py"], - additional_deps = [ - ":summary_writer", - "//third_party/py/numpy", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:constant_op", - "//tensorflow/python/eager:context", - "//tensorflow/python/eager:test", - ], -) - py_library( name = "metrics", srcs = [ @@ -165,11 +135,9 @@ py_test( ":metrics", "//tensorflow/contrib/summary:summary_ops", "//tensorflow/contrib/summary:summary_test_util", - "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:dtypes", - "//tensorflow/python:lib", - "//tensorflow/python:platform", + "//tensorflow/python:framework_ops", "//tensorflow/python:training", "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", @@ -219,8 +187,11 @@ py_library( srcs_version = "PY2AND3", visibility = ["//tensorflow:internal"], deps = [ + "//tensorflow/python:framework_ops", "//tensorflow/python:layers_base", + "//tensorflow/python:training", "//tensorflow/python:variable_scope", + "//tensorflow/python/eager:context", "//tensorflow/python/estimator:util", ], ) @@ -231,13 +202,17 @@ py_test( srcs_version = "PY2AND3", deps = [ ":network", + "//tensorflow/contrib/layers:layers_py", "//tensorflow/python:constant_op", + "//tensorflow/python:errors", "//tensorflow/python:framework_test_lib", "//tensorflow/python:layers", "//tensorflow/python:math_ops", "//tensorflow/python:nn_ops", "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:training", "//tensorflow/python:variable_scope", + "//tensorflow/python/eager:function", "//tensorflow/python/eager:test", ], ) diff --git a/tensorflow/contrib/eager/python/datasets.py b/tensorflow/contrib/eager/python/datasets.py index 98e6983658aed77277d87915ff26a8c676224503..b559cce6b12a809d671ce7855680063f02a4ac22 100644 --- a/tensorflow/contrib/eager/python/datasets.py +++ b/tensorflow/contrib/eager/python/datasets.py @@ -20,11 +20,15 @@ from __future__ import print_function import threading +from tensorflow.contrib.data.python.ops import prefetching_ops +from tensorflow.python.data.ops import iterator_ops from tensorflow.python.data.util import nest from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.framework import function from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import resource_variable_ops @@ -32,12 +36,12 @@ _uid_counter = 0 _uid_lock = threading.Lock() -def _iterator_shared_name(): +def _generate_shared_name(prefix): with _uid_lock: global _uid_counter uid = _uid_counter _uid_counter += 1 - return "eager_iterator_{}".format(uid) + return "{}_{}".format(prefix, uid) class Iterator(object): @@ -72,11 +76,12 @@ class Iterator(object): with ops.device("/device:CPU:0"): ds_variant = dataset._as_variant_tensor() # pylint: disable=protected-access self._output_types = dataset.output_types + self._output_shapes = dataset.output_shapes self._flat_output_types = nest.flatten(dataset.output_types) self._flat_output_shapes = nest.flatten(dataset.output_shapes) self._resource = gen_dataset_ops.iterator( container="", - shared_name=_iterator_shared_name(), + shared_name=_generate_shared_name("eager_iterator"), output_types=self._flat_output_types, output_shapes=self._flat_output_shapes) gen_dataset_ops.make_iterator(ds_variant, self._resource) @@ -84,6 +89,35 @@ class Iterator(object): self._resource_deleter = resource_variable_ops.EagerResourceDeleter( handle=self._resource, handle_device="/device:CPU:0") self._device = context.context().device_name + self._buffer_resource_handle = None + if not context.context().device_spec.device_type: + is_remote_device = False + else: + is_remote_device = context.context().device_spec.device_type != "CPU" + if is_remote_device: + with ops.device("/device:CPU:0"): + iter_string_handle = gen_dataset_ops.iterator_to_string_handle( + self._resource) + + @function.Defun(dtypes.string) + def remote_fn(h): + remote_iterator = iterator_ops.Iterator.from_string_handle( + h, self._output_types, self._output_shapes) + return remote_iterator.get_next() + + remote_fn.add_to_graph(None) + target = constant_op.constant("/device:CPU:0") + with ops.device(self._device): + self._buffer_resource_handle = prefetching_ops.function_buffering_resource( + string_arg=iter_string_handle, + f=remote_fn, + target_device=target, + buffer_size=10, + thread_pool_size=1, + container="", + shared_name=_generate_shared_name("function_buffer_resource")) + self._buffer_resource_deleter = resource_variable_ops.EagerResourceDeleter( + handle=self._buffer_resource_handle, handle_device=self._device) def __iter__(self): return self @@ -93,20 +127,20 @@ class Iterator(object): def next(self): """Return the next tf.Tensor from the dataset.""" - try: - # TODO(ashankar): Consider removing this ops.device() contextmanager - # and instead mimic ops placement in graphs: Operations on resource - # handles execute on the same device as where the resource is placed. - with ops.device("/device:CPU:0"): - ret = gen_dataset_ops.iterator_get_next( - self._resource, - output_types=self._flat_output_types, - output_shapes=self._flat_output_shapes) - except errors.OutOfRangeError: - raise StopIteration - # Copies tensors from CPU to the current device if necessary. - # TODO(rohanj): This should be replaced by the mechanism to have the - # runtime's threads copy tensors to the destination device. with ops.device(self._device): - ret = [array_ops.identity(x) for x in ret] + try: + if self._buffer_resource_handle is not None: + ret = prefetching_ops.function_buffering_resource_get_next( + function_buffer_resource=self._buffer_resource_handle, + output_types=self._flat_output_types) + else: + # TODO(ashankar): Consider removing this ops.device() contextmanager + # and instead mimic ops placement in graphs: Operations on resource + # handles execute on the same device as where the resource is placed. + ret = gen_dataset_ops.iterator_get_next( + self._resource, + output_types=self._flat_output_types, + output_shapes=self._flat_output_shapes) + except errors.OutOfRangeError: + raise StopIteration return nest.pack_sequence_as(self._output_types, ret) diff --git a/tensorflow/contrib/eager/python/evaluator.py b/tensorflow/contrib/eager/python/evaluator.py index bd0ab02ecf7ae6025e08dde1c3ddc634db9255c1..3faaeef5903615ea122800a6690117dde682e830 100644 --- a/tensorflow/contrib/eager/python/evaluator.py +++ b/tensorflow/contrib/eager/python/evaluator.py @@ -110,7 +110,7 @@ class Evaluator(object): return self._all_metric_results() else: def f(): - with summary_ops.create_summary_file_writer( + with summary_ops.create_file_writer( summary_logdir).as_default(), summary_ops.always_record_summaries(): return self._all_metric_results() if context.in_eager_mode(): diff --git a/tensorflow/contrib/eager/python/evaluator_test.py b/tensorflow/contrib/eager/python/evaluator_test.py index 02f82cb216983accc7bc2dfa20cbb1ee0b8d8d26..7d2274db9b051e604266074651f4cbd331f20f48 100644 --- a/tensorflow/contrib/eager/python/evaluator_test.py +++ b/tensorflow/contrib/eager/python/evaluator_test.py @@ -87,7 +87,7 @@ class EvaluatorTest(test.TestCase): e.all_metric_results(logdir) - events = summary_test_util.events_from_file(logdir) + events = summary_test_util.events_from_logdir(logdir) self.assertEqual(len(events), 2) self.assertEqual(events[1].summary.value[0].simple_value, 6.0) @@ -136,7 +136,7 @@ class EvaluatorTest(test.TestCase): variables.global_variables_initializer().run() e.run_evaluation(init_op, call_op, results_op) - events = summary_test_util.events_from_file(logdir) + events = summary_test_util.events_from_logdir(logdir) self.assertEqual(len(events), 2) self.assertEqual(events[1].summary.value[0].simple_value, 6.0) diff --git a/tensorflow/contrib/eager/python/examples/BUILD b/tensorflow/contrib/eager/python/examples/BUILD index aa21a6ab994acf929890ecebc07a86cf7ebf97db..6aef010a2139c4cd2ae19c008aa21d4e3592ca98 100644 --- a/tensorflow/contrib/eager/python/examples/BUILD +++ b/tensorflow/contrib/eager/python/examples/BUILD @@ -11,5 +11,6 @@ py_library( "//tensorflow/contrib/eager/python/examples/resnet50", "//tensorflow/contrib/eager/python/examples/rnn_colorbot", "//tensorflow/contrib/eager/python/examples/rnn_ptb", + "//tensorflow/contrib/eager/python/examples/spinn:data", ], ) diff --git a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py index d0130ebd118dbaff4f0161c8b2528764c6103e02..7bc5007c5655bed81b5600ee283c35bd332a1ebe 100644 --- a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py +++ b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py @@ -85,7 +85,7 @@ def fit(model, dataset, optimizer, verbose=False, logdir=None): if logdir: # Support for TensorBoard summaries. Once training has started, use: # tensorboard --logdir= - summary_writer = tf.contrib.summary.create_summary_file_writer(logdir) + summary_writer = tf.contrib.summary.create_file_writer(logdir) # Training loop. for i, (xs, ys) in enumerate(tfe.Iterator(dataset)): diff --git a/tensorflow/contrib/eager/python/examples/mnist/mnist.py b/tensorflow/contrib/eager/python/examples/mnist/mnist.py index ae01bac0b560e15f655c883da4ccc1944c07232c..bb121c7704b4772dde520ddc928a13c50ec8bb18 100644 --- a/tensorflow/contrib/eager/python/examples/mnist/mnist.py +++ b/tensorflow/contrib/eager/python/examples/mnist/mnist.py @@ -190,10 +190,10 @@ def main(_): else: train_dir = None test_dir = None - summary_writer = tf.contrib.summary.create_summary_file_writer( - train_dir, flush_secs=10) - test_summary_writer = tf.contrib.summary.create_summary_file_writer( - test_dir, flush_secs=10, name='test') + summary_writer = tf.contrib.summary.create_file_writer( + train_dir, flush_millis=10000) + test_summary_writer = tf.contrib.summary.create_file_writer( + test_dir, flush_millis=10000, name='test') checkpoint_prefix = os.path.join(FLAGS.checkpoint_dir, 'ckpt') with tf.device(device): @@ -211,7 +211,7 @@ def main(_): test(model, test_ds) all_variables = ( model.variables - + tfe.get_optimizer_variables(optimizer) + + optimizer.variables() + [global_step]) tfe.Saver(all_variables).save( checkpoint_prefix, global_step=global_step) diff --git a/tensorflow/contrib/eager/python/examples/notebooks/1_basics.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/1_basics.ipynb index 01616f2e7dbab8084153e6554ce0e64c13f5d710..459f2f4a7d2afa153e77069bc3ce0c5360ddd7e2 100644 --- a/tensorflow/contrib/eager/python/examples/notebooks/1_basics.ipynb +++ b/tensorflow/contrib/eager/python/examples/notebooks/1_basics.ipynb @@ -429,7 +429,9 @@ "cpu_tensor = tf.random_normal([SIZE, SIZE])\n", "\n", "if is_gpu_available:\n", - " gpu_tensor = cpu_tensor.gpu()" + " gpu_tensor = cpu_tensor.gpu()\n", + "else:\n", + " print(\"GPU not available.\")" ] }, { diff --git a/tensorflow/contrib/eager/python/examples/notebooks/2_gradients.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/2_gradients.ipynb index 3b7e2cd435e7f34cb950545a9fe5ee6eafefde7e..e6c7c117333e1e10aa571dae295e88747bd7d764 100644 --- a/tensorflow/contrib/eager/python/examples/notebooks/2_gradients.ipynb +++ b/tensorflow/contrib/eager/python/examples/notebooks/2_gradients.ipynb @@ -383,7 +383,7 @@ "\n", "`implicit_value_and_gradients()` returns a function that accepts the same inputs as the function passed in, and returns a tuple consisting of:\n", "\n", - "1. the value returned by the function passed in (in this case, the loss calculated by `calculate_linear_model_loss()`), and\n", + "1. the value returned by the function passed in (in this case, the loss calculated by `loss_fn()`), and\n", "1. a list of tuples consisting of:\n", " 1. The value of the gradient (a `tf.Tensor`) with respect to a given variable\n", " 1. The corresponding variable (`tf.Variable`)\n", @@ -698,7 +698,7 @@ "source": [ "## Other Ways to Compute Gradients\n", "\n", - "Using our loss function as an example (`calculate_linear_model_loss()`), there are several other ways we could compute gradients:\n", + "Using our loss function as an example (`loss_fn()`), there are several other ways we could compute gradients:\n", "\n", "1. `tfe.implicit_gradients()`\n", "1. `tfe.gradients_function()`\n", @@ -841,7 +841,7 @@ "# tfe.implicit_value_and_gradients() demo\n", "value_gradients_fn = tfe.implicit_value_and_gradients(loss_fn)\n", "\n", - "# Returns only gradients:\n", + "# Returns the value returned by the function passed in, gradients, and variables:\n", "value_gradients_fn(inputs, labels, wb)" ] } diff --git a/tensorflow/contrib/eager/python/examples/notebooks/3_datasets.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/3_datasets.ipynb index ebcc7027c1d34c47a339a49ede1d80e58ad43780..0088da5c4b583dd13251de5839235de666fe8b78 100644 --- a/tensorflow/contrib/eager/python/examples/notebooks/3_datasets.ipynb +++ b/tensorflow/contrib/eager/python/examples/notebooks/3_datasets.ipynb @@ -9,7 +9,7 @@ "source": [ "# Eager Execution Tutorial: Importing Data\n", "\n", - "This notebook demonstrates the use of the [`tf.contrib.data.Dataset` API](https://www.tensorflow.org/programmers_guide/datasets) to build pipelines to feed data to your program. It covers:\n", + "This notebook demonstrates the use of the [`tf.data.Dataset` API](https://www.tensorflow.org/programmers_guide/datasets) to build pipelines to feed data to your program. It covers:\n", "\n", "* Creating a `Dataset`.\n", "* Iteration over a `Dataset` with eager execution enabled.\n", @@ -64,7 +64,7 @@ "source": [ "# Step 1: Create a source `Dataset`\n", "\n", - "Create a _source_ dataset using one of the factory functions like [`Dataset.from_tensors`](https://www.tensorflow.org/api_docs/python/tf/contrib/data/Dataset#from_tensors), [`Dataset.from_tensor_slices`](https://www.tensorflow.org/api_docs/python/tf/contrib/data/Dataset#from_tensor_slices) or using objects that read from files like [`TextLineDataset`](https://www.tensorflow.org/api_docs/python/tf/contrib/data/TextLineDataset) or [`TFRecordDataset`](https://www.tensorflow.org/api_docs/python/tf/contrib/data/TFRecordDataset). See the [Programmer's Guide](https://www.google.com/url?sa=D\u0026q=https%3A%2F%2Fwww.tensorflow.org%2Fprogrammers_guide%2Fdatasets%23reading_input_data) for more information." + "Create a _source_ dataset using one of the factory functions like [`Dataset.from_tensors`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#from_tensors), [`Dataset.from_tensor_slices`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#from_tensor_slices) or using objects that read from files like [`TextLineDataset`](https://www.tensorflow.org/api_docs/python/tf/data/TextLineDataset) or [`TFRecordDataset`](https://www.tensorflow.org/api_docs/python/tf/data/TFRecordDataset). See the [Programmer's Guide](https://www.google.com/url?sa=D\u0026q=https%3A%2F%2Fwww.tensorflow.org%2Fprogrammers_guide%2Fdatasets%23reading_input_data) for more information." ] }, { @@ -83,7 +83,7 @@ }, "outputs": [], "source": [ - "ds_tensors = tf.contrib.data.Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6])\n", + "ds_tensors = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6])\n", "\n", "# Create a CSV file\n", "import tempfile\n", @@ -93,7 +93,7 @@ "Line 2\n", "Line 3\n", " \"\"\")\n", - "ds_file = tf.contrib.data.TextLineDataset(filename)\n" + "ds_file = tf.data.TextLineDataset(filename)\n" ] }, { @@ -105,7 +105,7 @@ "source": [ "# Step 2: Apply transformations\n", "\n", - "Use the transformations functions like [`map`](https://www.tensorflow.org/api_docs/python/tf/contrib/data/Dataset#map), [`batch`](https://www.tensorflow.org/api_docs/python/tf/contrib/data/Dataset#batch), [`shuffle`](https://www.tensorflow.org/api_docs/python/tf/contrib/data/Dataset#shuffle) etc. to apply transformations to the records of the dataset. See the [API documentation for `tf.contrib.data.Dataset`](https://www.tensorflow.org/api_docs/python/tf/contrib/data/Dataset) for details." + "Use the transformations functions like [`map`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#map), [`batch`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#batch), [`shuffle`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#shuffle) etc. to apply transformations to the records of the dataset. See the [API documentation for `tf.data.Dataset`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset) for details." ] }, { diff --git a/tensorflow/contrib/eager/python/examples/resnet50/BUILD b/tensorflow/contrib/eager/python/examples/resnet50/BUILD index 5759ca17facda2e94a35bcc7e2a54b80ff5ac858..536cad998d94e45187d30fce3be0d7a57178e0c1 100644 --- a/tensorflow/contrib/eager/python/examples/resnet50/BUILD +++ b/tensorflow/contrib/eager/python/examples/resnet50/BUILD @@ -39,5 +39,6 @@ cuda_py_test( tags = [ "noasan", "nomsan", + "notsan", ], ) diff --git a/tensorflow/contrib/eager/python/examples/resnet50/README.md b/tensorflow/contrib/eager/python/examples/resnet50/README.md index f6c1defa4246d46447028f86c87c4ea9b39bb2ad..db023e6c976c8eda09ef0dee7eecb144678773c4 100644 --- a/tensorflow/contrib/eager/python/examples/resnet50/README.md +++ b/tensorflow/contrib/eager/python/examples/resnet50/README.md @@ -11,7 +11,18 @@ Contents: # Benchmarks -Using a synthetic data. +Using a synthetic data, run: + +``` +# Using eager execution +python resnet50_test.py --benchmarks=. + +# Using graph execution +python resnet50_graph_test.py --benchmarks=. +``` + +The above uses the model definition included with the TensorFlow pip +package. To build (and run benchmarks) from source: ``` # Using eager execution diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py index 736a75332ff6403ea1b21387211df6b8fb6034f3..23317886e712323f4b520000e0fd372734fc53a1 100644 --- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py +++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py @@ -73,7 +73,7 @@ class ResNet50GraphTest(tf.test.TestCase): tf.train.get_or_create_global_step() logdir = tempfile.mkdtemp() with tf.contrib.summary.always_record_summaries(): - with tf.contrib.summary.create_summary_file_writer( + with tf.contrib.summary.create_file_writer( logdir, max_queue=0, name='t0').as_default(): model = resnet50.ResNet50(data_format()) @@ -95,7 +95,7 @@ class ResNet50GraphTest(tf.test.TestCase): sess.run([train_op, tf.contrib.summary.all_summary_ops()], feed_dict={images: np_images, labels: np_labels}) - events = summary_test_util.events_from_file(logdir) + events = summary_test_util.events_from_logdir(logdir) self.assertEqual(len(events), 2) self.assertEqual(events[1].summary.value[0].tag, 'loss') diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py index d6389f2e385b3637b178d49fc56e8baf913eccaa..d8d8644dde10498e5fd480f92b69656fca1558dd 100644 --- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py +++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py @@ -95,7 +95,7 @@ class ResNet50Test(tf.test.TestCase): model = resnet50.ResNet50(data_format) tf.train.get_or_create_global_step() logdir = tempfile.mkdtemp() - with tf.contrib.summary.create_summary_file_writer( + with tf.contrib.summary.create_file_writer( logdir, max_queue=0, name='t0').as_default(), tf.contrib.summary.always_record_summaries(): with tf.device(device): @@ -103,7 +103,7 @@ class ResNet50Test(tf.test.TestCase): images, labels = random_batch(2) train_one_step(model, images, labels, optimizer) self.assertEqual(320, len(model.variables)) - events = summary_test_util.events_from_file(logdir) + events = summary_test_util.events_from_logdir(logdir) self.assertEqual(len(events), 2) self.assertEqual(events[1].summary.value[0].tag, 'loss') diff --git a/tensorflow/contrib/eager/python/examples/rnn_colorbot/BUILD b/tensorflow/contrib/eager/python/examples/rnn_colorbot/BUILD index b657d31f35bafd6624ac7e4d6a6f6b2db362649d..f83eb5c476ed9f45d70849a0de6c0f20973682a5 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_colorbot/BUILD +++ b/tensorflow/contrib/eager/python/examples/rnn_colorbot/BUILD @@ -11,6 +11,7 @@ py_binary( deps = [ "//tensorflow:tensorflow_py", "//tensorflow/contrib/eager/python:tfe", + "//tensorflow/python/eager:context", "@six_archive//:six", ], ) diff --git a/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py b/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py index 318962c634e0d050b35da5efc405400380c1b759..40919f2d4cf511eb35fac954719286366aef6c7c 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py +++ b/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py @@ -247,10 +247,10 @@ def main(_): log_dir = os.path.join(FLAGS.dir, "summaries") tf.gfile.MakeDirs(log_dir) - train_summary_writer = tf.contrib.summary.create_summary_file_writer( - os.path.join(log_dir, "train"), flush_secs=10) - test_summary_writer = tf.contrib.summary.create_summary_file_writer( - os.path.join(log_dir, "eval"), flush_secs=10, name="eval") + train_summary_writer = tf.contrib.summary.create_file_writer( + os.path.join(log_dir, "train"), flush_millis=10000) + test_summary_writer = tf.contrib.summary.create_file_writer( + os.path.join(log_dir, "eval"), flush_millis=10000, name="eval") with tf.device(device): for epoch in range(FLAGS.num_epochs): diff --git a/tensorflow/contrib/eager/python/examples/rnn_ptb/BUILD b/tensorflow/contrib/eager/python/examples/rnn_ptb/BUILD index db2587bf2cb548ae37e58597691e96ae2c2e8177..4b4792cd49bf8bd4ad46a0371ef0d2f8a07ddd1c 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_ptb/BUILD +++ b/tensorflow/contrib/eager/python/examples/rnn_ptb/BUILD @@ -10,7 +10,9 @@ py_binary( srcs_version = "PY2AND3", deps = [ "//tensorflow:tensorflow_py", + "//tensorflow/contrib/cudnn_rnn:cudnn_rnn_py", "//tensorflow/contrib/eager/python:tfe", + "//third_party/py/numpy", ], ) diff --git a/tensorflow/contrib/eager/python/examples/rnn_ptb/README.md b/tensorflow/contrib/eager/python/examples/rnn_ptb/README.md index ea92d59e5863226a1bc28a07919518f209587cb5..743ebb68ee5bba5635899267cc4839828f7e4e2f 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_ptb/README.md +++ b/tensorflow/contrib/eager/python/examples/rnn_ptb/README.md @@ -18,6 +18,18 @@ To run: Benchmarks (using synthetic data): +``` +# Using eager execution +python rnn_ptb_test.py --benchmarks=. + +# Using graph execution +python rnn_ptb_graph_test.py --benchmarks=. +``` + +The above uses the model definition included with the TensorFlow pip +package. To build (and run benchmarks) from source: + + ``` # Using eager execution bazel run -c opt --config=cuda :rnn_ptb_test -- --benchmarks=. diff --git a/tensorflow/contrib/eager/python/examples/spinn/BUILD b/tensorflow/contrib/eager/python/examples/spinn/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..a1f8a759e2a556bc219f0aa13942f293c4f34cfa --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/spinn/BUILD @@ -0,0 +1,42 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//tensorflow:internal"]) + +load("//tensorflow:tensorflow.bzl", "cuda_py_test") +load("//tensorflow:tensorflow.bzl", "py_test") + +py_library( + name = "data", + srcs = ["data.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = ["//third_party/py/numpy"], +) + +py_test( + name = "data_test", + size = "small", + srcs = ["data_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":data", + "//tensorflow:tensorflow_py", + ], +) + +cuda_py_test( + name = "spinn_test", + size = "medium", + srcs = ["spinn_test.py"], + additional_deps = [ + ":data", + "//third_party/examples/eager/spinn", + "//third_party/py/numpy", + "//tensorflow:tensorflow_py", + "//tensorflow/contrib/summary:summary_test_util", + "//tensorflow/python/eager:test", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_test_lib", + ], + tags = ["no_pip"], # because spinn.py is under third_party/. +) diff --git a/tensorflow/contrib/eager/python/examples/spinn/README.md b/tensorflow/contrib/eager/python/examples/spinn/README.md new file mode 100644 index 0000000000000000000000000000000000000000..eb0637df473e22e5d39ca1b0816464cb2b7c6435 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/spinn/README.md @@ -0,0 +1,13 @@ +# SPINN: Dynamic neural network with TensorFlow eager execution + +This directory contains files supporting the +[spinn.py model in third_party/examples/eager/spinn/](../../../../../../third_party/examples/eager/spinn/spinn.py), +including + +- `data.py`: Utility library for loading and preprocessing the SNLI and GloVe + data. +- `data_test.py` and `spinn_test.py`: Unit tests for the data and model modules. + +See the [README.md in third_party/examples/eager/spinn/](../../../../../../third_party/examples/eager/spinn/README.md) +for detailed background, license and usage information regarding the SPINN code. + diff --git a/tensorflow/contrib/eager/python/examples/spinn/data.py b/tensorflow/contrib/eager/python/examples/spinn/data.py new file mode 100644 index 0000000000000000000000000000000000000000..a6e046320f78541bef4e091e97f08fd51857af83 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/spinn/data.py @@ -0,0 +1,350 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities of SNLI data and GloVe word vectors for SPINN model. + +See more details about the SNLI data set at: + https://nlp.stanford.edu/projects/snli/ + +See more details about the GloVe pretrained word embeddings at: + https://nlp.stanford.edu/projects/glove/ +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import glob +import math +import os +import random + +import numpy as np + +POSSIBLE_LABELS = ("entailment", "contradiction", "neutral") + +UNK_CODE = 0 # Code for unknown word tokens. +PAD_CODE = 1 # Code for padding tokens. + +SHIFT_CODE = 3 +REDUCE_CODE = 2 + +WORD_VECTOR_LEN = 300 # Embedding dimensions. + +LEFT_PAREN = "(" +RIGHT_PAREN = ")" +PARENTHESES = (LEFT_PAREN, RIGHT_PAREN) + + +def get_non_parenthesis_words(items): + """Get the non-parenthesis items from a SNLI parsed sentence. + + Args: + items: Data items from a parsed SNLI setence, with parentheses. E.g., + ["(", "Man", "(", "(", "(", "(", "(", "wearing", "pass", ")", ... + + Returns: + A list of non-parenthis word items, all converted to lower case. E.g., + ["man", "wearing", "pass", ... + """ + return [x.lower() for x in items if x not in PARENTHESES and x] + + +def get_shift_reduce(items): + """Obtain shift-reduce vector from a list of items from the SNLI data. + + Args: + items: Data items as a list of str, e.g., + ["(", "Man", "(", "(", "(", "(", "(", "wearing", "pass", ")", ... + + Returns: + A list of shift-reduce transitions, encoded as `SHIFT_CODE` for shift and + `REDUCE_CODE` for reduce. See code above for the values of `SHIFT_CODE` + and `REDUCE_CODE`. + """ + trans = [] + for item in items: + if item == LEFT_PAREN: + continue + elif item == RIGHT_PAREN: + trans.append(REDUCE_CODE) + else: + trans.append(SHIFT_CODE) + return trans + + +def pad_and_reverse_word_ids(sentences): + """Pad a list of sentences to the common maximum length + 1. + + Args: + sentences: A list of sentences as a list of list of integers. Each integer + is a word ID. Each list of integer corresponds to one sentence. + + Returns: + A numpy.ndarray of shape (num_sentences, max_length + 1), wherein max_length + is the maximum sentence length (in # of words). Each sentence is reversed + and then padded with an extra one at head, as required by the model. + """ + max_len = max(len(sent) for sent in sentences) + for sent in sentences: + if len(sent) < max_len: + sent.extend([PAD_CODE] * (max_len - len(sent))) + # Reverse in time order and pad an extra one. + sentences = np.fliplr(np.array(sentences, dtype=np.int64)) + sentences = np.concatenate( + [np.ones([sentences.shape[0], 1], dtype=np.int64), sentences], axis=1) + return sentences + + +def pad_transitions(sentences_transitions): + """Pad a list of shift-reduce transitions to the maximum length.""" + max_len = max(len(transitions) for transitions in sentences_transitions) + for transitions in sentences_transitions: + if len(transitions) < max_len: + transitions.extend([PAD_CODE] * (max_len - len(transitions))) + return np.array(sentences_transitions, dtype=np.int64) + + +def load_vocabulary(data_root): + """Load vocabulary from SNLI data files. + + Args: + data_root: Root directory of the data. It is assumed that the SNLI data + files have been downloaded and extracted to the "snli/snli_1.0" + subdirectory of it. + + Returns: + Vocabulary as a set of strings. + + Raises: + ValueError: If SNLI data files cannot be found. + """ + snli_path = os.path.join(data_root, "snli") + snli_glob_pattern = os.path.join(snli_path, "snli_1.0/snli_1.0_*.txt") + file_names = glob.glob(snli_glob_pattern) + if not file_names: + raise ValueError( + "Cannot find SNLI data files at %s. " + "Please download and extract SNLI data first." % snli_glob_pattern) + + print("Loading vocabulary...") + vocab = set() + for file_name in file_names: + with open(os.path.join(snli_path, file_name), "rt") as f: + for i, line in enumerate(f): + if i == 0: + continue + items = line.split("\t") + premise_words = get_non_parenthesis_words(items[1].split(" ")) + hypothesis_words = get_non_parenthesis_words(items[2].split(" ")) + vocab.update(premise_words) + vocab.update(hypothesis_words) + return vocab + + +def load_word_vectors(data_root, vocab): + """Load GloVe word vectors for words present in the vocabulary. + + Args: + data_root: Data root directory. It is assumed that the GloVe file + has been downloaded and extracted at the "glove/" subdirectory of it. + vocab: A `set` of words, representing the vocabulary. + + Returns: + 1. word2index: A dict from lower-case word to row index in the embedding + matrix, i.e, `embed` below. + 2. embed: The embedding matrix as a float32 numpy array. Its shape is + [vocabulary_size, WORD_VECTOR_LEN]. vocabulary_size is len(vocab). + WORD_VECTOR_LEN is the embedding dimension (300). + + Raises: + ValueError: If GloVe embedding file cannot be found. + """ + glove_path = os.path.join(data_root, "glove/glove.42B.300d.txt") + if not os.path.isfile(glove_path): + raise ValueError( + "Cannot find GloVe embedding file at %s. " + "Please download and extract GloVe embeddings first." % glove_path) + + print("Loading word vectors...") + + word2index = dict() + embed = [] + + embed.append([0] * WORD_VECTOR_LEN) # + embed.append([0] * WORD_VECTOR_LEN) # + word2index[""] = UNK_CODE + word2index[""] = PAD_CODE + + with open(glove_path, "rt") as f: + for line in f: + items = line.split(" ") + word = items[0] + if word in vocab and word not in word2index: + word2index[word] = len(embed) + vector = np.array([float(item) for item in items[1:]]) + assert (WORD_VECTOR_LEN,) == vector.shape + embed.append(vector) + embed = np.array(embed, dtype=np.float32) + return word2index, embed + + +def calculate_bins(length2count, min_bin_size): + """Cacluate bin boundaries given a histogram of lengths and mininum bin size. + + Args: + length2count: A `dict` mapping length to sentence count. + min_bin_size: Minimum bin size in terms of total number of sentence pairs + in the bin. + + Returns: + A `list` representing the right bin boundaries, starting from the inclusive + right boundary of the first bin. For example, if the output is + [10, 20, 35], + it means there are three bins: [1, 10], [11, 20] and [21, 35]. + """ + bounds = [] + lengths = sorted(length2count.keys()) + cum_count = 0 + for length in lengths: + cum_count += length2count[length] + if cum_count >= min_bin_size: + bounds.append(length) + cum_count = 0 + if bounds[-1] != lengths[-1]: + bounds.append(lengths[-1]) + return bounds + + +class SnliData(object): + """A split of SNLI data.""" + + def __init__(self, data_file, word2index, sentence_len_limit=-1): + """SnliData constructor. + + Args: + data_file: Full path to the data file, e.g., + "/tmp/spinn-data/snli/snli_1.0/snli_1.0.train.txt" + word2index: A dict from lower-case word to row index in the embedding + matrix (see `load_word_vectors()` for details). + sentence_len_limit: Maximum allowed sentence length (# of words). + A value of <= 0 means unlimited. Sentences longer than this limit + are currently discarded, not truncated. + """ + + self._labels = [] + self._premises = [] + self._premise_transitions = [] + self._hypotheses = [] + self._hypothesis_transitions = [] + + with open(data_file, "rt") as f: + for i, line in enumerate(f): + if i == 0: + # Skip header line. + continue + items = line.split("\t") + if items[0] not in POSSIBLE_LABELS: + continue + + premise_items = items[1].split(" ") + hypothesis_items = items[2].split(" ") + premise_words = get_non_parenthesis_words(premise_items) + hypothesis_words = get_non_parenthesis_words(hypothesis_items) + + if (sentence_len_limit > 0 and + (len(premise_words) > sentence_len_limit or + len(hypothesis_words) > sentence_len_limit)): + # TODO(cais): Maybe truncate; do not discard. + continue + + premise_ids = [ + word2index.get(word, UNK_CODE) for word in premise_words] + hypothesis_ids = [ + word2index.get(word, UNK_CODE) for word in hypothesis_words] + + self._premises.append(premise_ids) + self._hypotheses.append(hypothesis_ids) + self._premise_transitions.append(get_shift_reduce(premise_items)) + self._hypothesis_transitions.append(get_shift_reduce(hypothesis_items)) + assert (len(self._premise_transitions[-1]) == + 2 * len(premise_words) - 1) + assert (len(self._hypothesis_transitions[-1]) == + 2 * len(hypothesis_words) - 1) + + self._labels.append(POSSIBLE_LABELS.index(items[0]) + 1) + + assert len(self._labels) == len(self._premises) + assert len(self._labels) == len(self._hypotheses) + assert len(self._labels) == len(self._premise_transitions) + assert len(self._labels) == len(self._hypothesis_transitions) + + def num_batches(self, batch_size): + """Calculate number of batches given batch size.""" + return int(math.ceil(len(self._labels) / batch_size)) + + def get_generator(self, batch_size): + """Obtain a generator for batched data. + + All examples of this SnliData object are randomly shuffled, sorted + according to the maximum sentence length of the premise and hypothesis + sentences in the pair, and batched. + + Args: + batch_size: Desired batch size. + + Returns: + A generator for data batches. The generator yields a 5-tuple: + label: An array of the shape (batch_size,). + premise: An array of the shape (max_premise_len, batch_size), wherein + max_premise_len is the maximum length of the (padded) premise + sentence in the batch. + premise_transitions: An array of the shape (2 * max_premise_len -3, + batch_size). + hypothesis: Same as `premise`, but for hypothesis sentences. + hypothesis_transitions: Same as `premise_transitions`, but for + hypothesis sentences. + All the elements of the 5-tuple have dtype `int64`. + """ + # Randomly shuffle examples. + zipped = list(zip( + self._labels, self._premises, self._premise_transitions, + self._hypotheses, self._hypothesis_transitions)) + random.shuffle(zipped) + # Then sort the examples by maximum of the premise and hypothesis sentence + # lengths in the pair. During training, the batches are expected to be + # shuffled. So it is okay to leave them sorted by max length here. + (labels, premises, premise_transitions, hypotheses, + hypothesis_transitions) = zip( + *sorted(zipped, key=lambda x: max(len(x[1]), len(x[3])))) + + def _generator(): + begin = 0 + while begin < len(labels): + # The sorting above and the batching here makes sure that sentences of + # similar max lengths are batched together, minimizing the inefficiency + # due to uneven max lengths. The sentences are batched differently in + # each call to get_generator() due to the shuffling before sotring + # above. The pad_and_reverse_word_ids() and pad_transitions() functions + # take care of any remaning unevenness of the max sentence lengths. + end = min(begin + batch_size, len(labels)) + # Transpose, because the SPINN model requires time-major, instead of + # batch-major. + yield (labels[begin:end], + pad_and_reverse_word_ids(premises[begin:end]).T, + pad_transitions(premise_transitions[begin:end]).T, + pad_and_reverse_word_ids(hypotheses[begin:end]).T, + pad_transitions(hypothesis_transitions[begin:end]).T) + begin = end + return _generator diff --git a/tensorflow/contrib/eager/python/examples/spinn/data_test.py b/tensorflow/contrib/eager/python/examples/spinn/data_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e4f0b37c5099e45b7e3b258b258c0a203c36b3b7 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/spinn/data_test.py @@ -0,0 +1,243 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Unit tests for SPINN data module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import shutil +import tempfile + +import tensorflow as tf + +from tensorflow.contrib.eager.python.examples.spinn import data + + +class DataTest(tf.test.TestCase): + + def setUp(self): + super(DataTest, self).setUp() + self._temp_data_dir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self._temp_data_dir) + super(DataTest, self).tearDown() + + def testGenNonParenthesisWords(self): + seq_with_parse = ( + "( Man ( ( ( ( ( wearing pass ) ( on ( a lanyard ) ) ) and " + ") ( standing ( in ( ( a crowd ) ( of people ) ) ) ) ) . ) )") + self.assertEqual( + ["man", "wearing", "pass", "on", "a", "lanyard", "and", "standing", + "in", "a", "crowd", "of", "people", "."], + data.get_non_parenthesis_words(seq_with_parse.split(" "))) + + def testGetShiftReduce(self): + seq_with_parse = ( + "( Man ( ( ( ( ( wearing pass ) ( on ( a lanyard ) ) ) and " + ") ( standing ( in ( ( a crowd ) ( of people ) ) ) ) ) . ) )") + self.assertEqual( + [3, 3, 3, 2, 3, 3, 3, 2, 2, 2, 3, 2, 3, 3, 3, 3, 2, 3, 3, 2, 2, 2, 2, 2, + 3, 2, 2], data.get_shift_reduce(seq_with_parse.split(" "))) + + def testPadAndReverseWordIds(self): + id_sequences = [[0, 2, 3, 4, 5], + [6, 7, 8], + [9, 10, 11, 12, 13, 14, 15, 16]] + self.assertAllClose( + [[1, 1, 1, 1, 5, 4, 3, 2, 0], + [1, 1, 1, 1, 1, 1, 8, 7, 6], + [1, 16, 15, 14, 13, 12, 11, 10, 9]], + data.pad_and_reverse_word_ids(id_sequences)) + + def testPadTransitions(self): + unpadded = [[3, 3, 3, 2, 2, 2, 2], + [3, 3, 2, 2, 2]] + self.assertAllClose( + [[3, 3, 3, 2, 2, 2, 2], + [3, 3, 2, 2, 2, 1, 1]], + data.pad_transitions(unpadded)) + + def testCalculateBins(self): + length2count = { + 1: 10, + 2: 15, + 3: 25, + 4: 40, + 5: 35, + 6: 10} + self.assertEqual([2, 3, 4, 5, 6], + data.calculate_bins(length2count, 20)) + self.assertEqual([3, 4, 6], data.calculate_bins(length2count, 40)) + self.assertEqual([4, 6], data.calculate_bins(length2count, 60)) + + def testLoadVoacbulary(self): + snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0") + fake_train_file = os.path.join(snli_1_0_dir, "snli_1.0_train.txt") + fake_dev_file = os.path.join(snli_1_0_dir, "snli_1.0_dev.txt") + os.makedirs(snli_1_0_dir) + + with open(fake_train_file, "wt") as f: + f.write("gold_label\tsentence1_binary_parse\tsentence2_binary_parse\t" + "sentence1_parse\tsentence2_parse\tsentence1\tsentence2\t" + "captionID\tpairID\tlabel1\tlabel2\tlabel3\tlabel4\tlabel5\n") + f.write("neutral\t( ( Foo bar ) . )\t( ( foo baz ) . )\t" + "DummySentence1Parse\tDummySentence2Parse\t" + "Foo bar.\tfoo baz.\t" + "4705552913.jpg#2\t4705552913.jpg#2r1n\t" + "neutral\tentailment\tneutral\tneutral\tneutral\n") + with open(fake_dev_file, "wt") as f: + f.write("gold_label\tsentence1_binary_parse\tsentence2_binary_parse\t" + "sentence1_parse\tsentence2_parse\tsentence1\tsentence2\t" + "captionID\tpairID\tlabel1\tlabel2\tlabel3\tlabel4\tlabel5\n") + f.write("neutral\t( ( Quux quuz ) ? )\t( ( Corge grault ) ! )\t" + "DummySentence1Parse\tDummySentence2Parse\t" + "Quux quuz?\t.Corge grault!\t" + "4705552913.jpg#2\t4705552913.jpg#2r1n\t" + "neutral\tentailment\tneutral\tneutral\tneutral\n") + + vocab = data.load_vocabulary(self._temp_data_dir) + self.assertSetEqual( + {".", "?", "!", "foo", "bar", "baz", "quux", "quuz", "corge", "grault"}, + vocab) + + def testLoadVoacbularyWithoutFileRaisesError(self): + with self.assertRaisesRegexp(ValueError, "Cannot find SNLI data files at"): + data.load_vocabulary(self._temp_data_dir) + + os.makedirs(os.path.join(self._temp_data_dir, "snli")) + with self.assertRaisesRegexp(ValueError, "Cannot find SNLI data files at"): + data.load_vocabulary(self._temp_data_dir) + + os.makedirs(os.path.join(self._temp_data_dir, "snli/snli_1.0")) + with self.assertRaisesRegexp(ValueError, "Cannot find SNLI data files at"): + data.load_vocabulary(self._temp_data_dir) + + def testLoadWordVectors(self): + glove_dir = os.path.join(self._temp_data_dir, "glove") + os.makedirs(glove_dir) + glove_file = os.path.join(glove_dir, "glove.42B.300d.txt") + + words = [".", ",", "foo", "bar", "baz"] + with open(glove_file, "wt") as f: + for i, word in enumerate(words): + f.write("%s " % word) + for j in range(data.WORD_VECTOR_LEN): + f.write("%.5f" % (i * 0.1)) + if j < data.WORD_VECTOR_LEN - 1: + f.write(" ") + else: + f.write("\n") + + vocab = {"foo", "bar", "baz", "qux", "."} + # Notice that "qux" is not present in `words`. + word2index, embed = data.load_word_vectors(self._temp_data_dir, vocab) + + self.assertEqual(6, len(word2index)) + self.assertEqual(0, word2index[""]) + self.assertEqual(1, word2index[""]) + self.assertEqual(2, word2index["."]) + self.assertEqual(3, word2index["foo"]) + self.assertEqual(4, word2index["bar"]) + self.assertEqual(5, word2index["baz"]) + self.assertEqual((6, data.WORD_VECTOR_LEN), embed.shape) + self.assertAllClose([0.0] * data.WORD_VECTOR_LEN, embed[0, :]) + self.assertAllClose([0.0] * data.WORD_VECTOR_LEN, embed[1, :]) + self.assertAllClose([0.0] * data.WORD_VECTOR_LEN, embed[2, :]) + self.assertAllClose([0.2] * data.WORD_VECTOR_LEN, embed[3, :]) + self.assertAllClose([0.3] * data.WORD_VECTOR_LEN, embed[4, :]) + self.assertAllClose([0.4] * data.WORD_VECTOR_LEN, embed[5, :]) + + def testLoadWordVectorsWithoutFileRaisesError(self): + vocab = {"foo", "bar", "baz", "qux", "."} + with self.assertRaisesRegexp( + ValueError, "Cannot find GloVe embedding file at"): + data.load_word_vectors(self._temp_data_dir, vocab) + + os.makedirs(os.path.join(self._temp_data_dir, "glove")) + with self.assertRaisesRegexp( + ValueError, "Cannot find GloVe embedding file at"): + data.load_word_vectors(self._temp_data_dir, vocab) + + def testSnliData(self): + """Unit test for SnliData objects.""" + snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0") + fake_train_file = os.path.join(snli_1_0_dir, "snli_1.0_train.txt") + os.makedirs(snli_1_0_dir) + + # Four sentences in total. + with open(fake_train_file, "wt") as f: + f.write("gold_label\tsentence1_binary_parse\tsentence2_binary_parse\t" + "sentence1_parse\tsentence2_parse\tsentence1\tsentence2\t" + "captionID\tpairID\tlabel1\tlabel2\tlabel3\tlabel4\tlabel5\n") + f.write("neutral\t( ( Foo bar ) . )\t( ( foo . )\t" + "DummySentence1Parse\tDummySentence2Parse\t" + "Foo bar.\tfoo baz.\t" + "4705552913.jpg#2\t4705552913.jpg#2r1n\t" + "neutral\tentailment\tneutral\tneutral\tneutral\n") + f.write("contradiction\t( ( Bar foo ) . )\t( ( baz . )\t" + "DummySentence1Parse\tDummySentence2Parse\t" + "Foo bar.\tfoo baz.\t" + "4705552913.jpg#2\t4705552913.jpg#2r1n\t" + "neutral\tentailment\tneutral\tneutral\tneutral\n") + f.write("entailment\t( ( Quux quuz ) . )\t( ( grault . )\t" + "DummySentence1Parse\tDummySentence2Parse\t" + "Foo bar.\tfoo baz.\t" + "4705552913.jpg#2\t4705552913.jpg#2r1n\t" + "neutral\tentailment\tneutral\tneutral\tneutral\n") + f.write("entailment\t( ( Quuz quux ) . )\t( ( garply . )\t" + "DummySentence1Parse\tDummySentence2Parse\t" + "Foo bar.\tfoo baz.\t" + "4705552913.jpg#2\t4705552913.jpg#2r1n\t" + "neutral\tentailment\tneutral\tneutral\tneutral\n") + + glove_dir = os.path.join(self._temp_data_dir, "glove") + os.makedirs(glove_dir) + glove_file = os.path.join(glove_dir, "glove.42B.300d.txt") + + words = [".", "foo", "bar", "baz", "quux", "quuz", "grault", "garply"] + with open(glove_file, "wt") as f: + for i, word in enumerate(words): + f.write("%s " % word) + for j in range(data.WORD_VECTOR_LEN): + f.write("%.5f" % (i * 0.1)) + if j < data.WORD_VECTOR_LEN - 1: + f.write(" ") + else: + f.write("\n") + + vocab = data.load_vocabulary(self._temp_data_dir) + word2index, _ = data.load_word_vectors(self._temp_data_dir, vocab) + + train_data = data.SnliData(fake_train_file, word2index) + self.assertEqual(4, train_data.num_batches(1)) + self.assertEqual(2, train_data.num_batches(2)) + self.assertEqual(2, train_data.num_batches(3)) + self.assertEqual(1, train_data.num_batches(4)) + + generator = train_data.get_generator(2)() + for i in range(2): + label, prem, prem_trans, hypo, hypo_trans = next(generator) + self.assertEqual(2, len(label)) + self.assertEqual((4, 2), prem.shape) + self.assertEqual((5, 2), prem_trans.shape) + self.assertEqual((3, 2), hypo.shape) + self.assertEqual((3, 2), hypo_trans.shape) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py new file mode 100644 index 0000000000000000000000000000000000000000..84e25cf81a2223800c47994b26d000caddee6b01 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py @@ -0,0 +1,409 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import gc +import glob +import os +import shutil +import tempfile +import time + +import numpy as np +import tensorflow as tf + +# pylint: disable=g-bad-import-order +import tensorflow.contrib.eager as tfe +from tensorflow.contrib.eager.python.examples.spinn import data +from third_party.examples.eager.spinn import spinn +from tensorflow.contrib.summary import summary_test_util +from tensorflow.python.eager import test +from tensorflow.python.framework import test_util +# pylint: enable=g-bad-import-order + + +def _generate_synthetic_snli_data_batch(sequence_length, + batch_size, + vocab_size): + """Generate a fake batch of SNLI data for testing.""" + with tf.device("cpu:0"): + labels = tf.random_uniform([batch_size], minval=1, maxval=4, dtype=tf.int64) + prem = tf.random_uniform( + (sequence_length, batch_size), maxval=vocab_size, dtype=tf.int64) + prem_trans = tf.constant(np.array( + [[3, 3, 2, 3, 3, 3, 2, 2, 2, 3, 3, 3, + 2, 3, 3, 2, 2, 3, 3, 3, 2, 2, 2, 2, + 3, 2, 2]] * batch_size, dtype=np.int64).T) + hypo = tf.random_uniform( + (sequence_length, batch_size), maxval=vocab_size, dtype=tf.int64) + hypo_trans = tf.constant(np.array( + [[3, 3, 2, 3, 3, 3, 2, 2, 2, 3, 3, 3, + 2, 3, 3, 2, 2, 3, 3, 3, 2, 2, 2, 2, + 3, 2, 2]] * batch_size, dtype=np.int64).T) + if tfe.num_gpus(): + labels = labels.gpu() + prem = prem.gpu() + prem_trans = prem_trans.gpu() + hypo = hypo.gpu() + hypo_trans = hypo_trans.gpu() + return labels, prem, prem_trans, hypo, hypo_trans + + +def _test_spinn_config(d_embed, d_out, logdir=None): + config_tuple = collections.namedtuple( + "Config", ["d_hidden", "d_proj", "d_tracker", "predict", + "embed_dropout", "mlp_dropout", "n_mlp_layers", "d_mlp", + "d_out", "projection", "lr", "batch_size", "epochs", + "force_cpu", "logdir", "log_every", "dev_every", "save_every", + "lr_decay_every", "lr_decay_by"]) + return config_tuple( + d_hidden=d_embed, + d_proj=d_embed * 2, + d_tracker=8, + predict=False, + embed_dropout=0.1, + mlp_dropout=0.1, + n_mlp_layers=2, + d_mlp=32, + d_out=d_out, + projection=True, + lr=2e-2, + batch_size=2, + epochs=10, + force_cpu=False, + logdir=logdir, + log_every=1, + dev_every=2, + save_every=2, + lr_decay_every=1, + lr_decay_by=0.75) + + +class SpinnTest(test_util.TensorFlowTestCase): + + def setUp(self): + super(SpinnTest, self).setUp() + self._test_device = "gpu:0" if tfe.num_gpus() else "cpu:0" + self._temp_data_dir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self._temp_data_dir) + super(SpinnTest, self).tearDown() + + def testBundle(self): + with tf.device(self._test_device): + lstm_iter = [np.array([[0, 1], [2, 3]], dtype=np.float32), + np.array([[0, -1], [-2, -3]], dtype=np.float32), + np.array([[0, 2], [4, 6]], dtype=np.float32), + np.array([[0, -2], [-4, -6]], dtype=np.float32)] + out = spinn._bundle(lstm_iter) + + self.assertEqual(2, len(out)) + self.assertEqual(tf.float32, out[0].dtype) + self.assertEqual(tf.float32, out[1].dtype) + self.assertAllEqual(np.array([[0, 2, 0, -2, 0, 4, 0, -4]]).T, + out[0].numpy()) + self.assertAllEqual(np.array([[1, 3, -1, -3, 2, 6, -2, -6]]).T, + out[1].numpy()) + + def testUnbunbdle(self): + with tf.device(self._test_device): + state = [np.array([[0, 1, 2], [3, 4, 5]], dtype=np.float32), + np.array([[0, -1, -2], [-3, -4, -5]], dtype=np.float32)] + out = spinn._unbundle(state) + + self.assertEqual(2, len(out)) + self.assertEqual(tf.float32, out[0].dtype) + self.assertEqual(tf.float32, out[1].dtype) + self.assertAllEqual(np.array([[0, 1, 2, 0, -1, -2]]), + out[0].numpy()) + self.assertAllEqual(np.array([[3, 4, 5, -3, -4, -5]]), + out[1].numpy()) + + def testReducer(self): + with tf.device(self._test_device): + batch_size = 3 + size = 10 + tracker_size = 8 + reducer = spinn.Reducer(size, tracker_size=tracker_size) + + left_in = [] + right_in = [] + tracking = [] + for _ in range(batch_size): + left_in.append(tf.random_normal((1, size * 2))) + right_in.append(tf.random_normal((1, size * 2))) + tracking.append(tf.random_normal((1, tracker_size * 2))) + + out = reducer(left_in, right_in, tracking=tracking) + self.assertEqual(batch_size, len(out)) + self.assertEqual(tf.float32, out[0].dtype) + self.assertEqual((1, size * 2), out[0].shape) + + def testReduceTreeLSTM(self): + with tf.device(self._test_device): + size = 10 + tracker_size = 8 + reducer = spinn.Reducer(size, tracker_size=tracker_size) + + lstm_in = np.array([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + [0, -1, -2, -3, -4, -5, -6, -7, -8, -9]], + dtype=np.float32) + c1 = np.array([[0, 1], [2, 3]], dtype=np.float32) + c2 = np.array([[0, -1], [-2, -3]], dtype=np.float32) + + h, c = reducer._tree_lstm(c1, c2, lstm_in) + self.assertEqual(tf.float32, h.dtype) + self.assertEqual(tf.float32, c.dtype) + self.assertEqual((2, 2), h.shape) + self.assertEqual((2, 2), c.shape) + + def testTracker(self): + with tf.device(self._test_device): + batch_size = 2 + size = 10 + tracker_size = 8 + buffer_length = 18 + stack_size = 3 + + tracker = spinn.Tracker(tracker_size, False) + tracker.reset_state() + + # Create dummy inputs for testing. + bufs = [] + buf = [] + for _ in range(buffer_length): + buf.append(tf.random_normal((batch_size, size * 2))) + bufs.append(buf) + self.assertEqual(1, len(bufs)) + self.assertEqual(buffer_length, len(bufs[0])) + self.assertEqual((batch_size, size * 2), bufs[0][0].shape) + + stacks = [] + stack = [] + for _ in range(stack_size): + stack.append(tf.random_normal((batch_size, size * 2))) + stacks.append(stack) + self.assertEqual(1, len(stacks)) + self.assertEqual(3, len(stacks[0])) + self.assertEqual((batch_size, size * 2), stacks[0][0].shape) + + for _ in range(2): + out1, out2 = tracker(bufs, stacks) + self.assertIsNone(out2) + self.assertEqual(batch_size, len(out1)) + self.assertEqual(tf.float32, out1[0].dtype) + self.assertEqual((1, tracker_size * 2), out1[0].shape) + + self.assertEqual(tf.float32, tracker.state.c.dtype) + self.assertEqual((batch_size, tracker_size), tracker.state.c.shape) + self.assertEqual(tf.float32, tracker.state.h.dtype) + self.assertEqual((batch_size, tracker_size), tracker.state.h.shape) + + def testSPINN(self): + with tf.device(self._test_device): + embedding_dims = 10 + d_tracker = 8 + sequence_length = 15 + num_transitions = 27 + + config_tuple = collections.namedtuple( + "Config", ["d_hidden", "d_proj", "d_tracker", "predict"]) + config = config_tuple( + embedding_dims, embedding_dims * 2, d_tracker, False) + s = spinn.SPINN(config) + + # Create some fake data. + buffers = tf.random_normal((sequence_length, 1, config.d_proj)) + transitions = tf.constant( + [[3], [3], [2], [3], [3], [3], [2], [2], [2], [3], [3], [3], + [2], [3], [3], [2], [2], [3], [3], [3], [2], [2], [2], [2], + [3], [2], [2]], dtype=tf.int64) + self.assertEqual(tf.int64, transitions.dtype) + self.assertEqual((num_transitions, 1), transitions.shape) + + out = s(buffers, transitions, training=True) + self.assertEqual(tf.float32, out.dtype) + self.assertEqual((1, embedding_dims), out.shape) + + def testSNLIClassifierAndTrainer(self): + with tf.device(self._test_device): + vocab_size = 40 + batch_size = 2 + d_embed = 10 + sequence_length = 15 + d_out = 4 + + config = _test_spinn_config(d_embed, d_out) + + # Create fake embedding matrix. + embed = tf.random_normal((vocab_size, d_embed)) + + model = spinn.SNLIClassifier(config, embed) + trainer = spinn.SNLIClassifierTrainer(model, config.lr) + + (labels, prem, prem_trans, hypo, + hypo_trans) = _generate_synthetic_snli_data_batch(sequence_length, + batch_size, + vocab_size) + + # Invoke model under non-training mode. + logits = model(prem, prem_trans, hypo, hypo_trans, training=False) + self.assertEqual(tf.float32, logits.dtype) + self.assertEqual((batch_size, d_out), logits.shape) + + # Invoke model under training model. + logits = model(prem, prem_trans, hypo, hypo_trans, training=True) + self.assertEqual(tf.float32, logits.dtype) + self.assertEqual((batch_size, d_out), logits.shape) + + # Calculate loss. + loss1 = trainer.loss(labels, logits) + self.assertEqual(tf.float32, loss1.dtype) + self.assertEqual((), loss1.shape) + + loss2, logits = trainer.train_batch( + labels, prem, prem_trans, hypo, hypo_trans) + self.assertEqual(tf.float32, loss2.dtype) + self.assertEqual((), loss2.shape) + self.assertEqual(tf.float32, logits.dtype) + self.assertEqual((batch_size, d_out), logits.shape) + # Training on the batch should have led to a change in the loss value. + self.assertNotEqual(loss1.numpy(), loss2.numpy()) + + def testTrainSpinn(self): + """Test with fake toy SNLI data and GloVe vectors.""" + + # 1. Create and load a fake SNLI data file and a fake GloVe embedding file. + snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0") + fake_train_file = os.path.join(snli_1_0_dir, "snli_1.0_train.txt") + os.makedirs(snli_1_0_dir) + + # Four sentences in total. + with open(fake_train_file, "wt") as f: + f.write("gold_label\tsentence1_binary_parse\tsentence2_binary_parse\t" + "sentence1_parse\tsentence2_parse\tsentence1\tsentence2\t" + "captionID\tpairID\tlabel1\tlabel2\tlabel3\tlabel4\tlabel5\n") + f.write("neutral\t( ( Foo bar ) . )\t( ( foo . )\t" + "DummySentence1Parse\tDummySentence2Parse\t" + "Foo bar.\tfoo baz.\t" + "4705552913.jpg#2\t4705552913.jpg#2r1n\t" + "neutral\tentailment\tneutral\tneutral\tneutral\n") + f.write("contradiction\t( ( Bar foo ) . )\t( ( baz . )\t" + "DummySentence1Parse\tDummySentence2Parse\t" + "Foo bar.\tfoo baz.\t" + "4705552913.jpg#2\t4705552913.jpg#2r1n\t" + "neutral\tentailment\tneutral\tneutral\tneutral\n") + f.write("entailment\t( ( Quux quuz ) . )\t( ( grault . )\t" + "DummySentence1Parse\tDummySentence2Parse\t" + "Foo bar.\tfoo baz.\t" + "4705552913.jpg#2\t4705552913.jpg#2r1n\t" + "neutral\tentailment\tneutral\tneutral\tneutral\n") + f.write("entailment\t( ( Quuz quux ) . )\t( ( garply . )\t" + "DummySentence1Parse\tDummySentence2Parse\t" + "Foo bar.\tfoo baz.\t" + "4705552913.jpg#2\t4705552913.jpg#2r1n\t" + "neutral\tentailment\tneutral\tneutral\tneutral\n") + + glove_dir = os.path.join(self._temp_data_dir, "glove") + os.makedirs(glove_dir) + glove_file = os.path.join(glove_dir, "glove.42B.300d.txt") + + words = [".", "foo", "bar", "baz", "quux", "quuz", "grault", "garply"] + with open(glove_file, "wt") as f: + for i, word in enumerate(words): + f.write("%s " % word) + for j in range(data.WORD_VECTOR_LEN): + f.write("%.5f" % (i * 0.1)) + if j < data.WORD_VECTOR_LEN - 1: + f.write(" ") + else: + f.write("\n") + + vocab = data.load_vocabulary(self._temp_data_dir) + word2index, embed = data.load_word_vectors(self._temp_data_dir, vocab) + + train_data = data.SnliData(fake_train_file, word2index) + dev_data = data.SnliData(fake_train_file, word2index) + test_data = data.SnliData(fake_train_file, word2index) + print(embed) + + # 2. Create a fake config. + config = _test_spinn_config( + data.WORD_VECTOR_LEN, 4, + logdir=os.path.join(self._temp_data_dir, "logdir")) + + # 3. Test training of a SPINN model. + spinn.train_spinn(embed, train_data, dev_data, test_data, config) + + # 4. Load train loss values from the summary files and verify that they + # decrease with training. + summary_file = glob.glob(os.path.join(config.logdir, "events.out.*"))[0] + events = summary_test_util.events_from_file(summary_file) + train_losses = [event.summary.value[0].simple_value for event in events + if event.summary.value + and event.summary.value[0].tag == "train/loss"] + self.assertEqual(config.epochs, len(train_losses)) + self.assertLess(train_losses[-1], train_losses[0]) + + +class EagerSpinnSNLIClassifierBenchmark(test.Benchmark): + + def benchmarkEagerSpinnSNLIClassifier(self): + test_device = "gpu:0" if tfe.num_gpus() else "cpu:0" + with tf.device(test_device): + burn_in_iterations = 2 + benchmark_iterations = 10 + + vocab_size = 1000 + batch_size = 128 + sequence_length = 15 + d_embed = 200 + d_out = 4 + + embed = tf.random_normal((vocab_size, d_embed)) + + config = _test_spinn_config(d_embed, d_out) + model = spinn.SNLIClassifier(config, embed) + trainer = spinn.SNLIClassifierTrainer(model, config.lr) + + (labels, prem, prem_trans, hypo, + hypo_trans) = _generate_synthetic_snli_data_batch(sequence_length, + batch_size, + vocab_size) + + for _ in range(burn_in_iterations): + trainer.train_batch(labels, prem, prem_trans, hypo, hypo_trans) + + gc.collect() + start_time = time.time() + for _ in xrange(benchmark_iterations): + trainer.train_batch(labels, prem, prem_trans, hypo, hypo_trans) + wall_time = time.time() - start_time + # Named "examples"_per_sec to conform with other benchmarks. + extras = {"examples_per_sec": benchmark_iterations / wall_time} + self.report_benchmark( + name="Eager_SPINN_SNLIClassifier_Benchmark", + iters=benchmark_iterations, + wall_time=wall_time, + extras=extras) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/eager/python/g3doc/guide.md b/tensorflow/contrib/eager/python/g3doc/guide.md index e76745a807cb10adf2aedc56e69cea0ceded3ad7..0095ffa0db99d46d25654d73504d0d7d41c18b6f 100644 --- a/tensorflow/contrib/eager/python/g3doc/guide.md +++ b/tensorflow/contrib/eager/python/g3doc/guide.md @@ -388,7 +388,7 @@ many arguments. In fact, eager execution encourages use of the [Keras](https://keras.io)-style "Layer" classes in the -[`tf.layers`](https://www.tensorflow.org/versions/master/api_docs/python/tf/layers) +[`tf.layers`](https://www.tensorflow.org/api_docs/python/tf/layers) module. Furthermore, you may want to apply more sophisticated techniques to compute @@ -488,10 +488,10 @@ parameters of the model as arguments to the `loss` function. ### Using Keras and the Layers API [Keras](https://keras.io) is a popular API for defining model structures. The -[`tf.keras.layers`](https://www.tensorflow.org/versions/master/api_docs/python/tf/keras/layers) +[`tf.keras.layers`](https://www.tensorflow.org/api_docs/python/tf/keras/layers) module provides a set of building blocks for models and is implemented using the `tf.layers.Layer` subclasses in the -[`tf.layers`](https://www.tensorflow.org/versions/master/api_docs/python/tf/layers) +[`tf.layers`](https://www.tensorflow.org/api_docs/python/tf/layers) module. We encourage the use of these same building blocks when using TensorFlow's eager execution feature. For example, the very same linear regression model can be built using `tf.layers.Dense`: @@ -608,9 +608,9 @@ it provides conveniences like keeping track of all model variables and methods to save and restore from checkpoints. Sub-classes of `tfe.Network` may register `Layer`s (like classes in -[`tf.layers`](https://www.tensorflow.org/versions/master/api_docs/python/tf/layers), +[`tf.layers`](https://www.tensorflow.org/api_docs/python/tf/layers), or [Keras -layers](https://www.tensorflow.org/versions/master/api_docs/python/tf/keras/layers)) +layers](https://www.tensorflow.org/api_docs/python/tf/keras/layers)) using a call to `self.track_layer()` and define the computation in an implementation of `call()`. @@ -704,7 +704,7 @@ with tfe.restore_variables_on_create( net(inp).numpy())) all_variables = ( net.variables - + tfe.get_optimizer_variables(optimizer) + + optimizer.variables() + [global_step]) # Save the checkpoint. tfe.Saver(all_variables).save(checkpoint_prefix, global_step=global_step) @@ -757,7 +757,7 @@ For example, to record summaries once every 100 global steps, use: ```python tf.train.get_or_create_global_step() # Ensuring the global step variable exists -writer = tf.contrib.summary.create_summary_file_writer(logdir) +writer = tf.contrib.summary.create_file_writer(logdir) for _ in range(iterations): with writer.as_default(): @@ -800,7 +800,7 @@ example in The discussion above has been centered around the computation executed by your model. The -[`tf.data`](https://www.tensorflow.org/versions/master/api_docs/python/tf/data) +[`tf.data`](https://www.tensorflow.org/api_docs/python/tf/data) module provides APIs to build complex input pipelines from simple, reusable pieces. @@ -810,8 +810,7 @@ However, the process of iterating over elements of the dataset differs between eager execution and graph construction. When eager execution is enabled, the discussion on iterator creation using `make_one_shot_iterator()` and `get_next()` in the -[Programmer's -Guide](https://www.tensorflow.org/versions/master/programmers_guide/datasets) is +[Programmer's Guide](https://www.tensorflow.org/programmers_guide/datasets) is *not* applicable. Instead, a more Pythonic `Iterator` class is available. For example: diff --git a/tensorflow/contrib/eager/python/metrics_impl.py b/tensorflow/contrib/eager/python/metrics_impl.py index 2ba653af4a2465a17a17ff4ff019e69476f6434e..2f8016ede3caee6dbb6fd8f5226f1464b5c3976b 100644 --- a/tensorflow/contrib/eager/python/metrics_impl.py +++ b/tensorflow/contrib/eager/python/metrics_impl.py @@ -73,7 +73,7 @@ class Metric(object): * `result()`: Computes and returns a final value for the metric from the variables in `self`. - Decendants may override `aggregate()`, but usually won't need to. It + Descendants may override `aggregate()`, but usually won't need to. It adds in the state from a list of metrics of the same type as `self`. (Default is to sum all the variables.) Note that users should not call `aggregate()`, it is for use by TensorFlow infrastructure. @@ -223,8 +223,17 @@ class Metric(object): """***Only for use by descendants of Metric***.""" if self._built: raise RuntimeError("Can't call add_variable() except in build().") - v = variable_scope.get_variable(name, shape, dtype, initializer, - trainable=False, use_resource=True) + collections = None if context.in_eager_mode() else [ + ops.GraphKeys.LOCAL_VARIABLES, ops.GraphKeys.METRIC_VARIABLES + ] + v = variable_scope.get_variable( + name, + shape, + dtype, + initializer, + trainable=False, + collections=collections, + use_resource=True) self._vars.append(v) if context.in_eager_mode(): self._initial_values[v] = v.value() diff --git a/tensorflow/contrib/eager/python/metrics_test.py b/tensorflow/contrib/eager/python/metrics_test.py index b945e97a0049441d356f41e4d19fe6f01836ec40..1055f4563cd4608189281450aed512fbf5f31de1 100644 --- a/tensorflow/contrib/eager/python/metrics_test.py +++ b/tensorflow/contrib/eager/python/metrics_test.py @@ -26,6 +26,7 @@ from tensorflow.contrib.summary import summary_test_util from tensorflow.python.eager import context from tensorflow.python.eager import test from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.training import training_util @@ -41,6 +42,17 @@ class MetricsTest(test.TestCase): self.assertEqual(dtypes.float64, m.dtype) self.assertEqual(dtypes.float64, m.result().dtype) + def testVariableCollections(self): + with context.graph_mode(), ops.Graph().as_default(): + m = metrics.Mean() + m(1000) + self.assertEqual( + set(m.variables), + set(ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES))) + self.assertEqual( + set(m.variables), + set(ops.get_collection(ops.GraphKeys.METRIC_VARIABLES))) + def testInitVariables(self): m = metrics.Mean() m([1, 10, 100, 1000]) @@ -55,12 +67,12 @@ class MetricsTest(test.TestCase): m([1, 10, 100]) training_util.get_or_create_global_step() logdir = tempfile.mkdtemp() - with summary_ops.create_summary_file_writer( + with summary_ops.create_file_writer( logdir, max_queue=0, name="t0").as_default(), summary_ops.always_record_summaries(): m.result() # As a side-effect will write summaries. - events = summary_test_util.events_from_file(logdir) + events = summary_test_util.events_from_logdir(logdir) self.assertEqual(len(events), 2) self.assertEqual(events[1].summary.value[0].simple_value, 37.0) diff --git a/tensorflow/contrib/eager/python/network.py b/tensorflow/contrib/eager/python/network.py index 5b53a597f20a1cd0ba9be7f1d3a89e117cde66e8..e3c13cbd2e8ccd2ab79da74e0e97905c6ed5c02d 100644 --- a/tensorflow/contrib/eager/python/network.py +++ b/tensorflow/contrib/eager/python/network.py @@ -37,164 +37,98 @@ from tensorflow.python.training import training_util # functions in base.py which should be reused. -_DeferredRestoration = collections.namedtuple( +def _network_name_scope_naming(current_variable_scope): + """Name scope naming to match operation names to variable names. - "_DeferredRestoration", - [ - # The map_func to use (either user-specified or the default). - "map_func", - # Boolean, True if the user specified an explicit map_func, for error - # messages. - "map_func_is_user", - # A mapping from checkpoint names to initial values of not-yet-created - # variables which should be restored. These values come from parsing a - # checkpoint. - "checkpointed_variables_to_restore", - # A mapping from checkpoint name to variable objects of variables which - # have already been restored, for error checking. - "restored_variables", - # The session to restore with (if in graph mode). - "session", - # Names of the Network where the restore was requested, for error - # messages. - "network_name", - "network_scope_name" - ]) + Used in Networks and also applied to non-Network Layers which are added to + Networks before being built. - -def _default_naming_conflict_error_message( - mapped_name, first_variable, second_variable, - network_name, network_scope_name): - return ( - ("The default checkpoint variable name mapping strategy for Network " - "'%s' resulted in a naming conflict. We attempted to strip off the " - "variable prefix for the Network ('%s'), but this resulted in two " - "variables named '%s' (originally '%s' and '%s'). This should only " - "happen when using variable sharing (i.e. the Network contains Networks " - "or Layers which were first added to another Network, and therefore " - "have that Network's variable prefix). One solution is to pass " - "`map_func=lambda n: n` to Network.save and Network.restore to use " - "fully qualified variable names in the checkpoint, although this will " - "require that the variable prefix of the Network being restored into " - "is also '%s'. You may alternatively write an arbitrary mapping.") - % ( - network_name, network_scope_name, mapped_name, - first_variable._shared_name, - second_variable._shared_name, network_scope_name - )) - - -def _restore_custom_map_func_error_message( - mapped_name, first_variable, second_variable, - network_name, network_scope_name): - return ( - ("The map_func passed to Network.restore for the Network '%s' " - "resulted in two variables named '%s' (originally '%s' and '%s'). Since " - "this is also an error on Network.save, this Network was " - "probably not saved with this map_func. Note that map_func " - "always maps from full variable names to checkpoint names; " - "there is no need to specify an inverse mapping.\n\n" - "Try stripping less from the variable names, or renaming parts " - "of the Network. For reference, variables created by sub-Layers " - "of this Network are prefixed with '%s', but if they are " - "re-used after being added to another Network they will have " - "that Network's full variable prefix instead.") % ( - network_name, mapped_name, - first_variable._shared_name, - second_variable._shared_name, - network_scope_name)) - - -def _make_custom_getter_for_deferred_restorations(): - """Returns a custom getter which searches `deferred_restorations`. - - Returns: A tuple of (_custom_getter, deferred_restorations) - _custom_getter: The getter which should be added to variable_scopes where - variables will be created. - deferred_restorations: A list for _DeferredRestoration objects. Typically - empty when the getter is set, and expanded as deferred restorations are - requested. All new deferred restorations should be appended to the end of - the list, where they will have priority over older deferred restorations. + Args: + current_variable_scope: A VariableScope object. + Returns: + A name scope name. """ - deferred_restorations = [] - - def _custom_getter(getter, name, shape=None, dtype=None, - initializer=None, - *args, **kwargs): - """A custom getter which processes deferred restorations.""" - # Iterate over restorations, newest first (newer restorations will take - # precedence over older restorations, just like with immediate restorations - # into existing variables). - delayed_restoration = None - found_value = False - value_to_restore = None - for delayed_restoration in reversed( - deferred_restorations): - checkpoint_name = delayed_restoration.map_func(name) - if (checkpoint_name - in delayed_restoration.checkpointed_variables_to_restore): - found_value = True - value_to_restore = ( - delayed_restoration.checkpointed_variables_to_restore[ - checkpoint_name]) - if found_value: - break - # value_to_restore may be False because this variable is not in any - # checkpoint we are restoring, or None because we have explicitly set it to - # None when it was previously fetched. In either case, we don't need to - # set an initializer. - if found_value and value_to_restore is not None: - initializer = value_to_restore - shape = None - variable = getter(name, shape=shape, dtype=dtype, initializer=initializer, - *args, **kwargs) - if found_value and value_to_restore is not None: - # Mark as already restored from this checkpoint. - delayed_restoration.checkpointed_variables_to_restore[ - checkpoint_name] = None - if context.in_graph_mode(): - delayed_restoration.session.run(variable.initializer) - if found_value: - # Error checking should run even if we've already restored a value. - if delayed_restoration.restored_variables.setdefault( - checkpoint_name, variable) is not variable: - # Naming conflict. We've tried to initialize two variables with the - # same value from the checkpoint. - if delayed_restoration.map_func_is_user: - raise ValueError( - _restore_custom_map_func_error_message( - mapped_name=checkpoint_name, - first_variable=delayed_restoration.restored_variables[ - checkpoint_name], - second_variable=variable, - network_name=delayed_restoration.network_name, - network_scope_name=delayed_restoration.network_scope_name)) - else: - raise ValueError( - _default_naming_conflict_error_message( - mapped_name=checkpoint_name, - first_variable=delayed_restoration.restored_variables[ - checkpoint_name], - second_variable=variable, - network_name=delayed_restoration.network_name, - network_scope_name=delayed_restoration.network_scope_name)) - return variable - return _custom_getter, deferred_restorations + return current_variable_scope.name + "/" class Network(base.Layer): """Represents the composition of a set of Layers. - TODO(josh11b,ashankar): - - Should "trainable" be changeable on the Network object? - - Do we allow add_variable in Network? - - Detect layers used in __call__ that weren't registered with track_layer. - - Convert inputs to __call__ to tensors. - - Prevent variables from being created after the first __call__? - (Think about restoring from a checkpoint). + `Network` implements the `Layer` interface and adds convenience methods for + managing sub-`Layer`s, such as listing variables. + + `Layer`s (including other `Network`s) should be added via `track_layer`. They + can then be used when overriding the `Network.call` method: + + ```python + class TwoLayerNetwork(tfe.Network): + + def __init__(self, name): + super(TwoLayerNetwork, self).__init__(name=name) + self.layer_one = self.track_layer(tf.layers.Dense(16, input_shape=(8,))) + self.layer_two = self.track_layer(tf.layers.Dense(1, input_shape=(16,))) + + def call(self, inputs): + return self.layer_two(self.layer_one(inputs)) + ``` + + After constructing an object and calling the `Network`, a list of variables + created by tracked `Layer`s is available via `Network.variables`: + + ```python + net = TwoLayerNetwork(name="net") + output = net(tf.ones([1, 8])) + print([v.name for v in net.variables]) + ``` + + This example prints variable names, one kernel and one bias per + `tf.layers.Dense` layer: + + ``` + ['net/dense/kernel:0', + 'net/dense/bias:0', + 'net/dense_1/kernel:0', + 'net/dense_1/bias:0'] + ``` + + These variables can be passed to a `Saver` (`tf.train.Saver`, or + `tf.contrib.eager.Saver` when executing eagerly) to save or restore the + `Network`, typically alongside a global step and `tf.train.Optimizer` + variables when checkpointing during training. + + Note that the semantics of calling a `Network` with graph execution (i.e. not + executing eagerly) may change slightly in the future. Currently stateful ops + are pruned from the graph unless they or something that depends on them is + executed in a session, but this behavior is not consistent with eager + execution (where stateful ops are executed eagerly). `Layer`s from `tf.layers` + do not depend on this pruning and so will not be affected, but `Network`s + which rely on stateful ops being added to the graph but not executed (e.g. via + custom `Layer`s which manage stateful ops) may break with this change. """ + # TODO(josh11b,ashankar,allenl): + # - Should 'trainable' be changeable on the Network object? + # - Do we allow add_variable in Network? + # - Detect layers used in __call__ that weren't registered with track_layer. + # - Convert inputs to __call__ to tensors. def __init__(self, name=None): + """Configure the `Network`. + + Args: + name: The name to use for this `Network`. If specified, it must be unique + in the context where this `Network` is first + (1) added to another `Network` (in which case it must not share a name + with other `Layers` added to that `Network`), or + (2) built/called (in which case no other 'top-level' `Network`s may + share this name). + If unspecified or None, the `Network` will be named using its class + name, with a number appended if necessary for uniqueness (e.g. MyNetwork + -> 'my_network_1'). + + Raises: + ValueError: If `name` is not valid. Note that some naming errors will + instead be raised when the `Network` is called. + """ if isinstance(name, variable_scope.VariableScope): raise ValueError("VariableScopes are not valid Network names.") if name is not None and "/" in name: @@ -210,8 +144,17 @@ class Network(base.Layer): self._owned_layers = {} # The scope to use if we end up without a parent. self._default_parent_variable_scope = variable_scope.get_variable_scope() - self._custom_getter, self._deferred_restorations = ( - _make_custom_getter_for_deferred_restorations()) + # Hold on to the variable scope counts from init to check whether a scope + # with the name we want was ever created in our parent scope. Without this + # check we might have name collisions if the parent scope on init gets + # closed before build is called. + self._variable_scope_counts_on_init = ( + variable_scope._get_default_variable_store().variable_scopes_count) + + def _name_scope_name(self, current_variable_scope): + """Overrides Layer op naming to match variable naming.""" + return _network_name_scope_naming( + current_variable_scope=current_variable_scope) def _init_set_name(self, name): # Anonymous Networks (name=None) defer setting a final name until they are @@ -227,18 +170,30 @@ class Network(base.Layer): def _finalize_name(self, parent_network): if not self._name: - if not parent_network: - name_uid_map = base._get_default_graph_uid_map() - else: - name_uid_map = parent_network._sub_layer_name_uids # Were were not passed a name explicitly (or it was blank), so this is an # anonymous Network. We make up a unique name. if parent_network: avoid_names = parent_network._owned_layers + name_uid_map = parent_network._sub_layer_name_uids else: - avoid_names = None + name_uid_map = base._get_default_graph_uid_map() + # Figure out which names we have to avoid based on which variable scope + # we're nested in. + strip_name = self._default_parent_variable_scope.name + if strip_name: + strip_name += "/" + def _strip_on_init_scope(name): + if name.startswith(strip_name): + return name[len(strip_name):] + else: + return None + avoid_names = set( + _strip_on_init_scope(name) + for name in self._variable_scope_counts_on_init.keys() if name) self._name, self._base_name = self._make_unique_name( - name_uid_map=name_uid_map, avoid_names=avoid_names) + name_uid_map=name_uid_map, avoid_names=avoid_names, + namespace=self._default_parent_variable_scope.name, + zero_based=True) if self._first_parent is None or (self._first_parent # False = no parent and self._first_parent() is None): # Save a pointer to the parent Network so that we can later check that the @@ -268,7 +223,13 @@ class Network(base.Layer): parent_scope = first_parent._scope else: parent_scope = self._default_parent_variable_scope - with variable_scope.variable_scope(parent_scope): + with variable_scope.variable_scope(parent_scope) as parent_vs: + expected_scope_name = parent_vs.name + "/" + self._name + if expected_scope_name in self._variable_scope_counts_on_init: + raise ValueError( + ("A Network named '%s' already exists (or a variable_scope was " + "created with this name). Names must be unique.") % ( + self._name,)) # Make sure variables with this prefix will be unique. with variable_scope.variable_scope( None, use_resource=True, default_name=self._name) as scope: @@ -285,25 +246,22 @@ class Network(base.Layer): "created with this name). Names must be unique.") % ( self._name,)) if (first_parent - and scope_prefix[:-1] != first_parent._scope.name): + and scope_prefix[:-1] != first_parent.scope_name): raise ValueError( ("Network variable names must match a nesting of sub-Network " "names. Expected prefix '%s' from parent network, but got " "'%s' when attempting to create a variable_scope for Network " "'%s'. Likely an explicit variable_scope was inserted into " "the nesting.") % ( - first_parent._scope.name, + first_parent.scope_name, scope_prefix[:-1], self._name)) elif not first_parent and scope_prefix: # For the case when this Network is not nested inside any other - # Network, but is in a variable_scope. This is an error for now. - raise ValueError( - "Creating Networks inside named variable_scopes is currently " - "not supported (to ensure that variable names match the names " - "of Networks in which they were first created). To set " - "options, try `with tf.variable_scope(''):`. If this " - "limitation bothers you, please file a feature request.") + # Network, but is in a variable_scope. This Network's name takes on + # the full variable scope prefix. + self._name = scope_name + for non_network_sublayer in self._non_network_sublayers: self._set_scope_for_nonnetwork_sublayer(non_network_sublayer) @@ -321,7 +279,8 @@ class Network(base.Layer): raise ValueError( ("The parent of a Layer added to Network %s was garbage collected " "before the Layer was built. If this limitation bothers you " - "please, file a feature request.") % (self.name,)) + "please file a feature request.") % + (self.name,)) with variable_scope.variable_scope(parent_scope): # Horrid hack to make Layer variable names which are direct # sub-layers of Networks conform to the Network variable naming @@ -330,6 +289,9 @@ class Network(base.Layer): None, use_resource=True, default_name=sublayer.name) as sub_scope: sublayer._scope = sub_scope + # Also switch op naming for this Layer to match Network conventions, + # i.e. op naming matching variable naming. + sublayer._name_scope_name = _network_name_scope_naming @base.Layer.name.getter def name(self): @@ -384,7 +346,10 @@ class Network(base.Layer): # name, and we should respect it (subject to error checking). layer._name, layer._base_name = layer._make_unique_name( name_uid_map=self._sub_layer_name_uids, - avoid_names=self._owned_layers) + avoid_names=self._owned_layers, + zero_based=True + # No namespace required, since we've specified our own UID map. + ) layer._first_parent = weakref.ref(self) self._non_network_sublayers.append(layer) if (not layer.built @@ -486,272 +451,30 @@ class Network(base.Layer): "at https://github.com/tensorflow/tensorflow/issues/new if this is " "important to you") - def _strip_variable_prefix(self, original_variable_name): - """The default map_func for saving or restoring variables. - - Strips the variable prefix for the Network on which save/restore was called, - and leaves other variable names fully qualified in the checkpoint. - - Args: - original_variable_name: The _shared_name of the variable (no :0 - suffix) to map. - Returns: - The checkpoint name of the variable. - """ - scope_name_with_slash = self.scope_name + "/" - if original_variable_name.startswith(scope_name_with_slash): - return original_variable_name[len(scope_name_with_slash):] - else: - return original_variable_name - - def save(self, save_path, global_step=None, map_func=None): - """Save variables from the Network to a checkpoint. + def add_loss(self, losses, inputs=None): + raise RuntimeError( + "add_loss is not supported in Network class yet. Please file an issue " + "at https://github.com/tensorflow/tensorflow/issues/new if this is " + "important to you") - Args: - save_path: Either a checkpoint prefix or the name of a directory to save - the checkpoint in (in which case the checkpoint will be named based on - the Network name). - global_step: The global step to use when naming the checkpoint. If None - (default), we will first try to get the default global step. If that - fails because no default global step exists, then the checkpoint is - created without a global step suffix. - map_func: A function mapping fully qualified variable names - (e.g. 'my_network_1/dense_1/kernel') to names in the checkpoint. By - default (if `map_func=None`), the variable prefix for the network being - restored (`Network.scope_name + '/'`, e.g. 'my_network_1/') is stripped - and all other variable names (shared with other Networks) are left - unchanged. - Returns: - The checkpoint prefix for the saved checkpoint, which may be passed to - `Network.restore`. - Raises: - ValueError: If the Network has not yet been called, or if map_func results - in a name collision. - """ - if not self.built: - raise ValueError( - "Attempt to save the Network before it was first called. This means " - "variables have not yet been created, so there is nothing to save.") - self._set_scope() # scope_name should be available to map_funcs - if global_step is None: - global_step = training_util.get_global_step() - if os.path.isdir(save_path): - # If we were passed a directory, default to naming based on the Network - # name. - save_path = os.path.join(save_path, self.name) - user_map_func = map_func - if map_func is None: - map_func = self._strip_variable_prefix - variable_map = {} - for variable in self.variables: - mapped_name = map_func(variable._shared_name) - if variable_map.setdefault(mapped_name, variable) is not variable: - if user_map_func is None: - # Instead of erroring out, we could just re-try and silently use the - # full variable names in the checkpoint. This could be odd for deeply - # nested sub-Networks (since the full prefix from the nesting would - # get added), so for now we'll let the user deal with this case. - raise ValueError(_default_naming_conflict_error_message( - mapped_name=mapped_name, - first_variable=variable_map[mapped_name], - second_variable=variable, - network_name=self.name, - network_scope_name=self.scope_name)) - else: - # The user passed their own problematic map_func. - raise ValueError( - ("The map_func passed to Network.save for the Network '%s' " - "resulted in two variables named '%s' ('%s' and '%s'). Try " - "stripping less from the variable names, or renaming parts of " - "the Network. For reference, variables created by sub-Layers of " - "this Network are prefixed with '%s', but if they are re-used " - "after being added to another Network, they will have that " - "Network's full variable prefix instead.") % ( - self.name, mapped_name, - variable_map[mapped_name]._shared_name, - variable._shared_name, - self.scope_name)) - if context.in_eager_mode(): - sess = None - else: - sess = ops.get_default_session() - return saver_lib.Saver(variable_map).save( - sess=sess, save_path=save_path, write_meta_graph=False, - global_step=global_step) + @property + def losses(self): + """Gather losses from `Layer`s in the `Network`. - def _restore_existing_variables(self, save_path, map_func, user_map_func): - """Use a standard Saver to restore existing variables from a checkpoint. + Note that when executing eagerly, `Layer.losses` evaluates + regularizers. When using graph execution, variable regularization ops have + already been created and are simply returned here. - Args: - save_path: The checkpoint prefix or directory to read from. - map_func: The function to use when mapping from variable names to - checkpoint names. - user_map_func: The original map_func passed by the user, for error - checking. Returns: - A dictionary mapping from checkpoint names to variable objects which have - been restored (for bookkeeping to avoid deferred restorations on these - variables). - Raises: - ValueError: If there is a name collision. - """ - existing_variables_by_checkpoint_name = {} - for variable in self.variables: - checkpoint_name = map_func(variable._shared_name) - if existing_variables_by_checkpoint_name.setdefault( - checkpoint_name, variable) is not variable: - if user_map_func is None: - raise ValueError(_default_naming_conflict_error_message( - mapped_name=checkpoint_name, - first_variable=existing_variables_by_checkpoint_name[ - checkpoint_name], - second_variable=variable, - network_name=self.name, - network_scope_name=self.scope_name)) - else: - raise ValueError(_restore_custom_map_func_error_message( - mapped_name=checkpoint_name, - first_variable=existing_variables_by_checkpoint_name[ - checkpoint_name], - second_variable=variable, - network_name=self.name, - network_scope_name=self.scope_name)) - if existing_variables_by_checkpoint_name: - if context.in_eager_mode(): - sess = None - else: - sess = ops.get_default_session() - saver_lib.Saver(var_list=existing_variables_by_checkpoint_name).restore( - sess=sess, save_path=save_path) - return existing_variables_by_checkpoint_name - - def _set_restore_on_create(self, save_path, map_func, user_map_func, - existing_variables_by_checkpoint_name): - """If necessary, request deferred restorations of variables.""" - checkpoint_reader = checkpoint_utils.load_checkpoint(save_path) - checkpointed_variables_to_restore = {} - for checkpoint_name, _ in checkpoint_utils.list_variables(save_path): - if checkpoint_name in existing_variables_by_checkpoint_name: - # This variable was already created and restored. - continue - # Save the variable for later restoration in a custom getter. - checkpointed_variables_to_restore[checkpoint_name] = ( - checkpoint_reader.get_tensor(checkpoint_name)) - # Only set a deferred restoration if there are checkpoint variables which - # have not been assigned to existing variables. Note that this loses out on - # some opportunity for error checking, but avoids creating - # _DeferredRestoration objects once a Network has been built (so that - # restoring in a loop does not take increasing amounts of memory). - if checkpointed_variables_to_restore: - if context.in_eager_mode(): - sess = None - else: - sess = ops.get_default_session() - # We need a name for error messages. If we haven't been added to another - # Network yet, we're top-level. - self._finalize_name(False) - self._set_scope() - # Save a record of this restoration for use in the custom getter. - deferred_restoration = _DeferredRestoration( - map_func=map_func, - map_func_is_user=(user_map_func is not None), - checkpointed_variables_to_restore=checkpointed_variables_to_restore, - restored_variables={}, - session=sess, - network_name=self.name, - network_scope_name=self.scope_name) - self._deferred_restorations.append(deferred_restoration) - # Add the deferred registration to non-Network children, and request that - # Networks propagate the request to their children. - self._add_deferred_restoration(deferred_restoration) - - def _add_deferred_restoration(self, deferred_restoration): - """Add a deferred restoration to this Network and all children. - - Restorations which are requested later have higher priority, and the highest - priority matching restoration is applied to a variable when it is created. - - Args: - deferred_restoration: A _DeferredRestoration object. + A list of tensors. """ - # Networks don't create variables at the moment, so this append isn't - # strictly necessary. We could get by with only adding deferred restorations - # to non-Network Layers. - self._set_scope() - # We use set_custom_getter because it avoids recursively calling up the - # variable_scope tree. We've done the tree traversal ourselves and have - # added the request to each Layer which needs it. - self._scope.set_custom_getter(self._custom_getter) - self._deferred_restorations.append(deferred_restoration) + layer_losses = [] for layer in self.layers: - if isinstance(layer, Network): - # For Networks, request that they propagate this deferred restoration - # to all of their children recursively. - layer._add_deferred_restoration(deferred_restoration) - else: - # For non-Network Layers, make sure they have a deferred restoration - # queue and a custom getter, then add our request to it. - if not hasattr(layer, "_custom_getter"): - assert not hasattr(layer, "_deferred_restorations") - layer._custom_getter, layer._deferred_restorations = ( - _make_custom_getter_for_deferred_restorations()) - self._set_scope_for_nonnetwork_sublayer(layer) - layer._scope.set_custom_getter(layer._custom_getter) - layer._deferred_restorations.append(deferred_restoration) - - def restore(self, save_path, map_func=None): - """Restore the Network from a checkpoint. - - If variables have already been created (typically when some or all of the - `Network` is built), they are assigned values from the checkpoint - immediately, overwriting any existing values (in graph mode the default - session is used for the assignments). - - If there are checkpoint entries which do not correspond to any existing - variables in the `Network`, these values are saved for deferred restoration; - their initial values will be the checkpointed values once they are - created. Requests for multiple deferred restorations behave the same way as - immediate restorations, in that later requests will take priority over - earlier requests relevant to the same variable. - - If this `Network` shares `Layer`s with another network, those `Layer`s will - also have their variables restored from the checkpoint. - - Args: - save_path: The return value of `Network.save`, or a directory to search - for a checkpoint. - map_func: A function mapping fully qualified variable names - (e.g. 'my_network_1/dense_1/kernel') to names in the checkpoint. By - default (if `map_func=None`), the variable prefix for the network being - restored (`Network.scope_name + '/'`, e.g. 'my_network_1/') is stripped - and all other variable names (shared with other Networks) are left - unchanged. Note that this is the _same_ map_func as `Network.save`, not - an inverse mapping. - """ - self._finalize_name(parent_network=False) - self._set_scope() # scope_name should be available to map_funcs - if os.path.isdir(save_path): - # If we don't have a name yet, set no parent. - save_path = os.path.join(save_path, self.name) - user_map_func = map_func - if map_func is None: - map_func = self._strip_variable_prefix - # Step one is to restore any existing variables from the checkpoint. - existing_variables_by_checkpoint_name = self._restore_existing_variables( - save_path=save_path, - map_func=map_func, - user_map_func=user_map_func) - # Step two is to set a custom getter which restores variables on creation, - # for those variables which have not been added to sub-Layers yet. - self._set_restore_on_create( - save_path=save_path, - map_func=map_func, - user_map_func=user_map_func, - existing_variables_by_checkpoint_name=( - existing_variables_by_checkpoint_name)) + layer_losses.extend(layer.losses) + return layer_losses - # TODO(josh11b): Support other Layer methods needed for graph mode, such as for - # losses and updates + # TODO(allenl): Support other Layer methods needed for graph mode, such as for + # updates class Sequential(Network): @@ -799,3 +522,436 @@ class Sequential(Network): else: inputs = l(inputs) return inputs + + +_DeferredRestoration = collections.namedtuple( + + "_DeferredRestoration", + [ + # The map_func to use (either user-specified or the default). + "map_func", + # Boolean, True if the user specified an explicit map_func, for error + # messages. + "map_func_is_user", + # A mapping from checkpoint names to initial values of not-yet-created + # variables which should be restored. These values come from parsing a + # checkpoint. + "checkpointed_variables_to_restore", + # A mapping from checkpoint name to variable objects of variables which + # have already been restored, for error checking. + "restored_variables", + # The session to restore with (if in graph mode). + "session", + # Names of the Network where the restore was requested, for error + # messages. + "network_name", + "network_scope_name" + ]) + + +def _default_naming_conflict_error_message( + mapped_name, first_variable, second_variable, + network_name, network_scope_name): + return ( + ("The default checkpoint variable name mapping strategy for Network " + "'%s' resulted in a naming conflict. We attempted to strip off the " + "variable prefix for the Network ('%s'), but this resulted in two " + "variables named '%s' (originally '%s' and '%s'). This should only " + "happen when using variable sharing (i.e. the Network contains Networks " + "or Layers which were first added to another Network, and therefore " + "have that Network's variable prefix). One solution is to pass " + "`map_func=lambda n: n` to save and restore to use fully qualified " + "variable names in the checkpoint, although this will require that the " + "variable prefix of the Network being restored into is also '%s'. You " + "may alternatively write an arbitrary mapping.") + % ( + network_name, network_scope_name, mapped_name, + first_variable._shared_name, + second_variable._shared_name, network_scope_name + )) + + +def _restore_custom_map_func_error_message( + mapped_name, first_variable, second_variable, + network_name, network_scope_name): + return ( + ("The map_func passed to restore_network_checkpoint for the Network '%s' " + "resulted in two variables named '%s' (originally '%s' and '%s'). Since " + "this is also an error when saving, this Network was " + "probably not saved with this map_func. Note that map_func " + "always maps from full variable names to checkpoint names; " + "there is no need to specify an inverse mapping.\n\n" + "Try stripping less from the variable names, or renaming parts " + "of the Network. For reference, variables created by sub-Layers " + "of this Network are prefixed with '%s', but if they are " + "re-used after being added to another Network they will have " + "that Network's full variable prefix instead.") % ( + network_name, mapped_name, + first_variable._shared_name, + second_variable._shared_name, + network_scope_name)) + + +def _make_custom_getter_for_deferred_restorations(): + """Returns a custom getter which searches `deferred_restorations`. + + Returns: A tuple of (_custom_getter, deferred_restorations) + _custom_getter: The getter which should be added to variable_scopes where + variables will be created. + deferred_restorations: A list for _DeferredRestoration objects. Typically + empty when the getter is set, and expanded as deferred restorations are + requested. All new deferred restorations should be appended to the end of + the list, where they will have priority over older deferred restorations. + """ + deferred_restorations = [] + + def _custom_getter(getter, name, shape=None, dtype=None, + initializer=None, + *args, **kwargs): + """A custom getter which processes deferred restorations.""" + # Iterate over restorations, newest first (newer restorations will take + # precedence over older restorations, just like with immediate restorations + # into existing variables). + delayed_restoration = None + found_value = False + value_to_restore = None + for delayed_restoration in reversed( + deferred_restorations): + checkpoint_name = delayed_restoration.map_func(name) + if (checkpoint_name + in delayed_restoration.checkpointed_variables_to_restore): + found_value = True + value_to_restore = ( + delayed_restoration.checkpointed_variables_to_restore[ + checkpoint_name]) + if found_value: + break + # value_to_restore may be False because this variable is not in any + # checkpoint we are restoring, or None because we have explicitly set it to + # None when it was previously fetched. In either case, we don't need to + # set an initializer. + if found_value and value_to_restore is not None: + initializer = value_to_restore + shape = None + variable = getter(name, shape=shape, dtype=dtype, initializer=initializer, + *args, **kwargs) + if found_value and value_to_restore is not None: + # Mark as already restored from this checkpoint. + delayed_restoration.checkpointed_variables_to_restore[ + checkpoint_name] = None + if context.in_graph_mode(): + delayed_restoration.session.run(variable.initializer) + if found_value: + # Error checking should run even if we've already restored a value. + if delayed_restoration.restored_variables.setdefault( + checkpoint_name, variable) is not variable: + # Naming conflict. We've tried to initialize two variables with the + # same value from the checkpoint. + if delayed_restoration.map_func_is_user: + raise ValueError( + _restore_custom_map_func_error_message( + mapped_name=checkpoint_name, + first_variable=delayed_restoration.restored_variables[ + checkpoint_name], + second_variable=variable, + network_name=delayed_restoration.network_name, + network_scope_name=delayed_restoration.network_scope_name)) + else: + raise ValueError( + _default_naming_conflict_error_message( + mapped_name=checkpoint_name, + first_variable=delayed_restoration.restored_variables[ + checkpoint_name], + second_variable=variable, + network_name=delayed_restoration.network_name, + network_scope_name=delayed_restoration.network_scope_name)) + return variable + return _custom_getter, deferred_restorations + + +def _make_prefix_stripping_map_fn(scope_name): + """Closure for stripping the scope name of a Network. + + Implemented as a closure rather than a member function to avoid reference + cycles in deferred restorations (this function should not have a reference to + the Network which created it). + + Args: + scope_name: The Network.scope_name to strip from variables. + Returns: + A scope_name-stripping default `map_fn` for the Network. + """ + + def _strip_variable_prefix(original_variable_name): + """The default map_func for saving or restoring variables. + + Strips the variable prefix for the Network on which save/restore was called, + and leaves other variable names fully qualified in the checkpoint. + + Args: + original_variable_name: The _shared_name of the variable (no :0 + suffix) to map. + Returns: + The checkpoint name of the variable. + """ + scope_name_with_slash = scope_name + "/" + if original_variable_name.startswith(scope_name_with_slash): + return original_variable_name[len(scope_name_with_slash):] + else: + return original_variable_name + + return _strip_variable_prefix + + +def save_network_checkpoint( + network, save_path, global_step=None, map_func=None): + """Save variables from the Network to a checkpoint. + + Args: + network: A Network object to save. + save_path: Either a checkpoint prefix or the name of a directory to save + the checkpoint in (in which case the checkpoint will be named based on + the Network name). + global_step: The global step to use when naming the checkpoint. If None + (default), we will first try to get the default global step. If that + fails because no default global step exists, then the checkpoint is + created without a global step suffix. + map_func: A function mapping fully qualified variable names + (e.g. 'my_network_1/dense_1/kernel') to names in the checkpoint. By + default (if `map_func=None`), the variable prefix for the network being + restored (`Network.scope_name + '/'`, e.g. 'my_network_1/') is stripped + and all other variable names (shared with other Networks) are left + unchanged. + Returns: + The checkpoint prefix for the saved checkpoint, which may be passed to + `Network.restore`. + Raises: + ValueError: If the Network has not yet been called, or if map_func results + in a name collision. + """ + if not network.built: + raise ValueError( + "Attempt to save the Network before it was first called. This means " + "variables have not yet been created, so there is nothing to save.") + network._set_scope() # scope_name should be available to map_funcs + if global_step is None: + global_step = training_util.get_global_step() + if os.path.isdir(save_path): + # If we were passed a directory, default to naming based on the Network + # name. + save_path = os.path.join(save_path, network.name.replace("/", "_")) + user_map_func = map_func + if map_func is None: + map_func = _make_prefix_stripping_map_fn(network.scope_name) + variable_map = {} + for variable in network.variables: + mapped_name = map_func(variable._shared_name) + if variable_map.setdefault(mapped_name, variable) is not variable: + if user_map_func is None: + # Instead of erroring out, we could just re-try and silently use the + # full variable names in the checkpoint. This could be odd for deeply + # nested sub-Networks (since the full prefix from the nesting would + # get added), so for now we'll let the user deal with this case. + raise ValueError(_default_naming_conflict_error_message( + mapped_name=mapped_name, + first_variable=variable_map[mapped_name], + second_variable=variable, + network_name=network.name, + network_scope_name=network.scope_name)) + else: + # The user passed their own problematic map_func. + raise ValueError( + ("The map_func passed to save_network_checkpoint for the Network " + "'%s' resulted in two variables named '%s' ('%s' and '%s'). Try " + "stripping less from the variable names, or renaming parts of " + "the Network. For reference, variables created by sub-Layers of " + "this Network are prefixed with '%s', but if they are re-used " + "after being added to another Network, they will have that " + "Network's full variable prefix instead.") % ( + network.name, mapped_name, + variable_map[mapped_name]._shared_name, + variable._shared_name, + network.scope_name)) + if context.in_eager_mode(): + sess = None + else: + sess = ops.get_default_session() + return saver_lib.Saver(variable_map).save( + sess=sess, save_path=save_path, write_meta_graph=False, + global_step=global_step) + + +def _add_deferred_restoration(layer, deferred_restoration): + """Add a deferred restoration to this Layer and all children. + + Restorations which are requested later have higher priority, and the highest + priority matching restoration is applied to a variable when it is created. + + Args: + layer: The Layer (may not be a Network) to operate on. + deferred_restoration: A _DeferredRestoration object. + """ + # Networks don't create variables at the moment, so this append isn't strictly + # necessary. We could get by with only adding deferred restorations to + # non-Network Layers. + if isinstance(layer, Network): + layer._set_scope() + # Make sure this Layer has a deferred restoration queue and a custom getter, + # then add our request to it. + if not hasattr(layer, "_custom_getter"): + assert not hasattr(layer, "_deferred_restorations") + layer._custom_getter, layer._deferred_restorations = ( + _make_custom_getter_for_deferred_restorations()) + # We use set_custom_getter because it avoids recursively calling up the + # variable_scope tree. We've done the tree traversal ourselves and have added + # the request to each Layer which needs it. + layer._scope.set_custom_getter(layer._custom_getter) + layer._deferred_restorations.append(deferred_restoration) + if isinstance(layer, Network): + for sublayer in layer.layers: + if not isinstance(sublayer, Network): + layer._set_scope_for_nonnetwork_sublayer(sublayer) + _add_deferred_restoration(sublayer, deferred_restoration) + + +def _restore_existing_variables(network, save_path, map_func, user_map_func): + """Use a standard Saver to restore existing variables from a checkpoint. + + Args: + network: A Network object to restore. + save_path: The checkpoint prefix or directory to read from. + map_func: The function to use when mapping from variable names to + checkpoint names. + user_map_func: The original map_func passed by the user, for error + checking. + Returns: + A dictionary mapping from checkpoint names to variable objects which have + been restored (for bookkeeping to avoid deferred restorations on these + variables). + Raises: + ValueError: If there is a name collision. + """ + existing_variables_by_checkpoint_name = {} + for variable in network.variables: + checkpoint_name = map_func(variable._shared_name) + if existing_variables_by_checkpoint_name.setdefault( + checkpoint_name, variable) is not variable: + if user_map_func is None: + raise ValueError(_default_naming_conflict_error_message( + mapped_name=checkpoint_name, + first_variable=existing_variables_by_checkpoint_name[ + checkpoint_name], + second_variable=variable, + network_name=network.name, + network_scope_name=network.scope_name)) + else: + raise ValueError(_restore_custom_map_func_error_message( + mapped_name=checkpoint_name, + first_variable=existing_variables_by_checkpoint_name[ + checkpoint_name], + second_variable=variable, + network_name=network.name, + network_scope_name=network.scope_name)) + if existing_variables_by_checkpoint_name: + if context.in_eager_mode(): + sess = None + else: + sess = ops.get_default_session() + saver_lib.Saver(var_list=existing_variables_by_checkpoint_name).restore( + sess=sess, save_path=save_path) + return existing_variables_by_checkpoint_name + + +def _set_restore_on_create(network, save_path, map_func, user_map_func, + existing_variables_by_checkpoint_name): + """If necessary, request deferred restorations of variables.""" + checkpoint_reader = checkpoint_utils.load_checkpoint(save_path) + checkpointed_variables_to_restore = {} + for checkpoint_name, _ in checkpoint_utils.list_variables(save_path): + if checkpoint_name in existing_variables_by_checkpoint_name: + # This variable was already created and restored. + continue + # Save the variable for later restoration in a custom getter. + checkpointed_variables_to_restore[checkpoint_name] = ( + checkpoint_reader.get_tensor(checkpoint_name)) + # Only set a deferred restoration if there are checkpoint variables which + # have not been assigned to existing variables. Note that this loses out on + # some opportunity for error checking, but avoids creating + # _DeferredRestoration objects once a Network has been built (so that + # restoring in a loop does not take increasing amounts of memory). + if checkpointed_variables_to_restore: + if context.in_eager_mode(): + sess = None + else: + sess = ops.get_default_session() + # We need a name for error messages. If we haven't been added to another + # Network yet, we're top-level. + network._finalize_name(False) + network._set_scope() + # Save a record of this restoration for use in the custom getter. + deferred_restoration = _DeferredRestoration( + map_func=map_func, + map_func_is_user=(user_map_func is not None), + checkpointed_variables_to_restore=checkpointed_variables_to_restore, + restored_variables={}, + session=sess, + network_name=network.name, + network_scope_name=network.scope_name) + # Add the deferred registration to non-Network children, and request that + # Networks propagate the request to their children. + _add_deferred_restoration(network, deferred_restoration) + + +def restore_network_checkpoint(network, save_path, map_func=None): + """Restore the Network from a checkpoint. + + If variables have already been created (typically when some or all of the + `Network` is built), they are assigned values from the checkpoint immediately, + overwriting any existing values (in graph mode the default session is used for + the assignments). + + If there are checkpoint entries which do not correspond to any existing + variables in the `Network`, these values are saved for deferred restoration; + their initial values will be the checkpointed values once they are + created. Requests for multiple deferred restorations behave the same way as + immediate restorations, in that later requests will take priority over earlier + requests relevant to the same variable. + + If this `Network` shares `Layer`s with another network, those `Layer`s will + also have their variables restored from the checkpoint. + + Args: + network: A Network object to restore. + save_path: The return value of `tfe.save_network_checkpoint`, or a directory + to search for a checkpoint. + map_func: A function mapping fully qualified variable names + (e.g. 'my_network_1/dense_1/kernel') to names in the checkpoint. By + default (if `map_func=None`), the variable prefix for the network being + restored (`Network.scope_name + '/'`, e.g. 'my_network_1/') is stripped + and all other variable names (shared with other Networks) are left + unchanged. Note that this is the _same_ map_func as + `tfe.save_network_checkpoint`, not an inverse mapping. + """ + network._finalize_name(parent_network=False) + network._set_scope() # scope_name should be available to map_funcs + if os.path.isdir(save_path): + # If we don't have a name yet, set no parent. + save_path = os.path.join(save_path, network.name.replace("/", "_")) + user_map_func = map_func + if map_func is None: + map_func = _make_prefix_stripping_map_fn(network.scope_name) + # Step one is to restore any existing variables from the checkpoint. + existing_variables_by_checkpoint_name = _restore_existing_variables( + network=network, + save_path=save_path, + map_func=map_func, + user_map_func=user_map_func) + # Step two is to set a custom getter which restores variables on creation, + # for those variables which have not been added to sub-Layers yet. + _set_restore_on_create( + network=network, + save_path=save_path, + map_func=map_func, + user_map_func=user_map_func, + existing_variables_by_checkpoint_name=( + existing_variables_by_checkpoint_name)) diff --git a/tensorflow/contrib/eager/python/network_test.py b/tensorflow/contrib/eager/python/network_test.py index c621f527c28306131bdba56d8427eaa787ba150b..3eb4f5f8b3954a7ed04d2ef1d4f119ad137e1e65 100644 --- a/tensorflow/contrib/eager/python/network_test.py +++ b/tensorflow/contrib/eager/python/network_test.py @@ -19,9 +19,13 @@ from __future__ import print_function import gc from tensorflow.contrib.eager.python import network +from tensorflow.contrib.layers.python.layers import regularizers +from tensorflow.python.eager import context +from tensorflow.python.eager import function from tensorflow.python.eager import test from tensorflow.python.framework import constant_op from tensorflow.python.framework import errors_impl +from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.layers import core from tensorflow.python.ops import math_ops @@ -42,12 +46,28 @@ class MyNetwork(network.Network): return self.l1(x) +class RegularizedNetwork(network.Network): + + def __init__(self): + super(RegularizedNetwork, self).__init__() + self.l1 = self.track_layer(core.Dense( + 1, + bias_regularizer=regularizers.l1_regularizer(2.0), + kernel_regularizer=regularizers.l1_regularizer(2.0))) + self.l2 = self.track_layer(core.Dense( + 1, + bias_regularizer=regularizers.l1_regularizer(2.0))) + + def call(self, values): + return self.l2(self.l1(values)) + + class NetworkTest(test.TestCase): def _save_modify_load_network_built(self, net, global_step=None): checkpoint_directory = self.get_temp_dir() - checkpoint_path = net.save( - save_path=checkpoint_directory, global_step=global_step) + checkpoint_path = network.save_network_checkpoint( + network=net, save_path=checkpoint_directory, global_step=global_step) input_value = constant_op.constant([[42.0]]) original_output = self.evaluate(net(input_value)) for var in net.variables: @@ -56,18 +76,18 @@ class NetworkTest(test.TestCase): self.evaluate(net(input_value)), original_output) # Either the returned explicit checkpoint path or the directory should work. - net.restore(save_path=checkpoint_directory) + network.restore_network_checkpoint(net, save_path=checkpoint_directory) self.assertAllEqual( original_output, self.evaluate(net(input_value))) for var in net.variables: self.evaluate(var.assign(var + 2.)) - net.restore(save_path=checkpoint_path) + network.restore_network_checkpoint(net, save_path=checkpoint_path) self.assertAllEqual( original_output, self.evaluate(net(input_value))) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) def testTrainableAttribute(self): net = network.Network() self.assertTrue(net.trainable) @@ -75,7 +95,7 @@ class NetworkTest(test.TestCase): net.trainable = False self.assertTrue(net.trainable) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) def testNetworkCall(self): net = MyNetwork(name="abcd") net(constant_op.constant([[2.0]])) # Force variables to be created. @@ -85,17 +105,36 @@ class NetworkTest(test.TestCase): result = net(constant_op.constant([[2.0]])) self.assertEqual(34.0, self.evaluate(result)) + # TODO(akshayka): This test should be changed once an API for compiling + # `call` into a defun is implemented. + def testReplacingNetworkCallWithDefun(self): + net = MyNetwork(name="abcd") + x = constant_op.constant([[2.0]]) + net(x) # Force variables to be created. + self.evaluate(net.trainable_variables[0].assign([[17.0]])) + + net.call = function.defun(net.call) + result = net(x) # Build and execute the TensorFlow function + self.assertEqual(34.0, self.evaluate(result)) + + # Force the creation of another TensorFlow function by changing input shape + y = constant_op.constant([[1.0], [2.0]]) + result = net(y) + self.assertAllEqual([[17.0], [34.0]], self.evaluate(result)) + + # TODO(allenl): This test creates garbage in some Python versions @test_util.run_in_graph_and_eager_modes() def testNetworkSaveRestoreAlreadyBuilt(self): net = MyNetwork(name="abcd") with self.assertRaisesRegexp( ValueError, "Attempt to save the Network before it was first called"): - net.save(self.get_temp_dir()) + network.save_network_checkpoint(net, self.get_temp_dir()) net(constant_op.constant([[2.0]])) self.evaluate(net.trainable_variables[0].assign([[17.0]])) self._save_modify_load_network_built(net, global_step=None) self._save_modify_load_network_built(net, global_step=10) + # TODO(allenl): This test creates garbage in some Python versions @test_util.run_in_graph_and_eager_modes() def testSaveRestoreDefaultGlobalStep(self): net = MyNetwork(name="abcd") @@ -103,9 +142,10 @@ class NetworkTest(test.TestCase): self.evaluate(net.variables[0].assign([[3.]])) default_global_step = training_util.get_or_create_global_step() self.evaluate(default_global_step.assign(4242)) - save_path = net.save(self.get_temp_dir()) + save_path = network.save_network_checkpoint(net, self.get_temp_dir()) self.assertIn("abcd-4242", save_path) + # TODO(allenl): This test creates garbage in some Python versions @test_util.run_in_graph_and_eager_modes() def testNetworkSaveAndRestoreIntoUnbuilt(self): save_dir = self.get_temp_dir() @@ -113,16 +153,43 @@ class NetworkTest(test.TestCase): test_input = constant_op.constant([[2.0]]) net1(test_input) self.evaluate(net1.trainable_variables[0].assign([[17.0]])) - save_path = net1.save(save_dir) + save_path = network.save_network_checkpoint(net1, save_dir) # With a pre-build restore we should have the same value. net2 = MyNetwork() - net2.restore(save_path) + network.restore_network_checkpoint(net2, save_path) self.assertAllEqual(self.evaluate(net1(test_input)), self.evaluate(net2(test_input))) self.assertIsNot(net1.variables[0], net2.variables[0]) self.assertAllEqual(self.evaluate(net1.variables[0]), self.evaluate(net2.variables[0])) + @test_util.run_in_graph_and_eager_modes() + def testNetworkMatchesLayerVariableNames(self): + zero = constant_op.constant([[0.]]) + layer_one = core.Dense(1, use_bias=False) + layer_one(zero) + layer_two = core.Dense(1, use_bias=False) + layer_two(zero) + + class TwoLayerNet(network.Network): + + def __init__(self, name=None): + super(TwoLayerNet, self).__init__(name=name) + self.first = self.track_layer(core.Dense( + 1, use_bias=False)) + self.second = self.track_layer(core.Dense( + 1, use_bias=False)) + + def call(self, x): + return self.second(self.first(x)) + + net = TwoLayerNet() + net(zero) + self.assertEqual("two_layer_net/" + layer_one.variables[0].name, + net.first.variables[0].name) + self.assertEqual("two_layer_net/" + layer_two.variables[0].name, + net.second.variables[0].name) + @test_util.run_in_graph_and_eager_modes() def testLoadIntoUnbuiltSharedLayer(self): @@ -170,14 +237,15 @@ class NetworkTest(test.TestCase): # Re-map the variable names so that with default restore mapping we'll # attempt to restore into the unbuilt Layer. name_mapping = { - "checkpoint_creator/first_layer/kernel": "owner_1/first_layer/kernel", + "checkpoint_creator/first_layer/kernel": "owner/first_layer/kernel", "checkpoint_creator/second_layer/kernel": "second_layer/kernel", } - save_path = checkpoint_creator.save( + save_path = network.save_network_checkpoint( + checkpoint_creator, self.get_temp_dir(), map_func=lambda full_name: name_mapping[full_name]) load_into = User(use_layer=first_owner.first) - load_into.restore(save_path) + network.restore_network_checkpoint(load_into, save_path) self.assertEqual(0, len(first_owner.variables)) self.assertAllEqual(self.evaluate(checkpoint_creator(one)), self.evaluate(load_into(one))) @@ -193,12 +261,13 @@ class NetworkTest(test.TestCase): del first_owner gc.collect() def _restore_map_func(original_name): - if original_name.startswith("owner_1"): - return original_name.replace("owner_1", "owner_2") + if original_name.startswith("owner/"): + return original_name.replace("owner/", "owner_1/") else: - return "user_2/" + original_name + return "user_1/" + original_name with self.assertRaisesRegexp(ValueError, "garbage collected"): - load_into.restore(save_path, map_func=_restore_map_func) + network.restore_network_checkpoint( + load_into, save_path, map_func=_restore_map_func) @test_util.run_in_graph_and_eager_modes() def testRestoreIntoSubNetwork(self): @@ -218,17 +287,18 @@ class NetworkTest(test.TestCase): whole_model_saver(one) self.evaluate(whole_model_saver.variables[0].assign([[15.]])) self.evaluate(whole_model_saver.variables[1].assign([[16.]])) - whole_model_checkpoint = whole_model_saver.save(self.get_temp_dir()) + whole_model_checkpoint = network.save_network_checkpoint( + whole_model_saver, self.get_temp_dir()) save_from = MyNetwork() save_from(one) self.evaluate(save_from.variables[0].assign([[5.]])) - checkpoint = save_from.save(self.get_temp_dir()) + checkpoint = network.save_network_checkpoint(save_from, self.get_temp_dir()) save_into_parent = Parent() - save_into_parent.restore(whole_model_checkpoint) - save_into_parent.first.restore(checkpoint) - save_into_parent.first.restore(checkpoint) # deferred loading multiple - # times is fine + network.restore_network_checkpoint(save_into_parent, whole_model_checkpoint) + network.restore_network_checkpoint(save_into_parent.first, checkpoint) + # deferred loading multiple times is fine + network.restore_network_checkpoint(save_into_parent.first, checkpoint) save_into_parent(one) # deferred loading self.assertAllEqual([[5.]], self.evaluate(save_into_parent.variables[0])) self.assertAllEqual([[16.]], self.evaluate(save_into_parent.variables[1])) @@ -237,9 +307,9 @@ class NetworkTest(test.TestCase): # (deferred restoration should happen the same way non-deferred happens, # with later restorations overwriting older ones). save_into_parent = Parent() - save_into_parent.first.restore(checkpoint) # deferred loading multiple - # times is fine - save_into_parent.restore(whole_model_checkpoint) + # deferred loading multiple times is fine + network.restore_network_checkpoint(save_into_parent.first, checkpoint) + network.restore_network_checkpoint(save_into_parent, whole_model_checkpoint) save_into_parent(one) # deferred loading # We've overwritten the sub-Network restore. self.assertAllEqual([[15.]], self.evaluate(save_into_parent.variables[0])) @@ -247,12 +317,12 @@ class NetworkTest(test.TestCase): self.evaluate(save_into_parent.variables[0].assign([[3.]])) self.evaluate(save_into_parent.variables[1].assign([[4.]])) - save_into_parent.second.restore(checkpoint) + network.restore_network_checkpoint(save_into_parent.second, checkpoint) self.assertAllEqual([[5.]], self.evaluate(save_into_parent.variables[1])) with self.assertRaisesRegexp(errors_impl.NotFoundError, "not found in checkpoint"): # The checkpoint is incompatible. - save_into_parent.restore(checkpoint) + network.restore_network_checkpoint(save_into_parent, checkpoint) @test_util.run_in_graph_and_eager_modes() def testCustomMapCollisionErrors(self): @@ -274,31 +344,36 @@ class NetworkTest(test.TestCase): self.evaluate(make_checkpoint.variables[1].assign([[3.]])) with self.assertRaisesRegexp( ValueError, - "The map_func passed to Network.save for the Network 'parent_1' " - "resulted in two variables named 'foo'"): - make_checkpoint.save(self.get_temp_dir(), map_func=lambda n: "foo") - checkpoint = make_checkpoint.first.save( - self.get_temp_dir(), map_func=lambda n: "foo") + "The map_func passed to save_network_checkpoint for the Network " + "'parent' resulted in two variables named 'foo'"): + network.save_network_checkpoint( + make_checkpoint, self.get_temp_dir(), map_func=lambda n: "foo") + checkpoint = network.save_network_checkpoint( + network=make_checkpoint.first, + save_path=self.get_temp_dir(), + map_func=lambda n: "foo") loader = Parent() - loader.restore(checkpoint, map_func=lambda n: "foo") + network.restore_network_checkpoint( + loader, checkpoint, map_func=lambda n: "foo") with self.assertRaisesRegexp( ValueError, - ("The map_func passed to Network.restore for the Network" - " 'parent_2' resulted in two variables named 'foo'")): + ("The map_func passed to restore_network_checkpoint for the Network" + " 'parent_1' resulted in two variables named 'foo'")): loader(one) loader = Parent() loader(one) with self.assertRaisesRegexp( ValueError, - ("The map_func passed to Network.restore for the Network" - " 'parent_3' resulted in two variables named 'foo'")): - loader.restore(checkpoint, map_func=lambda n: "foo") + ("The map_func passed to restore_network_checkpoint for the Network" + " 'parent_2' resulted in two variables named 'foo'")): + network.restore_network_checkpoint( + loader, checkpoint, map_func=lambda n: "foo") @test_util.run_in_graph_and_eager_modes() def testDefaultMapCollisionErrors(self): one = constant_op.constant([[1.]]) - first = core.Dense(1, name="dense_1", use_bias=False) + first = core.Dense(1, name="dense", use_bias=False) first(one) class Parent(network.Network): @@ -319,8 +394,8 @@ class NetworkTest(test.TestCase): with self.assertRaisesRegexp( ValueError, ("The default checkpoint variable name mapping strategy for Network " - "'parent_1' resulted in a naming conflict.")): - make_checkpoint.save(self.get_temp_dir()) + "'parent' resulted in a naming conflict.")): + network.save_network_checkpoint(make_checkpoint, self.get_temp_dir()) class Compatible(network.Network): @@ -334,14 +409,15 @@ class NetworkTest(test.TestCase): successful_checkpoint = Compatible() successful_checkpoint(one) self.evaluate(successful_checkpoint.variables[0].assign([[-1.]])) - checkpoint_path = successful_checkpoint.save(self.get_temp_dir()) + checkpoint_path = network.save_network_checkpoint( + successful_checkpoint, self.get_temp_dir()) load_checkpoint = Parent() load_checkpoint(one) with self.assertRaisesRegexp( ValueError, ("The default checkpoint variable name mapping strategy for Network " - "'parent_2' resulted in a naming conflict.")): - load_checkpoint.restore(checkpoint_path) + "'parent_1' resulted in a naming conflict.")): + network.restore_network_checkpoint(load_checkpoint, checkpoint_path) def testNoReferenceCyclesAfterCall(self): @@ -377,25 +453,67 @@ class NetworkTest(test.TestCase): gc.set_debug(previous_gc_debug_flags) gc.enable() - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) def testAnonymousNoNameInitially(self): net = MyNetwork() with self.assertRaisesRegexp(ValueError, "does not yet have a final name"): net.name # pylint: disable=pointless-statement - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) def testExplicitHasNameInitially(self): net = MyNetwork(name="abcd") self.assertEqual("abcd", net.name) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) def testUsingResourceVariables(self): net = MyNetwork() net(constant_op.constant([[0.]])) self.assertIsInstance(net.trainable_weights[0], resource_variable_ops.ResourceVariable) - @test_util.run_in_graph_and_eager_modes() + def testGraphOpNames(self): + """Network operation names should match variable naming.""" + + def _check_op_prefixes(expected_prefix, checked_ops): + for operation in ops.get_default_graph().get_operations(): + if operation.name == "ignore": + continue + if operation.name in checked_ops: + continue + checked_ops.add(operation.name) + self.assertStartsWith(expected_start=expected_prefix, + actual=operation.name) + self.assertNotIn("my_network", operation.name[len(expected_prefix):]) + self.assertNotIn("dense", operation.name[len(expected_prefix):]) + + with context.graph_mode(): + net = MyNetwork() + zero = constant_op.constant([[0.]], name="ignore") + net(zero) + checked_ops = set() + _check_op_prefixes(expected_prefix="my_network/dense/", + checked_ops=checked_ops) + net.net2 = net.track_layer(MyNetwork()) + net.net2(zero) + _check_op_prefixes(expected_prefix="my_network/my_network/dense/", + checked_ops=checked_ops) + MyNetwork()(zero) + _check_op_prefixes(expected_prefix="my_network_1/dense/", + checked_ops=checked_ops) + + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) + def testVariableRegularizers(self): + net = RegularizedNetwork() + net(constant_op.constant([[1.]])) + self.evaluate(net.variables[0].assign([[2.]])) + self.evaluate(net.variables[1].assign([3.])) + self.evaluate(net.variables[2].assign([[-2.]])) + self.evaluate(net.variables[3].assign([4.])) + self.assertAllEqual([4., 6., 8.], self.evaluate(net.losses)) + self.evaluate(net.variables[3].assign([5.])) + self.assertAllEqual([4., 6., 10.], self.evaluate(net.losses)) + + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) def testDuplicateNameError(self): one = constant_op.constant([[1.]]) net = MyNetwork(name="foo") @@ -405,21 +523,105 @@ class NetworkTest(test.TestCase): net1 = MyNetwork(name="foo") net1(one) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) def testWrappingInVariableScope(self): + one = constant_op.constant([[1.]]) + # Naming happens in the order of first build rather than the order of + # construction, but for clarity they're the same here and construction is + # annotated. + outside_net_before = MyNetwork() # name=my_network + outside_net_before(one) + captured_scope = variable_scope.get_variable_scope() with variable_scope.variable_scope("outside_scope"): - net = MyNetwork() - one = constant_op.constant([[1.]]) - with self.assertRaisesRegexp( - ValueError, - ("Creating Networks inside named variable_scopes is currently not " - "supported")): - net(one) - # Alternatively, we could re-name the Network to match the variable_scope: - # self.assertEqual("outside_scope/my_network_1", net.name) - # self.assertStartsWith( - # expected_start="outside_scope/my_network_1/dense/", - # actual=net.trainable_weights[0].name) + net1 = MyNetwork() # name=outside_scope/my_network + net1(one) + name_conflict1 = MyNetwork(name="name_conflict") # fine, unique so far + name_conflict2 = MyNetwork(name="name_conflict") # error on build + with variable_scope.variable_scope("inside_scope"): + # No issue here since the name is unique within its scope. + name_conflict3 = MyNetwork(name="name_conflict") + net2 = MyNetwork() # name=outside_scope/my_network_2 to avoid the + # variable_scope my_network_1 below. + vs_name_conflict = MyNetwork(name="vs_name_conflict") # conflict below + with variable_scope.variable_scope("intervening_scope"): + with variable_scope.variable_scope(captured_scope): + with variable_scope.variable_scope("outside_scope"): + name_conflict4 = MyNetwork(name="name_conflict") # error on build + with variable_scope.variable_scope("my_network_1"): + pass + with variable_scope.variable_scope("vs_name_conflict"): + pass + net3 = MyNetwork() # name=outside_scope/my_network_4 + name_conflict1(one) + with self.assertRaisesRegexp( + ValueError, "named 'name_conflict' already exists"): + name_conflict2(one) + name_conflict3(one) + net2(one) + with self.assertRaisesRegexp( + ValueError, "or a variable_scope was created with this name"): + vs_name_conflict(one) + with self.assertRaisesRegexp( + ValueError, "named 'name_conflict' already exists"): + name_conflict4(one) + self.assertEqual("outside_scope/name_conflict", + name_conflict1.name) + self.assertStartsWith( + expected_start="outside_scope/name_conflict/dense/", + actual=name_conflict1.variables[0].name) + self.assertEqual("outside_scope/inside_scope/name_conflict", + name_conflict3.name) + self.assertStartsWith( + expected_start="outside_scope/inside_scope/name_conflict/dense/", + actual=name_conflict3.variables[0].name) + self.assertEqual("outside_scope/my_network", net1.name) + self.assertStartsWith( + expected_start="outside_scope/my_network/dense/", + actual=net1.trainable_weights[0].name) + self.assertEqual("outside_scope/my_network_2", net2.name) + self.assertStartsWith( + expected_start="outside_scope/my_network_2/dense/", + actual=net2.trainable_weights[0].name) + net3(one) + self.assertEqual("outside_scope/my_network_3", net3.name) + self.assertStartsWith( + expected_start="outside_scope/my_network_3/dense/", + actual=net3.trainable_weights[0].name) + outside_net_after = MyNetwork() + outside_net_after(one) + self.assertEqual("my_network", outside_net_before.name) + self.assertStartsWith( + expected_start="my_network/dense/", + actual=outside_net_before.trainable_weights[0].name) + self.assertEqual("my_network_1", outside_net_after.name) + self.assertStartsWith( + expected_start="my_network_1/dense/", + actual=outside_net_after.trainable_weights[0].name) + + @test_util.run_in_graph_and_eager_modes() + def testVariableScopeStripping(self): + with variable_scope.variable_scope("scope1"): + with variable_scope.variable_scope("scope2"): + net = MyNetwork() + net(constant_op.constant([[2.0]])) + self.evaluate(net.variables[0].assign([[42.]])) + self.assertEqual(net.name, "scope1/scope2/my_network") + self.assertStartsWith( + expected_start="scope1/scope2/my_network/dense/", + actual=net.trainable_weights[0].name) + save_path = network.save_network_checkpoint(net, self.get_temp_dir()) + self.assertIn("scope1_scope2_my_network", save_path) + restore_net = MyNetwork() + # Delayed restoration + network.restore_network_checkpoint(restore_net, save_path) + restore_net(constant_op.constant([[1.0]])) + self.assertAllEqual([[42.]], + self.evaluate(restore_net.variables[0])) + self.evaluate(restore_net.variables[0].assign([[-1.]])) + # Immediate restoration + network.restore_network_checkpoint(restore_net, save_path) + self.assertAllEqual([[42.]], + self.evaluate(restore_net.variables[0])) @test_util.run_in_graph_and_eager_modes() def testLayerNamesRespected(self): @@ -436,11 +638,11 @@ class NetworkTest(test.TestCase): one = constant_op.constant([[1.]]) net = ParentNetwork() net(one) - self.assertStartsWith(expected_start="parent_network_1/explicit_name/", + self.assertStartsWith(expected_start="parent_network/explicit_name/", actual=net.trainable_weights[0].name) self.assertEqual("explicit_name", net.first.name) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) def testWrappingInAnonymousVariableScope(self): # Named outside variable_scopes are not supported at the moment. However, # blank-named top level variable scopes do not change variable names, and so @@ -455,20 +657,20 @@ class NetworkTest(test.TestCase): net(one) self.assertTrue(was_called[0]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) def testReasonableSlashError(self): with self.assertRaisesRegexp( ValueError, "not allowed in Network names"): MyNetwork(name="slash/slash") - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) def testNoVariableScopeNames(self): with self.assertRaisesRegexp( ValueError, "VariableScopes are not valid Network names"): with variable_scope.variable_scope("some_scope") as vs: MyNetwork(name=vs) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) def testVariableScopeNameCollision(self): with variable_scope.variable_scope("abcd"): pass @@ -478,7 +680,7 @@ class NetworkTest(test.TestCase): one = constant_op.constant([[1.]]) net(one) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) def testNetworkVariablesDoNotInterfere(self): core.Dense(1, use_bias=True) # Should not interfere with naming. net1 = MyNetwork() @@ -491,15 +693,15 @@ class NetworkTest(test.TestCase): # locally so that previous Layer consutrciton does not interfere with # variable naming (e.g. add a Layer construction before the Network, # suddenly your previously saved checkpoint is incompatible). - self.assertEqual("dense_1", net1.l1.name) - self.assertEqual("dense_1", net2.l1.name) + self.assertEqual("dense", net1.l1.name) + self.assertEqual("dense", net2.l1.name) self.evaluate(net1.trainable_weights[0].assign([[1.]])) self.evaluate(net2.trainable_weights[0].assign([[2.]])) self.assertEqual(2., self.evaluate(net2.trainable_weights[0])) self.assertEqual(1., self.evaluate(net1.trainable_weights[0])) - self.assertStartsWith(expected_start="my_network_1/dense_1/", + self.assertStartsWith(expected_start="my_network/dense/", actual=net1.trainable_weights[0].name) - self.assertStartsWith(expected_start="my_network_2/dense_1/", + self.assertStartsWith(expected_start="my_network_1/dense/", actual=net2.trainable_weights[0].name) @test_util.run_in_graph_and_eager_modes() @@ -520,31 +722,31 @@ class NetworkTest(test.TestCase): one = constant_op.constant([[1.]]) net = ParentNetwork() net(one) - self.assertStartsWith(expected_start="parent_network_1/my_network_1/dense", + self.assertStartsWith(expected_start="parent_network/my_network/dense", actual=net.trainable_weights[0].name) - self.assertStartsWith(expected_start="parent_network_1/my_network_1/dense", + self.assertStartsWith(expected_start="parent_network/my_network/dense", actual=net.first.trainable_weights[0].name) - self.assertStartsWith(expected_start="parent_network_1/my_network_2/dense", + self.assertStartsWith(expected_start="parent_network/my_network_1/dense", actual=net.trainable_weights[1].name) - self.assertStartsWith(expected_start="parent_network_1/my_network_2/dense", + self.assertStartsWith(expected_start="parent_network/my_network_1/dense", actual=net.second.trainable_weights[0].name) - self.assertEqual("parent_network_1", net.name) - self.assertEqual("my_network_1", net.first.name) - self.assertEqual("my_network_2", net.second.name) + self.assertEqual("parent_network", net.name) + self.assertEqual("my_network", net.first.name) + self.assertEqual("my_network_1", net.second.name) net2 = ParentNetwork() net2(one) - self.assertStartsWith(expected_start="parent_network_2/my_network_1/dense", + self.assertStartsWith(expected_start="parent_network_1/my_network/dense", actual=net2.trainable_weights[0].name) - self.assertStartsWith(expected_start="parent_network_2/my_network_1/dense", + self.assertStartsWith(expected_start="parent_network_1/my_network/dense", actual=net2.first.trainable_weights[0].name) - self.assertStartsWith(expected_start="parent_network_2/my_network_2/dense", + self.assertStartsWith(expected_start="parent_network_1/my_network_1/dense", actual=net2.trainable_weights[1].name) - self.assertStartsWith(expected_start="parent_network_2/my_network_2/dense", + self.assertStartsWith(expected_start="parent_network_1/my_network_1/dense", actual=net2.second.trainable_weights[0].name) - self.assertEqual("parent_network_2", net2.name) - self.assertEqual("my_network_1", net2.first.name) - self.assertEqual("my_network_2", net2.second.name) + self.assertEqual("parent_network_1", net2.name) + self.assertEqual("my_network", net2.first.name) + self.assertEqual("my_network_1", net2.second.name) @test_util.run_in_graph_and_eager_modes() def testNestableExplicit(self): @@ -605,26 +807,26 @@ class NetworkTest(test.TestCase): one = constant_op.constant([[1.]]) net = MixedLayerNetwork() net(one) - self.assertEqual("dense_1", net.first.name) - self.assertEqual("dense_2", net.second.name) - self.assertEqual("dense_3", net.third.name) - self.assertEqual("dense_4", net.fourth.name) - self.assertEqual("dense_5", net.fifth.name) + self.assertEqual("dense", net.first.name) + self.assertEqual("dense_1", net.second.name) + self.assertEqual("dense_2", net.third.name) + self.assertEqual("dense_3", net.fourth.name) + self.assertEqual("dense_4", net.fifth.name) # Note that this is _not_ the default naming behavior for Layers. Layers # which are added to Networks follow Network variable naming conventions # (i.e. variable names = network name unless variable sharing). Nested # Layers revert to Layer behavior. - self.assertStartsWith(expected_start="mixed_layer_network_1/dense_1/", + self.assertStartsWith(expected_start="mixed_layer_network/dense/", actual=net.trainable_weights[0].name) - self.assertStartsWith(expected_start="mixed_layer_network_1/dense_2/", + self.assertStartsWith(expected_start="mixed_layer_network/dense_1/", actual=net.trainable_weights[1].name) - self.assertStartsWith(expected_start="mixed_layer_network_1/dense_3/", + self.assertStartsWith(expected_start="mixed_layer_network/dense_2/", actual=net.trainable_weights[2].name) - self.assertStartsWith(expected_start="mixed_layer_network_1/dense_4/", + self.assertStartsWith(expected_start="mixed_layer_network/dense_3/", actual=net.trainable_weights[3].name) - self.assertStartsWith(expected_start="mixed_layer_network_1/dense_5/", + self.assertStartsWith(expected_start="mixed_layer_network/dense_4/", actual=net.trainable_weights[4].name) - self.assertEqual("mixed_layer_network_1", net.name) + self.assertEqual("mixed_layer_network", net.name) @test_util.run_in_graph_and_eager_modes() def testNestableExplicitCollisions(self): @@ -677,24 +879,24 @@ class NetworkTest(test.TestCase): net = ParentNetwork() net(one) self.assertStartsWith( - expected_start="parent_network_1/first_unique_child_name/dense_1/", + expected_start="parent_network/first_unique_child_name/dense/", actual=net.trainable_weights[0].name) self.assertStartsWith( - expected_start="parent_network_1/second_unique_child_name/dense_1/", + expected_start="parent_network/second_unique_child_name/dense/", actual=net.trainable_weights[1].name) - self.assertEqual("parent_network_1", net.name) + self.assertEqual("parent_network", net.name) self.assertEqual("first_unique_child_name", net.first.name) self.assertEqual("second_unique_child_name", net.second.name) net2 = ParentNetwork() net2(one) self.assertStartsWith( - expected_start="parent_network_2/first_unique_child_name/dense", + expected_start="parent_network_1/first_unique_child_name/dense", actual=net2.trainable_weights[0].name) self.assertStartsWith( - expected_start="parent_network_2/second_unique_child_name/dense", + expected_start="parent_network_1/second_unique_child_name/dense", actual=net2.trainable_weights[1].name) - self.assertEqual("parent_network_2", net2.name) + self.assertEqual("parent_network_1", net2.name) self.assertEqual("first_unique_child_name", net2.first.name) self.assertEqual("second_unique_child_name", net2.second.name) @@ -752,15 +954,15 @@ class NetworkTest(test.TestCase): net2(one) self.assertStartsWith( - expected_start="first_parent_network_1/my_network_1/dense_1/", + expected_start="first_parent_network/my_network/dense/", actual=net2.trainable_weights[0].name) self.assertStartsWith( - expected_start="second_parent_network_1/my_network_1/dense_1/", + expected_start="second_parent_network/my_network/dense/", actual=net2.trainable_weights[1].name) - self.assertEqual("second_parent_network_1", net2.name) + self.assertEqual("second_parent_network", net2.name) self.assertTrue(net2.first is net.first) - self.assertEqual("my_network_1", net2.first.name) - self.assertEqual("my_network_1", net2.second.name) + self.assertEqual("my_network", net2.first.name) + self.assertEqual("my_network", net2.second.name) # No name collision; the owned Network is added first and has a different # name than the shared Network. @@ -778,15 +980,15 @@ class NetworkTest(test.TestCase): net3(one) self.assertStartsWith( - expected_start="third_parent_network_1/my_network_1/dense", + expected_start="third_parent_network/my_network/dense", actual=net3.trainable_weights[0].name) self.assertStartsWith( - expected_start="first_parent_network_1/my_network_2/dense", + expected_start="first_parent_network/my_network_1/dense", actual=net3.trainable_weights[1].name) - self.assertEqual("third_parent_network_1", net3.name) + self.assertEqual("third_parent_network", net3.name) self.assertTrue(net3.second is net.second) - self.assertEqual("my_network_1", net3.first.name) - self.assertEqual("my_network_2", net3.second.name) + self.assertEqual("my_network", net3.first.name) + self.assertEqual("my_network_1", net3.second.name) # "Unavoidable" same-name Layer. The owned name is added first (fixed), then # a shared Network is added with the same name. @@ -804,15 +1006,15 @@ class NetworkTest(test.TestCase): net4(one) self.assertStartsWith( - expected_start="fourth_parent_network_1/my_network_1/dense_1/", + expected_start="fourth_parent_network/my_network/dense/", actual=net4.trainable_weights[0].name) self.assertStartsWith( - expected_start="first_parent_network_1/my_network_1/dense_1/", + expected_start="first_parent_network/my_network/dense/", actual=net4.trainable_weights[1].name) - self.assertEqual("fourth_parent_network_1", net4.name) + self.assertEqual("fourth_parent_network", net4.name) self.assertTrue(net4.second is net.first) - self.assertEqual("my_network_1", net4.first.name) - self.assertEqual("my_network_1", net4.second.name) + self.assertEqual("my_network", net4.first.name) + self.assertEqual("my_network", net4.second.name) @test_util.run_in_graph_and_eager_modes() def testRecursiveLayerRenaming(self): @@ -843,28 +1045,28 @@ class NetworkTest(test.TestCase): net(one) self.assertStartsWith( - expected_start=("parent_network_1/network_with_layer_children_1/" - "dense_1/"), + expected_start=("parent_network/network_with_layer_children/" + "dense/"), actual=net.trainable_weights[0].name) self.assertStartsWith( - expected_start=("parent_network_1/network_with_layer_children_1/" - "dense_2/"), + expected_start=("parent_network/network_with_layer_children/" + "dense_1/"), actual=net.trainable_weights[1].name) self.assertStartsWith( - expected_start=("parent_network_1/network_with_layer_children_2/" - "dense_1/"), + expected_start=("parent_network/network_with_layer_children_1/" + "dense/"), actual=net.trainable_weights[2].name) self.assertStartsWith( - expected_start=("parent_network_1/network_with_layer_children_2/" - "dense_2/"), + expected_start=("parent_network/network_with_layer_children_1/" + "dense_1/"), actual=net.trainable_weights[3].name) - self.assertEqual("parent_network_1", net.name) - self.assertEqual("network_with_layer_children_1", net.first.name) - self.assertEqual("network_with_layer_children_2", net.second.name) - self.assertEqual("dense_1", net.first.first.name) - self.assertEqual("dense_2", net.first.second.name) - self.assertEqual("dense_1", net.second.first.name) - self.assertEqual("dense_2", net.second.second.name) + self.assertEqual("parent_network", net.name) + self.assertEqual("network_with_layer_children", net.first.name) + self.assertEqual("network_with_layer_children_1", net.second.name) + self.assertEqual("dense", net.first.first.name) + self.assertEqual("dense_1", net.first.second.name) + self.assertEqual("dense", net.second.first.name) + self.assertEqual("dense_1", net.second.second.name) @test_util.run_in_graph_and_eager_modes() def testCallInDifferentOrderThanConstruct(self): @@ -898,23 +1100,23 @@ class NetworkTest(test.TestCase): net1(one) self.assertStartsWith( - expected_start="first_network_1/my_network_1/dense_1/", + expected_start="first_network/my_network/dense/", actual=net1.trainable_weights[0].name) self.assertStartsWith( - expected_start="first_network_1/my_network_2/dense_1/", + expected_start="first_network/my_network_1/dense/", actual=net1.trainable_weights[1].name) self.assertStartsWith( - expected_start="first_network_1/my_network_1/dense_1/", + expected_start="first_network/my_network/dense/", actual=net2.trainable_weights[0].name) self.assertStartsWith( - expected_start="second_network_1/my_network_1/dense_1/", + expected_start="second_network/my_network/dense/", actual=net2.trainable_weights[1].name) self.assertTrue(net1.trainable_weights[0] is net2.trainable_weights[0]) - self.assertEqual("first_network_1", net1.name) - self.assertEqual("my_network_1", net1.first.name) - self.assertEqual("my_network_2", net1.second.name) + self.assertEqual("first_network", net1.name) + self.assertEqual("my_network", net1.first.name) + self.assertEqual("my_network_1", net1.second.name) self.assertTrue(net2.first is net1.first) - self.assertEqual("my_network_1", net2.second.name) + self.assertEqual("my_network", net2.second.name) @test_util.run_in_graph_and_eager_modes() def testLayerCallInDifferentOrderThanConstruct(self): @@ -951,23 +1153,23 @@ class NetworkTest(test.TestCase): net1(one) self.assertStartsWith( - expected_start="first_network_1/dense_1/", + expected_start="first_network/dense/", actual=net1.trainable_weights[0].name) self.assertStartsWith( - expected_start="first_network_1/dense_2/", + expected_start="first_network/dense_1/", actual=net1.trainable_weights[1].name) self.assertStartsWith( - expected_start="first_network_1/dense_1/", + expected_start="first_network/dense/", actual=net2.trainable_weights[0].name) self.assertStartsWith( - expected_start="second_network_1/dense_1/", + expected_start="second_network/dense/", actual=net2.trainable_weights[1].name) self.assertTrue(net1.trainable_weights[0] is net2.trainable_weights[0]) - self.assertEqual("first_network_1", net1.name) - self.assertEqual("dense_1", net1.first.name) - self.assertEqual("dense_2", net1.second.name) + self.assertEqual("first_network", net1.name) + self.assertEqual("dense", net1.first.name) + self.assertEqual("dense_1", net1.second.name) self.assertTrue(net2.first is net1.first) - self.assertEqual("dense_1", net2.second.name) + self.assertEqual("dense", net2.second.name) @test_util.run_in_graph_and_eager_modes() def testLayerAlreadyBuilt(self): @@ -996,17 +1198,18 @@ class NetworkTest(test.TestCase): # do not match their layer names. actual=net.trainable_weights[0].name) self.assertStartsWith( - expected_start="first_network_1/dense_1/", + expected_start="first_network/dense/", actual=net.trainable_weights[1].name) self.assertTrue( net.trainable_weights[0] is shared_layer.trainable_weights[0]) - self.assertEqual("first_network_1", net.name) + self.assertEqual("first_network", net.name) self.assertEqual("dense_3", net.first.name) - self.assertEqual("dense_1", net.second.name) + self.assertEqual("dense", net.second.name) class SequentialTest(test.TestCase): + @test_util.assert_no_garbage_created def testTwoLayers(self): # Create a sequential network with one layer. net = network.Sequential([core.Dense(1, use_bias=False)]) @@ -1028,6 +1231,7 @@ class SequentialTest(test.TestCase): l2.trainable_variables[0].assign([[11.0]]) self.assertEqual(231.0, net(constant_op.constant([[7.0]])).numpy()) + @test_util.assert_no_garbage_created def testFunctions(self): # Create a sequential network with one function. net = network.Sequential([nn_ops.relu]) @@ -1038,6 +1242,7 @@ class SequentialTest(test.TestCase): net.add(math_ops.negative) self.assertEqual(-2.0, net(two).numpy()) + @test_util.assert_no_garbage_created def testTrainingLayer(self): net = network.Sequential([core.Dropout(0.99999)]) two = constant_op.constant(2.0) @@ -1051,6 +1256,7 @@ class SequentialTest(test.TestCase): # Should only fail spuriously 1 in 10^100 runs. self.fail("Didn't see dropout happen after 20 tries.") + @test_util.assert_no_garbage_created def testTrainingFunction(self): # Output depends on value of "training". def add_training(input_value, training=None): diff --git a/tensorflow/contrib/eager/python/saver.py b/tensorflow/contrib/eager/python/saver.py index e0a20d2485e831b1841991596b91429c6eaa2854..57b070ec6eeac00c77f199a846639d64c4957cd8 100644 --- a/tensorflow/contrib/eager/python/saver.py +++ b/tensorflow/contrib/eager/python/saver.py @@ -23,7 +23,6 @@ from tensorflow.python.eager import context from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.ops import resource_variable_ops -from tensorflow.python.training import adam as _adam from tensorflow.python.training import checkpoint_utils from tensorflow.python.training import saver as _saver @@ -171,20 +170,12 @@ class Saver(object): def get_optimizer_variables(optimizer): """Returns a list of variables for the given `tf.train.Optimizer`. + Equivalent to `optimizer.variables()`. + Args: optimizer: An instance of `tf.train.Optimizer` which has created variables (typically after a call to `Optimizer.minimize`). Returns: - A list of variables which have been created by the `Optimizer`. Currently - returns all variables even if they were not created in the default graph, - but this behavior may change. + A list of variables which have been created by the `Optimizer`. """ - variables = [] - # pylint: disable=protected-access - for _, variable_dict in optimizer._slots.items(): - for _, slot_for_variable in variable_dict.items(): - variables.append(slot_for_variable) - if isinstance(optimizer, _adam.AdamOptimizer): - variables.append(optimizer._beta1_power) - variables.append(optimizer._beta2_power) - return variables + return optimizer.variables() diff --git a/tensorflow/contrib/eager/python/summary_writer.py b/tensorflow/contrib/eager/python/summary_writer.py deleted file mode 100644 index 5d8c41b545b3c9fd03af85f302ba05a394f085a4..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/eager/python/summary_writer.py +++ /dev/null @@ -1,242 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""TensorBoard Summary Writer for TensorFlow Eager Execution.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import uuid - -from tensorflow.contrib.summary import gen_summary_ops -from tensorflow.python.eager import context -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import resource_variable_ops -from tensorflow.python.ops import state_ops -from tensorflow.python.ops import summary_op_util -from tensorflow.python.ops import variable_scope - - -def _maybe_cpu(v): - if isinstance(v, (ops.EagerTensor, ops.Tensor)): - return v.cpu() - else: - return v - - -def _summary_writer_function(name, tensor, function, family=None): - def record(): - with summary_op_util.summary_scope( - name, family, values=[tensor]) as (tag, scope): - function(tag, scope) - return True - return record - - -class SummaryWriter(object): - """Writes summaries for TensorBoard, compatible with eager execution. - - This class is the supported way of writing TensorBoard summaries under - eager execution. - """ - - _CPU_DEVICE = "cpu:0" - - def __init__(self, - logdir, - max_queue=10, - flush_secs=120, - filename_suffix=""): - """Summary writer for TensorBoard, compatible with eager execution. - - If necessary, multiple instances of `SummaryWriter` can be created, with - distinct `logdir`s and `name`s. Each `SummaryWriter` instance will retain - its independent `global_step` counter and data writing destination. - - Example: - ```python - writer = tfe.SummaryWriter("my_model") - - # ... Code that sets up the model and data batches ... - - for _ in xrange(train_iters): - loss = model.train_batch(batch) - writer.scalar("loss", loss) - writer.step() - ``` - - Args: - logdir: Directory in which summary files will be written. - max_queue: Number of summary items to buffer before flushing to - filesystem. If 0, summaries will be flushed immediately. - flush_secs: Number of secondsbetween forced commits to disk. - filename_suffix: Suffix of the event protobuf files in which the summary - data are stored. - - Raises: - ValueError: If this constructor is called not under eager execution. - """ - # TODO(apassos, ashankar): Make this class and the underlying - # contrib.summary_ops compatible with graph model and remove this check. - if not context.in_eager_mode(): - raise ValueError( - "Use of SummaryWriter is currently supported only with eager " - "execution enabled. File an issue at " - "https://github.com/tensorflow/tensorflow/issues/new to express " - "interest in fixing this.") - - # TODO(cais): Consider adding name keyword argument, which if None or empty, - # will register the global global_step that training_util.get_global_step() - # can find. - with context.device(self._CPU_DEVICE): - self._name = uuid.uuid4().hex - self._global_step = 0 - self._global_step_tensor = variable_scope.get_variable( - "global_step/summary_writer/" + self._name, - shape=[], dtype=dtypes.int64, - initializer=init_ops.zeros_initializer()) - self._global_step_dirty = False - self._resource = gen_summary_ops.summary_writer(shared_name=self._name) - gen_summary_ops.create_summary_file_writer( - self._resource, logdir, max_queue, flush_secs, filename_suffix) - # Delete the resource when this object is deleted - self._resource_deleter = resource_variable_ops.EagerResourceDeleter( - handle=self._resource, handle_device=self._CPU_DEVICE) - - def step(self): - """Increment the global step counter of this SummaryWriter instance.""" - self._global_step += 1 - self._global_step_dirty = True - - @property - def global_step(self): - """Obtain the current global_step value of this SummaryWriter instance. - - Returns: - An `int` representing the current value of the global_step of this - `SummaryWriter` instance. - """ - return self._global_step - - def _update_global_step_tensor(self): - with context.device(self._CPU_DEVICE): - if self._global_step_dirty: - self._global_step_dirty = False - return state_ops.assign(self._global_step_tensor, self._global_step) - else: - return self._global_step_tensor - - def generic(self, name, tensor, metadata, family=None): - """Write a generic-type summary. - - Args: - name: A name for the generated node. Will also serve as the series name in - TensorBoard. - tensor: A `Tensor` or compatible value type containing the value of the - summary. - metadata: Metadata about the summary. - family: Optional; if provided, used as the prefix of the summary tag name, - which controls the tab name used for display on Tensorboard. - """ - with context.device(self._CPU_DEVICE): - with summary_op_util.summary_scope( - name, family, values=[tensor]) as (tag, scope): - gen_summary_ops.write_summary( - self._resource, - self._update_global_step_tensor(), - _maybe_cpu(tensor), - tag, - _maybe_cpu(metadata), - name=scope) - - def scalar(self, name, tensor, family=None): - """Write a scalar summary. - - Args: - name: A name for the generated node. Will also serve as the series name in - TensorBoard. - tensor: A real numeric `Tensor` or compatible value type containing a - single value. - family: Optional; if provided, used as the prefix of the summary tag name, - which controls the tab name used for display on Tensorboard. - - Returns: - A summary writer function for scalars. - """ - with context.device(self._CPU_DEVICE): - with summary_op_util.summary_scope( - name, family, values=[tensor]) as (tag, scope): - gen_summary_ops.write_scalar_summary( - self._resource, self._update_global_step_tensor(), - tag, _maybe_cpu(tensor), name=scope) - - def histogram(self, name, tensor, family=None): - """Write a histogram summary. - - Args: - name: A name for the generated node. Will also serve as a series name in - TensorBoard. - tensor: A real numeric `Tensor` or compatible value type. Any shape. - Values to use to build the histogram. - family: Optional; if provided, used as the prefix of the summary tag name, - which controls the tab name used for display on Tensorboard. - """ - with context.device(self._CPU_DEVICE): - with summary_op_util.summary_scope( - name, family, values=[tensor]) as (tag, scope): - gen_summary_ops.write_histogram_summary( - self._resource, self._update_global_step_tensor(), - tag, _maybe_cpu(tensor), name=scope) - - def image(self, name, tensor, bad_color=None, max_images=3, family=None): - """Write an image summary.""" - with context.device(self._CPU_DEVICE): - if bad_color is None: - bad_color_ = constant_op.constant([255, 0, 0, 255], dtype=dtypes.uint8) - with summary_op_util.summary_scope( - name, family, values=[tensor]) as (tag, scope): - gen_summary_ops.write_image_summary( - self._resource, self._update_global_step_tensor(), - tag, _maybe_cpu(tensor), bad_color_, max_images, - name=scope) - - def audio(self, name, tensor, sample_rate, max_outputs, family=None): - """Write an audio summary. - - Args: - name: A name for the generated node. Will also serve as a series name in - TensorBoard. - tensor: A 3-D `float32` `Tensor` of shape `[batch_size, frames, channels]` - or a 2-D `float32` `Tensor` of shape `[batch_size, frames]`, or - compatible value type. - sample_rate: A Scalar `float32` `Tensor` indicating the sample rate of the - signal in hertz. - max_outputs: Max number of batch elements to generate audio for. - family: Optional; if provided, used as the prefix of the summary tag name, - which controls the tab name used for display on Tensorboard. - """ - with context.device(self._CPU_DEVICE): - with summary_op_util.summary_scope( - name, family, values=[tensor]) as (tag, scope): - gen_summary_ops.write_audio_summary( - self._resource, self._update_global_step_tensor(), - tag, - _maybe_cpu(tensor), - sample_rate=_maybe_cpu(sample_rate), - max_outputs=max_outputs, - name=scope) diff --git a/tensorflow/contrib/eager/python/summary_writer_test.py b/tensorflow/contrib/eager/python/summary_writer_test.py deleted file mode 100644 index 5ebb36d04fcba8f4558fa1c09716314af42f559f..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/eager/python/summary_writer_test.py +++ /dev/null @@ -1,150 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Unit tests for eager execution SummaryWriter.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os -import shutil -import tempfile - -import numpy as np - -from tensorflow.contrib.eager.python import summary_writer -from tensorflow.core.util import event_pb2 -from tensorflow.python.eager import context -from tensorflow.python.eager import test -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.lib.io import tf_record -from tensorflow.python.platform import gfile - - -class SummaryWriterTest(test.TestCase): - - def setUp(self): - super(SummaryWriterTest, self).setUp() - self._test_device = "gpu:0" if context.num_gpus() else "cpu:0" - self._tmp_logdir = tempfile.mkdtemp() - with context.device(self._test_device): - # Use max_queue=0 so that summaries are immediately flushed to filesystem, - # making testing easier. - self._writer = summary_writer.SummaryWriter(self._tmp_logdir, max_queue=0) - - def tearDown(self): - if os.path.isdir(self._tmp_logdir): - shutil.rmtree(self._tmp_logdir) - super(SummaryWriterTest, self).tearDown() - - def _readLastEvent(self, logdir=None): - if not logdir: - logdir = self._tmp_logdir - files = [f for f in gfile.ListDirectory(logdir) - if not gfile.IsDirectory(os.path.join(logdir, f))] - file_path = os.path.join(logdir, files[0]) - records = list(tf_record.tf_record_iterator(file_path)) - event = event_pb2.Event() - event.ParseFromString(records[-1]) - return event - - def testGlobalStep(self): - with context.device(self._test_device): - orig_step = self._writer.global_step - self._writer.step() - self.assertEqual(orig_step + 1, self._writer.global_step) - self.assertEqual(orig_step + 1, self._writer.global_step) - self._writer.step() - self._writer.step() - self.assertEqual(orig_step + 3, self._writer.global_step) - - def testGenericSummary(self): - with context.device(self._test_device): - x = constant_op.constant(1337.0) - with context.device("cpu:0"): - metadata = constant_op.constant("foo") - self._writer.generic("x", x, metadata) - event = self._readLastEvent() - self.assertEqual("x", event.summary.value[0].tag) - - def testScalarSummary(self): - with context.device(self._test_device): - x = constant_op.constant(1337.0) - self._writer.scalar("x", x) - event = self._readLastEvent() - self.assertTrue("x", event.summary.value[0].tag) - self.assertEqual(1337.0, event.summary.value[0].simple_value) - - def testHistogramSummary(self): - with context.device(self._test_device): - y = constant_op.constant([1.0, 3.0, 3.0, 7.0]) - self._writer.histogram("y", y) - event = self._readLastEvent() - self.assertEqual("y", event.summary.value[0].tag) - self.assertTrue(event.summary.value[0].histo) - - def testImageSummary(self): - with context.device(self._test_device): - a = constant_op.constant([[10.0, 20.0], [-20.0, -10.0]]) - self._writer.histogram("image1", a) - event = self._readLastEvent() - self.assertEqual("image1", event.summary.value[0].tag) - self.assertTrue(event.summary.value[0].image) - - def testAudioSummary(self): - with context.device(self._test_device): - w = constant_op.constant(np.random.rand(3, 10, 2), dtype=dtypes.float32) - fs = constant_op.constant(44100.0, dtype=dtypes.float32) - max_outputs = 1 - self._writer.audio("audio1", w, fs, max_outputs) - event = self._readLastEvent() - self.assertTrue(event.summary.value[0].audio) - - def testTwoSummaryWritersGlobalStepsWorkWithoutCrosstalk(self): - tmp_logdir2 = os.path.join(self._tmp_logdir, "_writer2_") - writer2 = summary_writer.SummaryWriter(tmp_logdir2, max_queue=0) - - self.assertEqual(0, writer2.global_step) - self._writer.step() - self.assertEqual(0, writer2.global_step) - writer2.step() - writer2.step() - writer2.step() - self.assertEqual(3, writer2.global_step) - - x = constant_op.constant(1337.0) - writer_orig_step = self._writer.global_step - self._writer.step() - self._writer.scalar("x", x) - - event = self._readLastEvent() - self.assertEqual(writer_orig_step + 1, event.step) - - writer2.scalar("x", x) - event = self._readLastEvent(tmp_logdir2) - self.assertEqual(3, event.step) - - self._writer.step() - self._writer.scalar("x", x) - - event = self._readLastEvent() - self.assertEqual(writer_orig_step + 2, event.step) - - -# TODO(cais): Add performance benchmark for SummaryWriter. - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/eager/python/tfe.py b/tensorflow/contrib/eager/python/tfe.py index b6c687c82946ec62ccb90165791587dc335f13c7..770a7e3e7a01f3351c229b7fb53383240dd1f1c8 100644 --- a/tensorflow/contrib/eager/python/tfe.py +++ b/tensorflow/contrib/eager/python/tfe.py @@ -23,6 +23,7 @@ To use, at program startup, call `tfe.enable_eager_execution()`. @@list_devices @@num_gpus +@@py_func @@defun @@implicit_gradients @@implicit_value_and_gradients @@ -30,9 +31,6 @@ To use, at program startup, call `tfe.enable_eager_execution()`. @@value_and_gradients_function @@GradientTape -@@enable_tracing -@@flush_trace - @@run @@enable_eager_execution @@ -46,13 +44,16 @@ To use, at program startup, call `tfe.enable_eager_execution()`. @@seterr @@Iterator -@@Network @@Saver @@restore_variables_on_create @@Variable @@get_optimizer_variables @@EagerVariableStore +@@Network +@@save_network_checkpoint +@@restore_network_checkpoint + @@in_eager_mode @@in_graph_mode @@ -74,6 +75,8 @@ from __future__ import print_function from tensorflow.contrib.eager.python import metrics from tensorflow.contrib.eager.python.datasets import Iterator from tensorflow.contrib.eager.python.network import Network +from tensorflow.contrib.eager.python.network import save_network_checkpoint +from tensorflow.contrib.eager.python.network import restore_network_checkpoint from tensorflow.contrib.eager.python.saver import get_optimizer_variables from tensorflow.contrib.eager.python.saver import restore_variables_on_create from tensorflow.contrib.eager.python.saver import Saver @@ -86,7 +89,6 @@ from tensorflow.python.eager.context import in_eager_mode from tensorflow.python.eager.context import in_graph_mode from tensorflow.python.eager.context import list_devices from tensorflow.python.eager.context import num_gpus -from tensorflow.python.eager.core import enable_tracing from tensorflow.python.eager.custom_gradient import custom_gradient from tensorflow.python.eager.execution_callbacks import add_execution_callback from tensorflow.python.eager.execution_callbacks import clear_execution_callbacks @@ -100,8 +102,10 @@ from tensorflow.python.framework.test_util import IsolateTest from tensorflow.python.framework.test_util import run_in_graph_and_eager_modes as run_test_in_graph_and_eager_modes from tensorflow.python.ops.resource_variable_ops import ResourceVariable as Variable from tensorflow.python.ops.variable_scope import EagerVariableStore +from tensorflow.python.ops import script_ops from tensorflow.python.util.all_util import remove_undocumented +py_func = script_ops.eager_py_func defun = function.defun implicit_gradients = backprop.implicit_grad implicit_value_and_gradients = backprop.implicit_val_and_grad diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD index a0f83ac10555913b5be177f0f2b00b2b0e30494a..ba272d7e885434eb556cbafd3d9e64a50d21f9b2 100644 --- a/tensorflow/contrib/estimator/BUILD +++ b/tensorflow/contrib/estimator/BUILD @@ -7,6 +7,7 @@ package( licenses(["notice"]) # Apache 2.0 load("//tensorflow:tensorflow.bzl", "py_test") +load("//tensorflow:tensorflow.bzl", "cuda_py_test") filegroup( name = "all_files", @@ -26,10 +27,13 @@ py_library( srcs_version = "PY2AND3", deps = [ ":dnn", + ":dnn_linear_combined", ":extenders", ":head", + ":linear", ":logit_fns", ":multi_head", + ":replicate_model_fn", "//tensorflow/python:util", ], ) @@ -71,6 +75,46 @@ py_test( ], ) +py_library( + name = "dnn_linear_combined", + srcs = ["python/estimator/dnn_linear_combined.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:nn", + "//tensorflow/python/estimator", + "//tensorflow/python/estimator:dnn_linear_combined", + ], +) + +py_test( + name = "dnn_linear_combined_test", + size = "medium", + srcs = ["python/estimator/dnn_linear_combined_test.py"], + shard_count = 3, + srcs_version = "PY2AND3", + tags = [ + "no_pip", + "notsan", + ], + deps = [ + ":dnn_linear_combined", + ":head", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_ops", + "//tensorflow/python:nn", + "//tensorflow/python:platform", + "//tensorflow/python:summary", + "//tensorflow/python/estimator:dnn_testing_utils", + "//tensorflow/python/estimator:export_export", + "//tensorflow/python/estimator:linear_testing_utils", + "//tensorflow/python/estimator:numpy_io", + "//tensorflow/python/estimator:prediction_keys", + "//tensorflow/python/feature_column", + "//third_party/py/numpy", + "@six_archive//:six", + ], +) + py_library( name = "extenders", srcs = [ @@ -167,6 +211,42 @@ py_test( ], ) +py_library( + name = "linear", + srcs = ["python/estimator/linear.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python/estimator", + "//tensorflow/python/estimator:linear", + ], +) + +py_test( + name = "linear_test", + size = "small", + srcs = ["python/estimator/linear_test.py"], + srcs_version = "PY2AND3", + tags = [ + "no_pip", + "notsan", + ], + deps = [ + ":head", + ":linear", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_ops", + "//tensorflow/python:platform", + "//tensorflow/python:summary", + "//tensorflow/python/estimator:export_export", + "//tensorflow/python/estimator:linear_testing_utils", + "//tensorflow/python/estimator:numpy_io", + "//tensorflow/python/estimator:prediction_keys", + "//tensorflow/python/feature_column", + "//third_party/py/numpy", + "@six_archive//:six", + ], +) + py_library( name = "logit_fns", srcs = [ @@ -202,10 +282,14 @@ py_library( ], srcs_version = "PY2AND3", deps = [ + "//tensorflow/python:array_ops", "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", + "//tensorflow/python:metrics", + "//tensorflow/python:summary", "//tensorflow/python/estimator:head", + "//tensorflow/python/estimator:metric_keys", "//tensorflow/python/estimator:model_fn", "//tensorflow/python/saved_model:signature_constants", "@six_archive//:six", @@ -233,3 +317,63 @@ py_test( "@six_archive//:six", ], ) + +py_library( + name = "replicate_model_fn", + srcs = [ + "python/estimator/replicate_model_fn.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:device", + "//tensorflow/python:device_lib", + "//tensorflow/python:framework_ops", + "//tensorflow/python:gradients", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform", + "//tensorflow/python:state_ops", + "//tensorflow/python:training", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + "//tensorflow/python/estimator:export_output", + "//tensorflow/python/estimator:model_fn", + "//tensorflow/python/estimator:util", + "@six_archive//:six", + ], +) + +cuda_py_test( + name = "replicate_model_fn_test", + size = "medium", + srcs = ["python/estimator/replicate_model_fn_test.py"], + additional_deps = [ + "//tensorflow/python/estimator", + "//tensorflow/python/estimator:dnn", + "//tensorflow/python/estimator:export_export", + "//tensorflow/python/estimator:export_output", + "//tensorflow/python/estimator:model_fn", + "//tensorflow/python/estimator:numpy_io", + "//tensorflow/python/estimator:optimizers", + "//tensorflow/python/estimator:prediction_keys", + "//tensorflow/python/feature_column", + "//tensorflow/python/ops/losses", + "//tensorflow/python/saved_model:signature_constants", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:metrics", + "//tensorflow/python:platform", + "//tensorflow/python:summary", + "//tensorflow/python:training", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + ":replicate_model_fn", + ], + tags = ["multi_gpu"], +) diff --git a/tensorflow/contrib/estimator/__init__.py b/tensorflow/contrib/estimator/__init__.py index cf727264cd5116915f6bd7f285e470cbc2e2742a..28c1f8b1809d27db697365b7bb50441f7820d2b4 100644 --- a/tensorflow/contrib/estimator/__init__.py +++ b/tensorflow/contrib/estimator/__init__.py @@ -20,10 +20,13 @@ from __future__ import print_function # pylint: disable=unused-import,line-too-long,wildcard-import from tensorflow.contrib.estimator.python.estimator.dnn import * +from tensorflow.contrib.estimator.python.estimator.dnn_linear_combined import * from tensorflow.contrib.estimator.python.estimator.extenders import * from tensorflow.contrib.estimator.python.estimator.head import * +from tensorflow.contrib.estimator.python.estimator.linear import * from tensorflow.contrib.estimator.python.estimator.logit_fns import * from tensorflow.contrib.estimator.python.estimator.multi_head import * +from tensorflow.contrib.estimator.python.estimator.replicate_model_fn import * from tensorflow.python.util.all_util import remove_undocumented # pylint: enable=unused-import,line-too-long,wildcard-import @@ -38,9 +41,12 @@ _allowed_symbols = [ 'multi_label_head', 'regression_head', 'DNNEstimator', + 'DNNLinearCombinedEstimator', + 'LinearEstimator', 'call_logit_fn', 'dnn_logit_fn_builder', 'linear_logit_fn_builder', + 'replicate_model_fn', ] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined.py b/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined.py new file mode 100644 index 0000000000000000000000000000000000000000..ccaf1128bf23af734f7a5722a4dd8c1f0304fab7 --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined.py @@ -0,0 +1,164 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""TensorFlow estimator for Linear and DNN joined training models.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.estimator import estimator +from tensorflow.python.estimator.canned import dnn_linear_combined as dnn_linear_combined_lib +from tensorflow.python.ops import nn + + +class DNNLinearCombinedEstimator(estimator.Estimator): + """An estimator for TensorFlow Linear and DNN joined models with custom head. + + Note: This estimator is also known as wide-n-deep. + + Example: + + ```python + numeric_feature = numeric_column(...) + categorical_column_a = categorical_column_with_hash_bucket(...) + categorical_column_b = categorical_column_with_hash_bucket(...) + + categorical_feature_a_x_categorical_feature_b = crossed_column(...) + categorical_feature_a_emb = embedding_column( + categorical_column=categorical_feature_a, ...) + categorical_feature_b_emb = embedding_column( + categorical_column=categorical_feature_b, ...) + + estimator = DNNLinearCombinedEstimator( + head=tf.contrib.estimator.multi_label_head(n_classes=3), + # wide settings + linear_feature_columns=[categorical_feature_a_x_categorical_feature_b], + linear_optimizer=tf.train.FtrlOptimizer(...), + # deep settings + dnn_feature_columns=[ + categorical_feature_a_emb, categorical_feature_b_emb, + numeric_feature], + dnn_hidden_units=[1000, 500, 100], + dnn_optimizer=tf.train.ProximalAdagradOptimizer(...)) + + # To apply L1 and L2 regularization, you can set optimizers as follows: + tf.train.ProximalAdagradOptimizer( + learning_rate=0.1, + l1_regularization_strength=0.001, + l2_regularization_strength=0.001) + # It is same for FtrlOptimizer. + + # Input builders + def input_fn_train: # returns x, y + pass + estimator.train(input_fn=input_fn_train, steps=100) + + def input_fn_eval: # returns x, y + pass + metrics = estimator.evaluate(input_fn=input_fn_eval, steps=10) + def input_fn_predict: # returns x, None + pass + predictions = estimator.predict(input_fn=input_fn_predict) + ``` + + Input of `train` and `evaluate` should have following features, + otherwise there will be a `KeyError`: + + * for each `column` in `dnn_feature_columns` + `linear_feature_columns`: + - if `column` is a `_CategoricalColumn`, a feature with `key=column.name` + whose `value` is a `SparseTensor`. + - if `column` is a `_WeightedCategoricalColumn`, two features: the first + with `key` the id column name, the second with `key` the weight column + name. Both features' `value` must be a `SparseTensor`. + - if `column` is a `_DenseColumn`, a feature with `key=column.name` + whose `value` is a `Tensor`. + + Loss is calculated by using mean squared error. + + @compatibility(eager) + Estimators are not compatible with eager execution. + @end_compatibility + """ + + def __init__(self, + head, + model_dir=None, + linear_feature_columns=None, + linear_optimizer='Ftrl', + dnn_feature_columns=None, + dnn_optimizer='Adagrad', + dnn_hidden_units=None, + dnn_activation_fn=nn.relu, + dnn_dropout=None, + input_layer_partitioner=None, + config=None): + """Initializes a DNNLinearCombinedEstimator instance. + + Args: + head: A `_Head` instance constructed with a method such as + `tf.contrib.estimator.multi_label_head`. + model_dir: Directory to save model parameters, graph and etc. This can + also be used to load checkpoints from the directory into a estimator + to continue training a previously saved model. + linear_feature_columns: An iterable containing all the feature columns + used by linear part of the model. All items in the set must be + instances of classes derived from `FeatureColumn`. + linear_optimizer: An instance of `tf.Optimizer` used to apply gradients to + the linear part of the model. Defaults to FTRL optimizer. + dnn_feature_columns: An iterable containing all the feature columns used + by deep part of the model. All items in the set must be instances of + classes derived from `FeatureColumn`. + dnn_optimizer: An instance of `tf.Optimizer` used to apply gradients to + the deep part of the model. Defaults to Adagrad optimizer. + dnn_hidden_units: List of hidden units per layer. All layers are fully + connected. + dnn_activation_fn: Activation function applied to each layer. If None, + will use `tf.nn.relu`. + dnn_dropout: When not None, the probability we will drop out + a given coordinate. + input_layer_partitioner: Partitioner for input layer. Defaults to + `min_max_variable_partitioner` with `min_slice_size` 64 << 20. + config: RunConfig object to configure the runtime settings. + + Raises: + ValueError: If both linear_feature_columns and dnn_features_columns are + empty at the same time. + """ + linear_feature_columns = linear_feature_columns or [] + dnn_feature_columns = dnn_feature_columns or [] + self._feature_columns = ( + list(linear_feature_columns) + list(dnn_feature_columns)) + if not self._feature_columns: + raise ValueError('Either linear_feature_columns or dnn_feature_columns ' + 'must be defined.') + + def _model_fn(features, labels, mode, config): + return dnn_linear_combined_lib._dnn_linear_combined_model_fn( # pylint: disable=protected-access + features=features, + labels=labels, + mode=mode, + head=head, + linear_feature_columns=linear_feature_columns, + linear_optimizer=linear_optimizer, + dnn_feature_columns=dnn_feature_columns, + dnn_optimizer=dnn_optimizer, + dnn_hidden_units=dnn_hidden_units, + dnn_activation_fn=dnn_activation_fn, + dnn_dropout=dnn_dropout, + input_layer_partitioner=input_layer_partitioner, + config=config) + + super(DNNLinearCombinedEstimator, self).__init__( + model_fn=_model_fn, model_dir=model_dir, config=config) diff --git a/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined_test.py b/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined_test.py new file mode 100644 index 0000000000000000000000000000000000000000..b5e4d34dc70ccaa4806ae8b8ed5001bd971ee7b4 --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined_test.py @@ -0,0 +1,220 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for dnn_linear_combined.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import shutil +import tempfile + +import numpy as np +import six + +from tensorflow.contrib.estimator.python.estimator import dnn_linear_combined +from tensorflow.contrib.estimator.python.estimator import head as head_lib +from tensorflow.python.estimator.canned import dnn_testing_utils +from tensorflow.python.estimator.canned import linear_testing_utils +from tensorflow.python.estimator.canned import prediction_keys +from tensorflow.python.estimator.export import export +from tensorflow.python.estimator.inputs import numpy_io +from tensorflow.python.feature_column import feature_column +from tensorflow.python.framework import ops +from tensorflow.python.ops import nn +from tensorflow.python.platform import gfile +from tensorflow.python.platform import test +from tensorflow.python.summary.writer import writer_cache + + +def _dnn_only_estimator_fn( + hidden_units, + feature_columns, + model_dir=None, + label_dimension=1, + weight_column=None, + optimizer='Adagrad', + activation_fn=nn.relu, + dropout=None, + input_layer_partitioner=None, + config=None): + return dnn_linear_combined.DNNLinearCombinedEstimator( + head=head_lib.regression_head( + weight_column=weight_column, label_dimension=label_dimension), + model_dir=model_dir, + dnn_feature_columns=feature_columns, + dnn_optimizer=optimizer, + dnn_hidden_units=hidden_units, + dnn_activation_fn=activation_fn, + dnn_dropout=dropout, + input_layer_partitioner=input_layer_partitioner, + config=config) + + +class DNNOnlyEstimatorEvaluateTest( + dnn_testing_utils.BaseDNNRegressorEvaluateTest, test.TestCase): + + def __init__(self, methodName='runTest'): # pylint: disable=invalid-name + test.TestCase.__init__(self, methodName) + dnn_testing_utils.BaseDNNRegressorEvaluateTest.__init__( + self, _dnn_only_estimator_fn) + + +class DNNOnlyEstimatorPredictTest( + dnn_testing_utils.BaseDNNRegressorPredictTest, test.TestCase): + + def __init__(self, methodName='runTest'): # pylint: disable=invalid-name + test.TestCase.__init__(self, methodName) + dnn_testing_utils.BaseDNNRegressorPredictTest.__init__( + self, _dnn_only_estimator_fn) + + +class DNNOnlyEstimatorTrainTest( + dnn_testing_utils.BaseDNNRegressorTrainTest, test.TestCase): + + def __init__(self, methodName='runTest'): # pylint: disable=invalid-name + test.TestCase.__init__(self, methodName) + dnn_testing_utils.BaseDNNRegressorTrainTest.__init__( + self, _dnn_only_estimator_fn) + + +def _linear_only_estimator_fn( + feature_columns, + model_dir=None, + label_dimension=1, + weight_column=None, + optimizer='Ftrl', + config=None, + partitioner=None): + return dnn_linear_combined.DNNLinearCombinedEstimator( + head=head_lib.regression_head( + weight_column=weight_column, label_dimension=label_dimension), + model_dir=model_dir, + linear_feature_columns=feature_columns, + linear_optimizer=optimizer, + input_layer_partitioner=partitioner, + config=config) + + +class LinearOnlyEstimatorEvaluateTest( + linear_testing_utils.BaseLinearRegressorEvaluationTest, test.TestCase): + + def __init__(self, methodName='runTest'): # pylint: disable=invalid-name + test.TestCase.__init__(self, methodName) + linear_testing_utils.BaseLinearRegressorEvaluationTest.__init__( + self, _linear_only_estimator_fn) + + +class LinearOnlyEstimatorPredictTest( + linear_testing_utils.BaseLinearRegressorPredictTest, test.TestCase): + + def __init__(self, methodName='runTest'): # pylint: disable=invalid-name + test.TestCase.__init__(self, methodName) + linear_testing_utils.BaseLinearRegressorPredictTest.__init__( + self, _linear_only_estimator_fn) + + +class LinearOnlyEstimatorTrainTest( + linear_testing_utils.BaseLinearRegressorTrainingTest, test.TestCase): + + def __init__(self, methodName='runTest'): # pylint: disable=invalid-name + test.TestCase.__init__(self, methodName) + linear_testing_utils.BaseLinearRegressorTrainingTest.__init__( + self, _linear_only_estimator_fn) + + +class DNNLinearCombinedEstimatorIntegrationTest(test.TestCase): + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + + def tearDown(self): + if self._model_dir: + writer_cache.FileWriterCache.clear() + shutil.rmtree(self._model_dir) + + def _test_complete_flow( + self, train_input_fn, eval_input_fn, predict_input_fn, input_dimension, + label_dimension, batch_size): + linear_feature_columns = [ + feature_column.numeric_column('x', shape=(input_dimension,))] + dnn_feature_columns = [ + feature_column.numeric_column('x', shape=(input_dimension,))] + feature_columns = linear_feature_columns + dnn_feature_columns + est = dnn_linear_combined.DNNLinearCombinedEstimator( + head=head_lib.regression_head(label_dimension=label_dimension), + linear_feature_columns=linear_feature_columns, + dnn_feature_columns=dnn_feature_columns, + dnn_hidden_units=(2, 2), + model_dir=self._model_dir) + + # TRAIN + num_steps = 10 + est.train(train_input_fn, steps=num_steps) + + # EVALUTE + scores = est.evaluate(eval_input_fn) + self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP]) + self.assertIn('loss', six.iterkeys(scores)) + + # PREDICT + predictions = np.array([ + x[prediction_keys.PredictionKeys.PREDICTIONS] + for x in est.predict(predict_input_fn) + ]) + self.assertAllEqual((batch_size, label_dimension), predictions.shape) + + # EXPORT + feature_spec = feature_column.make_parse_example_spec(feature_columns) + serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn( + feature_spec) + export_dir = est.export_savedmodel(tempfile.mkdtemp(), + serving_input_receiver_fn) + self.assertTrue(gfile.Exists(export_dir)) + + def test_numpy_input_fn(self): + """Tests complete flow with numpy_input_fn.""" + label_dimension = 2 + batch_size = 10 + data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32) + data = data.reshape(batch_size, label_dimension) + # learn y = x + train_input_fn = numpy_io.numpy_input_fn( + x={'x': data}, + y=data, + batch_size=batch_size, + num_epochs=None, + shuffle=True) + eval_input_fn = numpy_io.numpy_input_fn( + x={'x': data}, + y=data, + batch_size=batch_size, + shuffle=False) + predict_input_fn = numpy_io.numpy_input_fn( + x={'x': data}, + batch_size=batch_size, + shuffle=False) + + self._test_complete_flow( + train_input_fn=train_input_fn, + eval_input_fn=eval_input_fn, + predict_input_fn=predict_input_fn, + input_dimension=label_dimension, + label_dimension=label_dimension, + batch_size=batch_size) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/estimator/python/estimator/head.py b/tensorflow/contrib/estimator/python/estimator/head.py index 189f098005b8926bfb30b723cc989cb854a5d77e..a9311a20f127d92f02a95b8b48082fc90850635a 100644 --- a/tensorflow/contrib/estimator/python/estimator/head.py +++ b/tensorflow/contrib/estimator/python/estimator/head.py @@ -28,6 +28,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import math_ops @@ -48,7 +49,20 @@ def multi_class_head(n_classes, Uses `sparse_softmax_cross_entropy` loss. - This head expects to be fed integer labels specifying the class index. + The head expects `logits` with shape `[D0, D1, ... DN, n_classes]`. + In many applications, the shape is `[batch_size, n_classes]`. + + `labels` must be a dense `Tensor` with shape matching `logits`, namely + `[D0, D1, ... DN, 1]`. If `label_vocabulary` given, `labels` must be a string + `Tensor` with values from the vocabulary. If `label_vocabulary` is not given, + `labels` must be an integer `Tensor` with values specifying the class index. + + If `weight_column` is specified, weights must be of shape + `[D0, D1, ... DN]`, or `[D0, D1, ... DN, 1]`. + + The loss is the weighted sum over the input dimensions. Namely, if the input + labels have shape `[batch_size, 1]`, the loss is the weighted sum over + `batch_size`. Args: n_classes: Number of classes, must be greater than 2 (for 2 classes, use @@ -57,11 +71,11 @@ def multi_class_head(n_classes, `tf.feature_column.numeric_column` defining feature column representing weights. It is used to down weight or boost examples during training. It will be multiplied by the loss of the example. - label_vocabulary: A list of strings represents possible label values. If it - is not given, that means labels are already encoded as integer within - [0, n_classes). If given, labels must be string type and have any value in - `label_vocabulary`. Also there will be errors if vocabulary is not - provided and labels are string. + label_vocabulary: A list or tuple of strings representing possible label + values. If it is not given, that means labels are already encoded as an + integer within [0, n_classes). If given, labels must be of string type and + have any value in `label_vocabulary`. Note that errors will be raised if + `label_vocabulary` is not provided but labels are strings. name: name of the head. If provided, summary and metrics keys will be suffixed by `"/" + name`. Also used as `name_scope` when creating ops. @@ -84,7 +98,20 @@ def binary_classification_head( This head uses `sigmoid_cross_entropy_with_logits` loss. - This head expects to be fed float labels of shape `(batch_size, 1)`. + The head expects `logits` with shape `[D0, D1, ... DN, 1]`. + In many applications, the shape is `[batch_size, 1]`. + + `labels` must be a dense `Tensor` with shape matching `logits`, namely + `[D0, D1, ... DN, 1]`. If `label_vocabulary` given, `labels` must be a string + `Tensor` with values from the vocabulary. If `label_vocabulary` is not given, + `labels` must be float `Tensor` with values in the interval `[0, 1]`. + + If `weight_column` is specified, weights must be of shape + `[D0, D1, ... DN]`, or `[D0, D1, ... DN, 1]`. + + The loss is the weighted sum over the input dimensions. Namely, if the input + labels have shape `[batch_size, 1]`, the loss is the weighted sum over + `batch_size`. Args: weight_column: A string or a `_NumericColumn` created by @@ -96,11 +123,11 @@ def binary_classification_head( generated for each threshold value. This threshold is applied to the logistic values to determine the binary classification (i.e., above the threshold is `true`, below is `false`. - label_vocabulary: A list of strings represents possible label values. If it - is not given, that means labels are already encoded within [0, 1]. If - given, labels must be string type and have any value in - `label_vocabulary`. Also there will be errors if vocabulary is not - provided and labels are string. + label_vocabulary: A list or tuple of strings representing possible label + values. If it is not given, labels must be float with values within + [0, 1]. If given, labels must be string type and have any value in + `label_vocabulary`. Note that errors will be raised if `label_vocabulary` + is not provided but labels are strings. name: name of the head. If provided, summary and metrics keys will be suffixed by `"/" + name`. Also used as `name_scope` when creating ops. @@ -120,9 +147,22 @@ def binary_classification_head( def regression_head(weight_column=None, label_dimension=1, name=None): - """Creates a `_Head` for regression using the mean squared loss. + """Creates a `_Head` for regression using the `mean_squared_error` loss. - Uses `mean_squared_error` loss. + The loss is the weighted sum over all input dimensions. Namely, if the input + labels have shape `[batch_size, label_dimension]`, the loss is the weighted + sum over both `batch_size` and `label_dimension`. + + The head expects `logits` with shape `[D0, D1, ... DN, label_dimension]`. + In many applications, the shape is `[batch_size, label_dimension]`. + + The `labels` shape must match `logits`, namely + `[D0, D1, ... DN, label_dimension]`. If `label_dimension=1`, shape + `[D0, D1, ... DN]` is also supported. + + If `weight_column` is specified, weights must be of shape + `[D0, D1, ... DN]`, `[D0, D1, ... DN, 1]` or + `[D0, D1, ... DN, label_dimension]`. Args: weight_column: A string or a `_NumericColumn` created by @@ -156,15 +196,29 @@ def multi_label_head(n_classes, or more associated labels, from a discrete set. This is distinct from `multi_class_head` which has exactly one label per example. - Uses `sigmoid_cross_entropy` loss averaged over classes. Expects labels as a - multi-hot tensor of shape `[batch_size, n_classes]`, or as an integer - `SparseTensor` of class indices. + Uses `sigmoid_cross_entropy` loss average over classes and weighted sum over + the batch. Namely, if the input logits have shape `[batch_size, n_classes]`, + the loss is the average over `n_classes` and the weighted sum over + `batch_size`. + + The head expects `logits` with shape `[D0, D1, ... DN, n_classes]`. In many + applications, the shape is `[batch_size, label_n_classes]`. + + Labels can be: + * A multi-hot tensor of shape `[D0, D1, ... DN, n_classes]` + * An integer `SparseTensor` of class indices. The `dense_shape` must be + `[D0, D1, ... DN, ?]` and the values within `[0, n_classes)`. + * If `label_vocabulary` is given, a string `SparseTensor`. The `dense_shape` + must be `[D0, D1, ... DN, ?]` and the values within `label_vocabulary`. + + If `weight_column` is specified, weights must be of shape + `[D0, D1, ... DN]`, or `[D0, D1, ... DN, 1]`. Also supports custom `loss_fn`. `loss_fn` takes `(labels, logits)` or `(labels, logits, features)` as arguments and returns unreduced loss with - shape `[batch_size, 1]`. `loss_fn` must support indicator `labels` with shape - `[batch_size, n_classes]`. Namely, the head applies `label_vocabulary` to the - input labels before passing them to `loss_fn`. + shape `[D0, D1, ... DN, 1]`. `loss_fn` must support indicator `labels` with + shape `[D0, D1, ... DN, n_classes]`. Namely, the head applies + `label_vocabulary` to the input labels before passing them to `loss_fn`. Args: n_classes: Number of classes, must be greater than 1 (for 1 class, use @@ -172,7 +226,8 @@ def multi_label_head(n_classes, weight_column: A string or a `_NumericColumn` created by `tf.feature_column.numeric_column` defining feature column representing weights. It is used to down weight or boost examples during training. It - will be multiplied by the loss of the example. + will be multiplied by the loss of the example. Per-class weighting is + not supported. thresholds: Iterable of floats in the range `(0, 1)`. Accuracy, precision and recall metrics are evaluated for each threshold value. The threshold is applied to the predicted probabilities, i.e. above the threshold is @@ -190,7 +245,7 @@ def multi_label_head(n_classes, An instance of `_Head` for multi-label classification. Raises: - ValueError: if `n_classes` or `thresholds` is invalid. + ValueError: if `n_classes`, `thresholds`, or `loss_fn` is invalid. """ thresholds = tuple(thresholds) if thresholds else tuple() if n_classes is None or n_classes < 2: @@ -258,26 +313,36 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access indices=labels.indices, values=label_ids_values, dense_shape=labels.dense_shape) + return math_ops.to_int64( + sparse_ops.sparse_to_indicator(label_ids, self._n_classes)) else: - label_ids = labels - return math_ops.to_int64( - sparse_ops.sparse_to_indicator(label_ids, self._n_classes)) - msg = ('labels shape must be [batch_size, {}]. ' - 'Given: ').format(self._n_classes) - labels_shape = array_ops.shape(labels) - check_rank_op = control_flow_ops.Assert( - math_ops.equal(array_ops.rank(labels), 2), - data=[msg, labels_shape]) - check_label_dim = control_flow_ops.Assert( - math_ops.equal(labels_shape[-1], self._n_classes), - data=[msg, labels_shape]) - with ops.control_dependencies([check_rank_op, check_label_dim]): - return array_ops.identity(labels) + err_msg = ( + r'labels must be an integer SparseTensor with values in ' + r'[0, {})'.format(self._n_classes)) + assert_int = check_ops.assert_integer( + labels.values, message=err_msg) + assert_less = check_ops.assert_less( + labels.values, + ops.convert_to_tensor(self._n_classes, dtype=labels.dtype), + message=err_msg) + assert_greater = check_ops.assert_non_negative( + labels.values, message=err_msg) + with ops.control_dependencies( + [assert_int, assert_less, assert_greater]): + return math_ops.to_int64( + sparse_ops.sparse_to_indicator(labels, self._n_classes)) + err_msg = ( + r'labels must be an integer indicator Tensor with values in [0, 1]') + return head_lib._assert_range(labels, 2, message=err_msg) # pylint:disable=protected-access, def create_loss(self, features, mode, logits, labels): """See `Head`.""" del mode # Unused for this head. + logits = ops.convert_to_tensor(logits) processed_labels = self._process_labels(labels) + processed_labels = head_lib._check_dense_labels_match_logits_and_reshape( # pylint:disable=protected-access + labels=processed_labels, logits=logits, + expected_labels_dimension=self.logits_dimension) if self._loss_fn: unweighted_loss = _call_loss_fn( loss_fn=self._loss_fn, labels=processed_labels, logits=logits, @@ -289,15 +354,23 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access # Averages loss over classes. unweighted_loss = math_ops.reduce_mean( unweighted_loss, axis=-1, keep_dims=True) - return head_lib.LossAndLabels( - unweighted_loss=unweighted_loss, + weights = head_lib._get_weights_and_check_match_logits( # pylint:disable=protected-access, + features=features, weight_column=self._weight_column, logits=logits) + weighted_sum_loss = losses.compute_weighted_loss( + unweighted_loss, weights=weights, reduction=losses.Reduction.SUM) + # _weights() can return 1. + example_weight_sum = math_ops.reduce_sum( + weights * array_ops.ones_like(unweighted_loss)) + return head_lib.LossSpec( + weighted_sum_loss=weighted_sum_loss, + example_weight_sum=example_weight_sum, processed_labels=processed_labels) def create_estimator_spec( self, features, mode, logits, labels=None, train_op_fn=None): """See `Head`.""" with ops.name_scope(self._name, 'head'): - logits = head_lib._check_logits(logits, self.logits_dimension) # pylint:disable=protected-access + logits = head_lib._check_logits_final_dim(logits, self.logits_dimension) # pylint:disable=protected-access # Predict. pred_keys = prediction_keys.PredictionKeys @@ -321,22 +394,24 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access export_output.PredictOutput(predictions)) }) + (weighted_sum_loss, example_weight_sum, + processed_labels) = self.create_loss( + features=features, mode=mode, logits=logits, labels=labels) + # Eval. - unweighted_loss, processed_labels = self.create_loss( - features=features, mode=mode, logits=logits, labels=labels) - weights = head_lib._weights(features, self._weight_column) # pylint:disable=protected-access - training_loss = losses.compute_weighted_loss( - unweighted_loss, weights=weights, reduction=losses.Reduction.SUM) if mode == model_fn.ModeKeys.EVAL: + weights = head_lib._get_weights_and_check_match_logits( # pylint:disable=protected-access, + features=features, weight_column=self._weight_column, logits=logits) return model_fn.EstimatorSpec( mode=model_fn.ModeKeys.EVAL, predictions=predictions, - loss=training_loss, + loss=weighted_sum_loss, eval_metric_ops=self._eval_metric_ops( labels=processed_labels, probabilities=probabilities, weights=weights, - unweighted_loss=unweighted_loss)) + weighted_sum_loss=weighted_sum_loss, + example_weight_sum=example_weight_sum)) # Train. if train_op_fn is None: @@ -344,37 +419,43 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access with ops.name_scope(''): summary.scalar( head_lib._summary_key(self._name, metric_keys.MetricKeys.LOSS), # pylint:disable=protected-access - training_loss) + weighted_sum_loss) summary.scalar( head_lib._summary_key( # pylint:disable=protected-access self._name, metric_keys.MetricKeys.LOSS_MEAN), - losses.compute_weighted_loss( - unweighted_loss, weights=weights, - reduction=losses.Reduction.MEAN)) + weighted_sum_loss / example_weight_sum) return model_fn.EstimatorSpec( mode=model_fn.ModeKeys.TRAIN, predictions=predictions, - loss=training_loss, - train_op=train_op_fn(training_loss)) + loss=weighted_sum_loss, + train_op=train_op_fn(weighted_sum_loss)) - def _eval_metric_ops(self, labels, probabilities, weights, unweighted_loss): + def _eval_metric_ops(self, labels, probabilities, weights, weighted_sum_loss, + example_weight_sum): """Returns a dict of metrics for eval_metric_ops.""" with ops.name_scope( - None, 'metrics', [labels, probabilities, weights, unweighted_loss]): + None, 'metrics', + [labels, probabilities, weights, weighted_sum_loss, example_weight_sum + ]): keys = metric_keys.MetricKeys metric_ops = { # Estimator already adds a metric for loss. head_lib._summary_key(self._name, keys.LOSS_MEAN): # pylint:disable=protected-access metrics_lib.mean( - unweighted_loss, weights=weights, name=keys.LOSS_MEAN), + # Both values and weights here are reduced, scalar Tensors. + # values is the actual mean we want, but we pass the scalar + # example_weight_sum in order to return the correct update_op + # alongside the value_op for streaming metrics. + values=(weighted_sum_loss / example_weight_sum), + weights=example_weight_sum, + name=keys.LOSS_MEAN), head_lib._summary_key(self._name, keys.AUC): # pylint:disable=protected-access - metrics_lib.auc( - labels=labels, predictions=probabilities, weights=weights, - name=keys.AUC), + metrics_lib.auc(labels=labels, predictions=probabilities, + weights=weights, name=keys.AUC), head_lib._summary_key(self._name, keys.AUC_PR): # pylint:disable=protected-access - metrics_lib.auc( - labels=labels, predictions=probabilities, weights=weights, - curve='PR', name=keys.AUC_PR), + metrics_lib.auc(labels=labels, predictions=probabilities, + weights=weights, curve='PR', + name=keys.AUC_PR), } for threshold in self._thresholds: accuracy_key = keys.ACCURACY_AT_THRESHOLD % threshold @@ -453,4 +534,3 @@ def _call_loss_fn(loss_fn, labels, logits, features): loss_shape]) with ops.control_dependencies([check_shape_op]): return array_ops.identity(unweighted_loss) - diff --git a/tensorflow/contrib/estimator/python/estimator/head_test.py b/tensorflow/contrib/estimator/python/estimator/head_test.py index db7d96d508649f93c23b55504088551747f15a26..d1cf9090048470181818c573647923c9f5824dfa 100644 --- a/tensorflow/contrib/estimator/python/estimator/head_test.py +++ b/tensorflow/contrib/estimator/python/estimator/head_test.py @@ -226,7 +226,7 @@ class MultiLabelHead(test.TestCase): def test_weight_should_not_impact_prediction(self): n_classes = 4 - head = head_lib.multi_label_head(n_classes, weight_column='label_weights') + head = head_lib.multi_label_head(n_classes, weight_column='example_weights') self.assertEqual(n_classes, head.logits_dimension) logits = np.array( @@ -237,7 +237,7 @@ class MultiLabelHead(test.TestCase): spec = head.create_estimator_spec( features={ 'x': np.array(((42,),), dtype=np.int32), - 'label_weights': weights_2x1, + 'example_weights': weights_2x1, }, mode=model_fn.ModeKeys.PREDICT, logits=logits) @@ -262,17 +262,17 @@ class MultiLabelHead(test.TestCase): labels = np.array([[1, 0], [1, 1]], dtype=np.int64) # loss = labels * -log(sigmoid(logits)) + # (1 - labels) * -log(1 - sigmoid(logits)) - expected_unweighted_loss = _sigmoid_cross_entropy( - labels=labels, logits=logits) - actual_unweighted_loss, _ = head.create_loss( + expected_weighted_sum_loss = np.sum( + _sigmoid_cross_entropy(labels=labels, logits=logits)) + actual_weighted_sum_loss = head.create_loss( features={'x': np.array(((42,),), dtype=np.int32)}, mode=model_fn.ModeKeys.EVAL, logits=logits, - labels=labels) + labels=labels)[0] with self.test_session(): _initialize_variables(self, monitored_session.Scaffold()) - self.assertAllClose( - expected_unweighted_loss, actual_unweighted_loss.eval()) + self.assertAllClose(expected_weighted_sum_loss, + actual_weighted_sum_loss.eval()) def test_eval_create_loss_large_logits(self): """Tests head.create_loss for eval mode and large logits.""" @@ -286,17 +286,19 @@ class MultiLabelHead(test.TestCase): # For large logits, this is approximated as: # loss = labels * (logits < 0) * (-logits) + # (1 - labels) * (logits > 0) * logits - expected_unweighted_loss = np.array( - [[(10. + 10.) / 2.], [(15. + 0.) / 2.]], dtype=np.float32) - actual_unweighted_loss, _ = head.create_loss( + expected_weighted_sum_loss = np.sum( + np.array([[(10. + 10.) / 2.], [(15. + 0.) / 2.]], dtype=np.float32)) + actual_weighted_sum_loss = head.create_loss( features={'x': np.array(((42,),), dtype=np.int32)}, mode=model_fn.ModeKeys.EVAL, logits=logits, - labels=labels) + labels=labels)[0] with self.test_session(): _initialize_variables(self, monitored_session.Scaffold()) self.assertAllClose( - expected_unweighted_loss, actual_unweighted_loss.eval(), atol=1e-4) + expected_weighted_sum_loss, + actual_weighted_sum_loss.eval(), + atol=1e-4) def test_eval_create_loss_labels_wrong_shape(self): """Tests head.create_loss for eval mode when labels has the wrong shape.""" @@ -305,23 +307,26 @@ class MultiLabelHead(test.TestCase): logits = np.array([[-1., 1.], [-1.5, 1.]], dtype=np.float32) labels_placeholder = array_ops.placeholder(dtype=dtypes.int64) - actual_unweighted_loss, _ = head.create_loss( + actual_weighted_sum_loss = head.create_loss( features={'x': np.array(((42,),), dtype=np.int32)}, mode=model_fn.ModeKeys.EVAL, logits=logits, - labels=labels_placeholder) + labels=labels_placeholder)[0] with self.test_session(): _initialize_variables(self, monitored_session.Scaffold()) with self.assertRaisesRegexp( errors.InvalidArgumentError, - r'labels shape must be \[batch_size, 2\]\. Given: \] \[2 1\]'): - actual_unweighted_loss.eval( - {labels_placeholder: np.array([[1], [1]], dtype=np.int64)}) + r'\[expected_labels_shape: \] \[2 2\] \[labels_shape: \] \[2 1\]'): + actual_weighted_sum_loss.eval({ + labels_placeholder: np.array([[1], [1]], dtype=np.int64) + }) with self.assertRaisesRegexp( errors.InvalidArgumentError, - r'labels shape must be \[batch_size, 2\]\. Given: \] \[2\]'): - actual_unweighted_loss.eval( - {labels_placeholder: np.array([1, 1], dtype=np.int64)}) + r'labels shape must be \[D0, D1, ... DN, 2\]\..*' + r'\[Received shape: \] \[2\]'): + actual_weighted_sum_loss.eval({ + labels_placeholder: np.array([1, 1], dtype=np.int64) + }) def test_eval_create_loss_loss_fn(self): """Tests head.create_loss for eval mode and custom loss_fn.""" @@ -339,14 +344,14 @@ class MultiLabelHead(test.TestCase): return constant_op.constant(loss) head = head_lib.multi_label_head(n_classes=2, loss_fn=_loss_fn) - actual_unweighted_loss, _ = head.create_loss( + actual_weighted_sum_loss = head.create_loss( features={'x': np.array(((42,),), dtype=np.int32)}, mode=model_fn.ModeKeys.EVAL, logits=logits_input, - labels=labels_input) + labels=labels_input)[0] with self.test_session(): _initialize_variables(self, monitored_session.Scaffold()) - self.assertAllClose(loss, actual_unweighted_loss.eval()) + self.assertAllClose(np.sum(loss), actual_weighted_sum_loss.eval()) def test_eval_create_loss_loss_fn_wrong_shape(self): """Tests custom loss_fn that returns Tensor of unexpected shape.""" @@ -358,18 +363,18 @@ class MultiLabelHead(test.TestCase): logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32) labels = np.array([[1, 0], [1, 1]], dtype=np.int64) - actual_unweighted_loss, _ = head.create_loss( + actual_weighted_sum_loss = head.create_loss( features={'x': np.array(((42,),), dtype=np.int32)}, mode=model_fn.ModeKeys.EVAL, logits=logits, - labels=labels) + labels=labels)[0] with self.test_session(): _initialize_variables(self, monitored_session.Scaffold()) with self.assertRaisesRegexp( errors.InvalidArgumentError, r'loss_fn must return Tensor of shape \[batch_size, 1\]\. ' r'Given: \] \[2\]'): - actual_unweighted_loss.eval() + actual_weighted_sum_loss.eval() def test_eval_labels_none(self): """Tests that error is raised when labels is None.""" @@ -383,9 +388,11 @@ class MultiLabelHead(test.TestCase): logits=np.array([[-10., 10.], [-15., 10.]], dtype=np.float32), labels=None) - def _test_eval(self, head, logits, labels, expected_loss, expected_metrics): + def _test_eval( + self, head, logits, labels, expected_loss, expected_metrics, + features=None): spec = head.create_estimator_spec( - features={'x': np.array(((42,),), dtype=np.int32)}, + features=features or {}, mode=model_fn.ModeKeys.EVAL, logits=logits, labels=labels) @@ -545,7 +552,7 @@ class MultiLabelHead(test.TestCase): def test_eval_with_weights(self): n_classes = 2 - head = head_lib.multi_label_head(n_classes, weight_column='label_weights') + head = head_lib.multi_label_head(n_classes, weight_column='example_weights') logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32) labels = np.array([[1, 0], [1, 1]], dtype=np.int64) @@ -559,7 +566,7 @@ class MultiLabelHead(test.TestCase): spec = head.create_estimator_spec( features={ 'x': np.array([[41], [42]], dtype=np.int32), - 'label_weights': np.array([[1.], [2.]], dtype=np.float32), + 'example_weights': np.array([[1.], [2.]], dtype=np.float32), }, mode=model_fn.ModeKeys.EVAL, logits=logits, @@ -601,26 +608,39 @@ class MultiLabelHead(test.TestCase): def test_train_create_loss_large_logits(self): """Tests head.create_loss for train mode and large logits.""" n_classes = 2 - head = head_lib.multi_label_head(n_classes) + head = head_lib.multi_label_head(n_classes, weight_column='example_weights') logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32) labels = np.array([[1, 0], [1, 1]], dtype=np.int64) + weights = np.array([[1.], [2.]], dtype=np.float32) # loss = labels * -log(sigmoid(logits)) + # (1 - labels) * -log(1 - sigmoid(logits)) # For large logits, this is approximated as: # loss = labels * (logits < 0) * (-logits) + # (1 - labels) * (logits > 0) * logits - expected_unweighted_loss = np.array( - [[(10. + 10.) / 2.], [(15. + 0.) / 2.]], dtype=np.float32) - actual_unweighted_loss, _ = head.create_loss( - features={'x': np.array(((42,),), dtype=np.int32)}, + expected_weighted_sum_loss = np.sum( + np.array( + [[1. * (10. + 10.) / 2.], [2. * (15. + 0.) / 2.]], + dtype=np.float32)) + expected_example_weight_sum = 1. + 2. + actual_weighted_sum_loss, actual_example_weight_sum, _ = head.create_loss( + features={ + 'x': np.array(((42,),), dtype=np.int32), + 'example_weights': weights + }, mode=model_fn.ModeKeys.TRAIN, logits=logits, labels=labels) with self.test_session(): _initialize_variables(self, monitored_session.Scaffold()) self.assertAllClose( - expected_unweighted_loss, actual_unweighted_loss.eval(), atol=1e-4) + expected_weighted_sum_loss, + actual_weighted_sum_loss.eval(), + atol=1e-4) + self.assertAllClose( + expected_example_weight_sum, + actual_example_weight_sum.eval(), + atol=1e-4) def test_train_labels_none(self): """Tests that error is raised when labels is None.""" @@ -638,6 +658,54 @@ class MultiLabelHead(test.TestCase): labels=None, train_op_fn=_no_op_train_fn) + def test_train_invalid_indicator_labels(self): + head = head_lib.multi_label_head(n_classes=2) + logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32) + # The value 2 is outside the allowed range. + labels = np.array([[2, 0], [1, 1]], dtype=np.int64) + def _train_op_fn(loss): + del loss + return control_flow_ops.no_op() + + spec = head.create_estimator_spec( + features={}, + mode=model_fn.ModeKeys.TRAIN, + logits=logits, + labels=labels, + train_op_fn=_train_op_fn) + with self.test_session() as sess: + _initialize_variables(self, spec.scaffold) + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + r'labels must be an integer indicator Tensor with values in ' + r'\[0, 1\]'): + sess.run(spec.loss) + + def test_train_invalid_sparse_labels(self): + head = head_lib.multi_label_head(n_classes=2) + logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32) + # The value 2 is outside the allowed range. + labels = sparse_tensor.SparseTensor( + values=[2, 0, 1], + indices=[[0, 0], [1, 0], [1, 1]], + dense_shape=[2, 2]) + def _train_op_fn(loss): + del loss + return control_flow_ops.no_op() + + spec = head.create_estimator_spec( + features={}, + mode=model_fn.ModeKeys.TRAIN, + logits=logits, + labels=labels, + train_op_fn=_train_op_fn) + with self.test_session() as sess: + _initialize_variables(self, spec.scaffold) + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + r'labels must be an integer SparseTensor with values in \[0, 2\)'): + sess.run(spec.loss) + def _test_train(self, head, logits, labels, expected_loss): expected_train_result = 'my_train_op' def _train_op_fn(loss): @@ -725,7 +793,7 @@ class MultiLabelHead(test.TestCase): def test_train_with_weights(self): n_classes = 2 - head = head_lib.multi_label_head(n_classes, weight_column='label_weights') + head = head_lib.multi_label_head(n_classes, weight_column='example_weights') logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32) labels = np.array([[1, 0], [1, 1]], dtype=np.int64) @@ -744,7 +812,7 @@ class MultiLabelHead(test.TestCase): spec = head.create_estimator_spec( features={ 'x': np.array([[41], [42]], dtype=np.int32), - 'label_weights': np.array([[1.], [2.]], dtype=np.float32), + 'example_weights': np.array([[1.], [2.]], dtype=np.float32), }, mode=model_fn.ModeKeys.TRAIN, logits=logits, @@ -774,6 +842,153 @@ class MultiLabelHead(test.TestCase): metric_keys.MetricKeys.LOSS_MEAN: expected_loss / 3, }, summary_str, tol) + def test_multi_dim_weighted_train_create_loss(self): + """Logits and labels of shape [2, 2, 3], weights [2, 2].""" + head = head_lib.multi_label_head(n_classes=3, weight_column='weights') + + logits = np.array([[[-10., 10., -10.], [10., -10., 10.]], + [[-12., 12., -12.], [12., -12., 12.]]], dtype=np.float32) + labels = np.array([[[1, 0, 0], [1, 0, 0]], + [[0, 1, 1], [0, 1, 1]]], dtype=np.int64) + weights = np.array([[1., 1.5], [2., 2.5]], dtype=np.float32) + # loss = [[10 + 10 + 0, 0 + 0 + 10], [0 + 0 + 12, 12 + 12 + 0]] / 3 + # = [[20/3, 10/3], [4, 8]] + # weighted_sum_loss = 1*20/3 + 1.5*10/3 + 2*4 + 2.5*8 = 39.6667 + expected_weighted_sum_loss = 39.6667 + expected_example_weight_sum = np.sum(weights) + actual_weighted_sum_loss, actual_example_weight_sum, _ = head.create_loss( + features={'weights': weights}, + mode=model_fn.ModeKeys.TRAIN, + logits=logits, + labels=labels) + atol = 1.e-3 + with self.test_session(): + _initialize_variables(self, monitored_session.Scaffold()) + self.assertAllClose( + expected_weighted_sum_loss, actual_weighted_sum_loss.eval(), + atol=atol) + self.assertAllClose( + expected_example_weight_sum, actual_example_weight_sum.eval(), + atol=atol) + + def test_multi_dim_weighted_train(self): + """Logits and labels of shape [2, 2, 3], weights [2, 2].""" + head = head_lib.multi_label_head(n_classes=3, weight_column='weights') + + logits = np.array([[[-10., 10., -10.], [10., -10., 10.]], + [[-12., 12., -12.], [12., -12., 12.]]], dtype=np.float32) + labels = np.array([[[1, 0, 0], [1, 0, 0]], + [[0, 1, 1], [0, 1, 1]]], dtype=np.int64) + weights = np.array([[1., 1.5], [2., 2.5]], dtype=np.float32) + # loss = [[10 + 10 + 0, 0 + 0 + 10], [0 + 0 + 12, 12 + 12 + 0]] / 3 + # = [[20/3, 10/3], [4, 8]] + # weighted_sum_loss = 1*20/3 + 1.5*10/3 + 2*4 + 2.5*8 = 39.6667 + expected_loss = 39.6667 + expected_train_result = 'my_train_op' + def _train_op_fn(loss): + return string_ops.string_join( + [constant_op.constant(expected_train_result), + string_ops.as_string(loss, precision=3)]) + + spec = head.create_estimator_spec( + features={'weights': weights}, + mode=model_fn.ModeKeys.TRAIN, + logits=logits, + labels=labels, + train_op_fn=_train_op_fn) + + atol = 1.e-3 + with self.test_session() as sess: + _initialize_variables(self, monitored_session.Scaffold()) + loss, train_result = sess.run((spec.loss, spec.train_op)) + self.assertAllClose(expected_loss, loss, atol=atol) + self.assertEqual( + six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)), + train_result) + + def test_multi_dim_weights_wrong_inner_dim(self): + """Logits and labels of shape [2, 2, 3], weights [2, 1].""" + head = head_lib.multi_label_head(n_classes=3, weight_column='weights') + + logits = np.array([[[-10., 10., -10.], [10., -10., 10.]], + [[-12., 12., -12.], [12., -12., 12.]]], dtype=np.float32) + labels = np.array([[[1, 0, 0], [1, 0, 0]], + [[0, 1, 1], [0, 1, 1]]], dtype=np.int64) + weights = np.array([[1.], [2.]], dtype=np.float32) + def _train_op_fn(loss): + del loss + return control_flow_ops.no_op() + + spec = head.create_estimator_spec( + features={'weights': weights}, + mode=model_fn.ModeKeys.TRAIN, + logits=logits, + labels=labels, + train_op_fn=_train_op_fn) + with self.test_session(): + _initialize_variables(self, monitored_session.Scaffold()) + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + r'\[logits_shape: \] \[2 2 3\] \[weights_shape: \] \[2 1\]'): + spec.loss.eval() + + def test_multi_dim_weights_wrong_outer_dim(self): + """Logits and labels of shape [2, 2, 3], weights [2, 2, 3].""" + head = head_lib.multi_label_head(n_classes=3, weight_column='weights') + + logits = np.array([[[-10., 10., -10.], [10., -10., 10.]], + [[-12., 12., -12.], [12., -12., 12.]]], dtype=np.float32) + labels = np.array([[[1, 0, 0], [1, 0, 0]], + [[0, 1, 1], [0, 1, 1]]], dtype=np.int64) + weights = np.array([[[1., 1., 1.], [1.5, 1.5, 1.5]], + [[2., 2., 2.], [2.5, 2.5, 2.5]]], dtype=np.float32) + weights_placeholder = array_ops.placeholder(dtype=dtypes.float32) + def _train_op_fn(loss): + del loss + return control_flow_ops.no_op() + + spec = head.create_estimator_spec( + features={'weights': weights_placeholder}, + mode=model_fn.ModeKeys.TRAIN, + logits=logits, + labels=labels, + train_op_fn=_train_op_fn) + with self.test_session(): + _initialize_variables(self, monitored_session.Scaffold()) + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + r'\[logits_shape: \] \[2 2 3\] \[weights_shape: \] \[2 2 3\]'): + spec.loss.eval({weights_placeholder: weights}) + + def test_multi_dim_weighted_eval(self): + """Logits and labels of shape [2, 2, 3], weights [2, 2].""" + head = head_lib.multi_label_head(n_classes=3, weight_column='weights') + + logits = np.array([[[-10., 10., -10.], [10., -10., 10.]], + [[-12., 12., -12.], [12., -12., 12.]]], dtype=np.float32) + labels = np.array([[[1, 0, 0], [1, 0, 0]], + [[0, 1, 1], [0, 1, 1]]], dtype=np.int64) + weights = np.array([[1., 1.5], [2., 2.5]], dtype=np.float32) + # loss = [[10 + 10 + 0, 0 + 0 + 10], [0 + 0 + 12, 12 + 12 + 0]] / 3 + # = [[20/3, 10/3], [4, 8]] + # weighted_sum_loss = 1*20/3 + 1.5*10/3 + 2*4 + 2.5*8 = 39.6667 + expected_loss = 39.6667 + keys = metric_keys.MetricKeys + expected_metrics = { + keys.LOSS_MEAN: expected_loss / np.sum(weights), + # auc and auc_pr cannot be reliably calculated for only 4 samples, but + # this assert tests that the algorithm remains consistent. + keys.AUC: 0.4977, + keys.AUC_PR: 0.6645, + } + self._test_eval( + head=head, + features={'weights': weights}, + logits=logits, + labels=labels, + expected_loss=expected_loss, + expected_metrics=expected_metrics) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/estimator/python/estimator/linear.py b/tensorflow/contrib/estimator/python/estimator/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..3bf4abe83d54504d55de73b63f369cceaf149dd2 --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/linear.py @@ -0,0 +1,118 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Linear estimator.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.estimator import estimator +from tensorflow.python.estimator.canned import linear as linear_lib + + +class LinearEstimator(estimator.Estimator): + """An estimator for TensorFlow linear models with user-specified head. + + Example: + + ```python + categorical_column_a = categorical_column_with_hash_bucket(...) + categorical_column_b = categorical_column_with_hash_bucket(...) + + categorical_feature_a_x_categorical_feature_b = crossed_column(...) + + # Estimator using the default optimizer. + estimator = LinearEstimator( + head=tf.contrib.estimator.multi_label_head(n_classes=3), + feature_columns=[categorical_column_a, + categorical_feature_a_x_categorical_feature_b]) + + # Or estimator using the FTRL optimizer with regularization. + estimator = LinearEstimator( + head=tf.contrib.estimator.multi_label_head(n_classes=3), + feature_columns=[categorical_column_a, + categorical_feature_a_x_categorical_feature_b]) + optimizer=tf.train.FtrlOptimizer( + learning_rate=0.1, + l1_regularization_strength=0.001 + )) + + def input_fn_train: # returns x, y (where y represents label's class index). + ... + estimator.train(input_fn=input_fn_train, steps=100) + def input_fn_eval: # returns x, y (where y represents label's class index). + ... + metrics = estimator.evaluate(input_fn=input_fn_eval, steps=10) + def input_fn_predict: # returns x, None + ... + predictions = estimator.predict(input_fn=input_fn_predict) + ``` + + Input of `train` and `evaluate` should have following features, + otherwise there will be a `KeyError`: + + * if `weight_column` is not `None`, a feature with + `key=weight_column` whose value is a `Tensor`. + * for each `column` in `feature_columns`: + - if `column` is a `_CategoricalColumn`, a feature with `key=column.name` + whose `value` is a `SparseTensor`. + - if `column` is a `_WeightedCategoricalColumn`, two features: the first + with `key` the id column name, the second with `key` the weight column + name. Both features' `value` must be a `SparseTensor`. + - if `column` is a `_DenseColumn`, a feature with `key=column.name` + whose `value` is a `Tensor`. + + Loss and predicted output are determined by the specified head. + + @compatibility(eager) + Estimators are not compatible with eager execution. + @end_compatibility + """ + + def __init__(self, + head, + feature_columns, + model_dir=None, + optimizer='Ftrl', + config=None, + partitioner=None): + """Initializes a `LinearEstimator` instance. + + Args: + head: A `_Head` instance constructed with a method such as + `tf.contrib.estimator.multi_label_head`. + feature_columns: An iterable containing all the feature columns used by + the model. All items in the set should be instances of classes derived + from `FeatureColumn`. + model_dir: Directory to save model parameters, graph and etc. This can + also be used to load checkpoints from the directory into a estimator + to continue training a previously saved model. + optimizer: An instance of `tf.Optimizer` used to train the model. Defaults + to FTRL optimizer. + config: `RunConfig` object to configure the runtime settings. + partitioner: Optional. Partitioner for input layer. + """ + def _model_fn(features, labels, mode, config): + return linear_lib._linear_model_fn( # pylint: disable=protected-access + features=features, + labels=labels, + mode=mode, + head=head, + feature_columns=tuple(feature_columns or []), + optimizer=optimizer, + partitioner=partitioner, + config=config) + super(LinearEstimator, self).__init__( + model_fn=_model_fn, model_dir=model_dir, config=config) diff --git a/tensorflow/contrib/estimator/python/estimator/linear_test.py b/tensorflow/contrib/estimator/python/estimator/linear_test.py new file mode 100644 index 0000000000000000000000000000000000000000..c63514eb688af48577f0a3b7ce9e7478309f2c30 --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/linear_test.py @@ -0,0 +1,153 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for linear.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import shutil +import tempfile + +import numpy as np +import six + +from tensorflow.contrib.estimator.python.estimator import head as head_lib +from tensorflow.contrib.estimator.python.estimator import linear +from tensorflow.python.estimator.canned import linear_testing_utils +from tensorflow.python.estimator.canned import prediction_keys +from tensorflow.python.estimator.export import export +from tensorflow.python.estimator.inputs import numpy_io +from tensorflow.python.feature_column import feature_column +from tensorflow.python.framework import ops +from tensorflow.python.platform import gfile +from tensorflow.python.platform import test +from tensorflow.python.summary.writer import writer_cache + + +def _linear_estimator_fn( + weight_column=None, label_dimension=1, *args, **kwargs): + """Returns a LinearEstimator that uses regression_head.""" + return linear.LinearEstimator( + head=head_lib.regression_head( + weight_column=weight_column, label_dimension=label_dimension), + *args, **kwargs) + + +class LinearEstimatorEvaluateTest( + linear_testing_utils.BaseLinearRegressorEvaluationTest, test.TestCase): + + def __init__(self, methodName='runTest'): # pylint: disable=invalid-name + test.TestCase.__init__(self, methodName) + linear_testing_utils.BaseLinearRegressorEvaluationTest.__init__( + self, _linear_estimator_fn) + + +class LinearEstimatorPredictTest( + linear_testing_utils.BaseLinearRegressorPredictTest, test.TestCase): + + def __init__(self, methodName='runTest'): # pylint: disable=invalid-name + test.TestCase.__init__(self, methodName) + linear_testing_utils.BaseLinearRegressorPredictTest.__init__( + self, _linear_estimator_fn) + + +class LinearEstimatorTrainTest( + linear_testing_utils.BaseLinearRegressorTrainingTest, test.TestCase): + + def __init__(self, methodName='runTest'): # pylint: disable=invalid-name + test.TestCase.__init__(self, methodName) + linear_testing_utils.BaseLinearRegressorTrainingTest.__init__( + self, _linear_estimator_fn) + + +class LinearEstimatorIntegrationTest(test.TestCase): + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + + def tearDown(self): + if self._model_dir: + writer_cache.FileWriterCache.clear() + shutil.rmtree(self._model_dir) + + def _test_complete_flow( + self, train_input_fn, eval_input_fn, predict_input_fn, input_dimension, + label_dimension, batch_size): + feature_columns = [ + feature_column.numeric_column('x', shape=(input_dimension,))] + est = linear.LinearEstimator( + head=head_lib.regression_head(label_dimension=label_dimension), + feature_columns=feature_columns, + model_dir=self._model_dir) + + # TRAIN + num_steps = 10 + est.train(train_input_fn, steps=num_steps) + + # EVALUTE + scores = est.evaluate(eval_input_fn) + self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP]) + self.assertIn('loss', six.iterkeys(scores)) + + # PREDICT + predictions = np.array([ + x[prediction_keys.PredictionKeys.PREDICTIONS] + for x in est.predict(predict_input_fn) + ]) + self.assertAllEqual((batch_size, label_dimension), predictions.shape) + + # EXPORT + feature_spec = feature_column.make_parse_example_spec(feature_columns) + serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn( + feature_spec) + export_dir = est.export_savedmodel(tempfile.mkdtemp(), + serving_input_receiver_fn) + self.assertTrue(gfile.Exists(export_dir)) + + def test_numpy_input_fn(self): + """Tests complete flow with numpy_input_fn.""" + label_dimension = 2 + batch_size = 10 + data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32) + data = data.reshape(batch_size, label_dimension) + # learn y = x + train_input_fn = numpy_io.numpy_input_fn( + x={'x': data}, + y=data, + batch_size=batch_size, + num_epochs=None, + shuffle=True) + eval_input_fn = numpy_io.numpy_input_fn( + x={'x': data}, + y=data, + batch_size=batch_size, + shuffle=False) + predict_input_fn = numpy_io.numpy_input_fn( + x={'x': data}, + batch_size=batch_size, + shuffle=False) + + self._test_complete_flow( + train_input_fn=train_input_fn, + eval_input_fn=eval_input_fn, + predict_input_fn=predict_input_fn, + input_dimension=label_dimension, + label_dimension=label_dimension, + batch_size=batch_size) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/estimator/python/estimator/logit_fns.py b/tensorflow/contrib/estimator/python/estimator/logit_fns.py index 110ea0302e703fd3eecdfafea928d7ba04f07d8e..09c2862ccd3f90de4153a2095afc9c3d3f9476c1 100644 --- a/tensorflow/contrib/estimator/python/estimator/logit_fns.py +++ b/tensorflow/contrib/estimator/python/estimator/logit_fns.py @@ -39,6 +39,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import six + from tensorflow.python.estimator import util from tensorflow.python.estimator.canned import dnn as dnn_core from tensorflow.python.estimator.canned import linear as linear_core @@ -67,7 +69,8 @@ def call_logit_fn(logit_fn, features, mode, params, config): A logit Tensor, the output of logit_fn. Raises: - ValueError: if logit_fn does not return a Tensor. + ValueError: if logit_fn does not return a Tensor or a dictionary mapping + strings to Tensors. """ logit_fn_args = util.fn_args(logit_fn) kwargs = {} @@ -79,7 +82,15 @@ def call_logit_fn(logit_fn, features, mode, params, config): kwargs['config'] = config logit_fn_results = logit_fn(features=features, **kwargs) - if not isinstance(logit_fn_results, ops.Tensor): - raise ValueError('model_fn should return a Tensor.') + result_is_valid_dictionary = ( + isinstance(logit_fn_results, dict) and + all([(isinstance(k, six.string_types) and isinstance(v, ops.Tensor)) + for k, v in six.iteritems(logit_fn_results)])) + result_is_tensor = isinstance(logit_fn_results, ops.Tensor) + + if not (result_is_valid_dictionary or result_is_tensor): + raise ValueError('logit_fn should return a Tensor or a dictionary mapping ' + 'strings to Tensors. logit_fn returned: %s' % + logit_fn_results) return logit_fn_results diff --git a/tensorflow/contrib/estimator/python/estimator/logit_fns_test.py b/tensorflow/contrib/estimator/python/estimator/logit_fns_test.py index d75eada798dcdf929e4094258ecdc6ce394f847c..074ece6cca2865b9057ab5ce874a210d3d9ac2e0 100644 --- a/tensorflow/contrib/estimator/python/estimator/logit_fns_test.py +++ b/tensorflow/contrib/estimator/python/estimator/logit_fns_test.py @@ -43,22 +43,53 @@ class LogitFnTest(test.TestCase): with session.Session(): self.assertAllClose([[4., 5.]], logit_fn_result.eval()) - def test_should_return_tensor(self): + def test_simple_call_multi_logit_fn(self): + + def dummy_logit_fn(features): + return {u'head1': features['f1'], 'head2': features['f2']} + + features = { + 'f1': constant_op.constant([[2., 3.]]), + 'f2': constant_op.constant([[4., 5.]]) + } + logit_fn_result = logit_fns.call_logit_fn(dummy_logit_fn, features, + model_fn.ModeKeys.TRAIN, + 'fake_params', 'fake_config') + with session.Session(): + self.assertAllClose([[2., 3.]], logit_fn_result['head1'].eval()) + self.assertAllClose([[4., 5.]], logit_fn_result['head2'].eval()) + + def test_invalid_logit_fn_results(self): def invalid_logit_fn(features, params): - return { - 'tensor1': features['f1'] * params['input_multiplier'], - 'tensor2': features['f2'] * params['input_multiplier'] - } + return [ + features['f1'] * params['input_multiplier'], + features['f2'] * params['input_multiplier'] + ] + features = { 'f1': constant_op.constant([[2., 3.]]), 'f2': constant_op.constant([[4., 5.]]) } params = {'learning_rate': 0.001, 'input_multiplier': 2.0} - with self.assertRaisesRegexp(ValueError, 'model_fn should return a Tensor'): + with self.assertRaisesRegexp( + ValueError, 'logit_fn should return a Tensor or a dictionary mapping ' + 'strings to Tensors'): logit_fns.call_logit_fn(invalid_logit_fn, features, 'fake_mode', params, 'fake_config') + def test_invalid_logit_fn_results_dict(self): + + def invalid_logit_fn(features): + return {'head1': features['f1'], 'head2': features['f2']} + + features = {'f1': constant_op.constant([[2., 3.]]), 'f2': 'some string'} + with self.assertRaisesRegexp( + ValueError, 'logit_fn should return a Tensor or a dictionary mapping ' + 'strings to Tensors'): + logit_fns.call_logit_fn(invalid_logit_fn, features, 'fake_mode', + 'fake_params', 'fake_config') + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/estimator/python/estimator/multi_head.py b/tensorflow/contrib/estimator/python/estimator/multi_head.py index 64b2a9dee83801b5d6d852a3485fc0cc81417ff0..f2a6eae03ec021e5c28d48b3887870d8a057e077 100644 --- a/tensorflow/contrib/estimator/python/estimator/multi_head.py +++ b/tensorflow/contrib/estimator/python/estimator/multi_head.py @@ -22,10 +22,14 @@ import six from tensorflow.python.estimator import model_fn from tensorflow.python.estimator.canned import head as head_lib +from tensorflow.python.estimator.canned import metric_keys from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import metrics as metrics_lib from tensorflow.python.saved_model import signature_constants +from tensorflow.python.summary import summary _DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY @@ -72,6 +76,23 @@ def multi_head(heads, head_weights=None): estimator.train(input_fn=input_fn, steps=100) ``` + Also supports `logits` as a `Tensor` of shape + `[D0, D1, ... DN, logits_dimension]`. It will split the `Tensor` along the + last dimension and distribute it appropriately among the heads. E.g.: + + ```python + def model_fn(features, labels, mode): + # Create simple heads and specify head name. + head1 = multi_class_head(n_classes=3, name='head1') + head2 = binary_classification_head(name='head2') + # Create multi-head from two simple heads. + head = multi_head([head1, head2]) + # Create logits for the multihead. + logits = logit_fn(logits_dimension=head.logits_dimension) + # Return the merged EstimatorSpec + return head.create_estimator_spec(..., logits=logits, ...) + ``` + Args: heads: List or tuple of `_Head` instances. All heads must have `name` specified. The first head in the list is the default used at serving time. @@ -161,14 +182,53 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access def create_loss(self, features, mode, logits, labels): """See `Head`.""" - # TODO(roumposg): Implement it. - raise NotImplementedError('create_loss not yet implemented for MultiHead.') + if isinstance(logits, dict): + logits_dict = logits + else: + logits_dict = self._split_logits(logits) + weighted_sum_losses = [] + example_weight_sums = [] + labels_by_head = {} + for head in self._heads: + (weighted_sum_loss, + example_weight_sum, processed_labels) = head.create_loss( + features, mode, logits_dict[head.name], labels[head.name]) + weighted_sum_losses.append(weighted_sum_loss) + example_weight_sums.append(example_weight_sum) + labels_by_head[head.name] = processed_labels + + weighted_sum_losses = tuple(weighted_sum_losses) + with ops.name_scope('merge_losses', + values=weighted_sum_losses + (self._head_weights or + tuple())): + if self._head_weights: + head_weighted_losses = [] + head_weighted_example_weight_sums = [] + for loss, example_weight_sum, weight in zip(weighted_sum_losses, + example_weight_sums, + self._head_weights): + head_weighted_losses.append(math_ops.multiply(loss, weight)) + head_weighted_example_weight_sums.append(math_ops.multiply( + example_weight_sum, weight)) + merged_weighted_sum_loss = math_ops.add_n(head_weighted_losses) + merged_example_weight_sum = math_ops.add_n( + head_weighted_example_weight_sums) + else: + merged_weighted_sum_loss = math_ops.add_n(weighted_sum_losses) + merged_example_weight_sum = math_ops.add_n(example_weight_sums) + + return head_lib.LossSpec( + weighted_sum_loss=merged_weighted_sum_loss, + example_weight_sum=merged_example_weight_sum, + processed_labels=labels_by_head) def create_estimator_spec( self, features, mode, logits, labels=None, train_op_fn=None): """See `_Head`.""" - if not isinstance(logits, dict): - raise ValueError('logits must be a dict. Given: {}'.format(logits)) + if isinstance(logits, dict): + logits_dict = logits + else: + logits_dict = self._split_logits(logits) if labels and not isinstance(labels, dict): raise ValueError('labels must be a dict. Given: {}'.format(labels)) @@ -179,20 +239,42 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access head.create_estimator_spec( features=features, mode=mode, - logits=logits[head_name], + logits=logits_dict[head_name], labels=labels[head_name] if labels else None, train_op_fn=_no_op_train_fn)) if mode == model_fn.ModeKeys.TRAIN: if train_op_fn is None: raise ValueError('train_op_fn can not be None in TRAIN mode.') - return self._merge_train(all_estimator_spec, train_op_fn) + spec = self._merge_train(all_estimator_spec, train_op_fn) + with ops.name_scope(''): + summary.scalar(metric_keys.MetricKeys.LOSS, spec.loss) + return spec if mode == model_fn.ModeKeys.PREDICT: return self._merge_predict(all_estimator_spec) if mode == model_fn.ModeKeys.EVAL: return self._merge_eval(all_estimator_spec) raise ValueError('mode={} unrecognized'.format(mode)) + def _split_logits(self, logits): + """Splits logits along the last dimension and returns a dict.""" + logits_dict = {} + with ops.name_scope(None, 'split_logits', values=[logits]): + logits = ops.convert_to_tensor(logits) + batch_shape = array_ops.shape(logits)[:-1] + zeros_like_batch_shape = array_ops.zeros_like(batch_shape) + minus_ones_like_batch_shape = -1 * array_ops.ones_like(batch_shape) + begin_idx = 0 + for head in self._heads: + begin_tensor = array_ops.concat( + [zeros_like_batch_shape, [begin_idx]], axis=0) + size_tensor = array_ops.concat( + [minus_ones_like_batch_shape, [head.logits_dimension]], axis=0) + logits_dict[head.name] = array_ops.slice( + logits, begin=begin_tensor, size=size_tensor) + begin_idx += head.logits_dimension + return logits_dict + def _merge_train(self, all_estimator_spec, train_op_fn): """Merges list of `EstimatorSpec` for training. @@ -261,14 +343,19 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access predictions = {} metrics = {} losses = [] - for head, spec in zip(self._heads, all_estimator_spec): - losses.append(spec.loss) - head_name = head.name - # Metric keys already contain head.name. - metrics.update(spec.eval_metric_ops or {}) - for k, v in six.iteritems(spec.predictions): - predictions[(head_name, k)] = v - loss = _merge_losses(losses, self._head_weights) + with ops.name_scope('merge_eval'): + for head, spec in zip(self._heads, all_estimator_spec): + losses.append(spec.loss) + head_name = head.name + # Loss metric is not added by default. + loss_name = head_lib._summary_key( # pylint:disable=protected-access + head_name, metric_keys.MetricKeys.LOSS) + metrics[loss_name] = metrics_lib.mean(spec.loss, name=loss_name) + # Metric keys already contain head.name. + metrics.update(spec.eval_metric_ops or {}) + for k, v in six.iteritems(spec.predictions): + predictions[(head_name, k)] = v + loss = _merge_losses(losses, self._head_weights) return model_fn.EstimatorSpec( mode=model_fn.ModeKeys.EVAL, diff --git a/tensorflow/contrib/estimator/python/estimator/multi_head_test.py b/tensorflow/contrib/estimator/python/estimator/multi_head_test.py index 48027035cecffc3ce8aacf8ae917f5eb9e9b2473..68f2d5d1cd53456f7dd82222e171b3619052321a 100644 --- a/tensorflow/contrib/estimator/python/estimator/multi_head_test.py +++ b/tensorflow/contrib/estimator/python/estimator/multi_head_test.py @@ -106,7 +106,8 @@ class MultiHeadTest(test.TestCase): multi_head = multi_head_lib.multi_head([head1, head2]) self.assertEqual('head1_head2', multi_head.name) - def test_predict_two_heads(self): + def test_predict_two_heads_logits_dict(self): + """Tests predict with logits as dict.""" head1 = head_lib.multi_label_head(n_classes=2, name='head1') head2 = head_lib.multi_label_head(n_classes=3, name='head2') multi_head = multi_head_lib.multi_head([head1, head2]) @@ -158,6 +159,111 @@ class MultiHeadTest(test.TestCase): expected_probabilities['head2'], sess.run(spec.export_outputs['head2'].scores)) + def test_predict_two_heads_logits_tensor(self): + """Tests predict with logits as Tensor.""" + head1 = head_lib.multi_label_head(n_classes=2, name='head1') + head2 = head_lib.multi_label_head(n_classes=3, name='head2') + multi_head = multi_head_lib.multi_head([head1, head2]) + + logits = np.array( + [[-1., 1., 2., -2., 2.], [-1.5, 1., -3., 2., -2.]], dtype=np.float32) + expected_logits1 = np.array([[-1., 1.], [-1.5, 1.]], dtype=np.float32) + expected_logits2 = np.array([[2., -2., 2.], [-3., 2., -2.]], + dtype=np.float32) + expected_probabilities = { + 'head1': _sigmoid(expected_logits1), + 'head2': _sigmoid(expected_logits2), + } + + spec = multi_head.create_estimator_spec( + features={'x': np.array(((42,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.PREDICT, + logits=logits) + + self.assertItemsEqual( + (_DEFAULT_SERVING_KEY, 'head1', 'classification/head1', 'predict/head1', + 'head2', 'classification/head2', 'predict/head2'), + spec.export_outputs.keys()) + + # Assert predictions and export_outputs. + with self.test_session() as sess: + _initialize_variables(self, spec.scaffold) + self.assertIsNone(spec.scaffold.summary_op) + predictions = sess.run(spec.predictions) + self.assertAllClose( + expected_logits1, + predictions[('head1', prediction_keys.PredictionKeys.LOGITS)]) + self.assertAllClose( + expected_logits2, + predictions[('head2', prediction_keys.PredictionKeys.LOGITS)]) + self.assertAllClose( + expected_probabilities['head1'], + predictions[('head1', prediction_keys.PredictionKeys.PROBABILITIES)]) + self.assertAllClose( + expected_probabilities['head2'], + predictions[('head2', prediction_keys.PredictionKeys.PROBABILITIES)]) + + self.assertAllClose( + expected_probabilities['head1'], + sess.run(spec.export_outputs[_DEFAULT_SERVING_KEY].scores)) + self.assertAllClose( + expected_probabilities['head1'], + sess.run(spec.export_outputs['head1'].scores)) + self.assertAllClose( + expected_probabilities['head2'], + sess.run(spec.export_outputs['head2'].scores)) + + def test_predict_two_heads_logits_tensor_multi_dim(self): + """Tests predict with multi-dimensional logits of shape [2, 2, 5].""" + head1 = head_lib.regression_head(label_dimension=2, name='head1') + head2 = head_lib.regression_head(label_dimension=3, name='head2') + multi_head = multi_head_lib.multi_head([head1, head2]) + + logits = np.array( + [[[-1., 1., 2., -2., 2.], [-1., 1., 2., -2., 2.]], + [[-1.5, 1., -3., 2., -2.], [-1.5, 1., -3., 2., -2.]]], + dtype=np.float32) + expected_logits1 = np.array( + [[[-1., 1.], [-1., 1.]], + [[-1.5, 1.], [-1.5, 1.]]], + dtype=np.float32) + expected_logits2 = np.array( + [[[2., -2., 2.], [2., -2., 2.]], + [[-3., 2., -2.], [-3., 2., -2.]]], + dtype=np.float32) + + spec = multi_head.create_estimator_spec( + features={'x': np.array(((42,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.PREDICT, + logits=logits) + + self.assertItemsEqual( + (_DEFAULT_SERVING_KEY, 'head1', 'regression/head1', 'predict/head1', + 'head2', 'regression/head2', 'predict/head2'), + spec.export_outputs.keys()) + + # Assert predictions and export_outputs. + with self.test_session() as sess: + _initialize_variables(self, spec.scaffold) + self.assertIsNone(spec.scaffold.summary_op) + predictions = sess.run(spec.predictions) + self.assertAllClose( + expected_logits1, + predictions[('head1', prediction_keys.PredictionKeys.PREDICTIONS)]) + self.assertAllClose( + expected_logits2, + predictions[('head2', prediction_keys.PredictionKeys.PREDICTIONS)]) + + self.assertAllClose( + expected_logits1, + sess.run(spec.export_outputs[_DEFAULT_SERVING_KEY].value)) + self.assertAllClose( + expected_logits1, + sess.run(spec.export_outputs['head1'].value)) + self.assertAllClose( + expected_logits2, + sess.run(spec.export_outputs['head2'].value)) + def test_eval_two_heads_with_weights(self): head1 = head_lib.multi_label_head(n_classes=2, name='head1') head2 = head_lib.multi_label_head(n_classes=3, name='head2') @@ -178,7 +284,7 @@ class MultiHeadTest(test.TestCase): # (1 - labels) * (logits > 0) * logits => # head1: expected_unweighted_loss = [[10., 10.], [15., 0.]] # head2: expected_unweighted_loss = [[20., 20., 20.], [30., 0., 0]] - # Average over classes, weighted sum ober batch and heads. + # Average over classes, weighted sum over batch and heads. expected_loss_head1 = 17.5 expected_loss_head2 = 30.0 expected_loss = 1. * expected_loss_head1 + 2. * expected_loss_head2 @@ -191,6 +297,8 @@ class MultiHeadTest(test.TestCase): keys = metric_keys.MetricKeys expected_metrics = { + keys.LOSS + '/head1': expected_loss_head1, + keys.LOSS + '/head2': expected_loss_head2, # Average loss over examples. keys.LOSS_MEAN + '/head1': expected_loss_head1 / 2, keys.LOSS_MEAN + '/head2': expected_loss_head2 / 2, @@ -231,18 +339,25 @@ class MultiHeadTest(test.TestCase): logits = {'head1': np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)} labels = {'head1': np.array([[1, 0], [1, 1]], dtype=np.int64)} - with self.assertRaisesRegexp( - NotImplementedError, - r'create_loss not yet implemented for MultiHead\.'): - multi_head.create_loss( - features={'x': np.array(((42,),), dtype=np.int32)}, - mode=model_fn.ModeKeys.TRAIN, - logits=logits, - labels=labels) + loss = multi_head.create_loss( + features={'x': np.array(((42,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.TRAIN, + logits=logits, + labels=labels)[0] + tol = 1e-3 + with self.test_session(): + # Unreduced loss of the head is [[(10 + 10) / 2], (15 + 0) / 2] + # (averaged over classes, sum-reduced over examples). + self.assertAllClose(17.5, loss.eval(), rtol=tol, atol=tol) def test_train_create_loss_two_heads_with_weights(self): - head1 = head_lib.multi_label_head(n_classes=2, name='head1') - head2 = head_lib.multi_label_head(n_classes=3, name='head2') + # Use different example weighting for each head weighting. + weights1 = np.array([[1.], [2.]], dtype=np.float32) + weights2 = np.array([[2.], [3.]]) + head1 = head_lib.multi_label_head(n_classes=2, name='head1', + weight_column='weights1') + head2 = head_lib.multi_label_head(n_classes=3, name='head2', + weight_column='weights2') multi_head = multi_head_lib.multi_head( [head1, head2], head_weights=[1., 2.]) @@ -255,14 +370,105 @@ class MultiHeadTest(test.TestCase): 'head1': np.array([[1, 0], [1, 1]], dtype=np.int64), 'head2': np.array([[0, 1, 0], [1, 1, 0]], dtype=np.int64), } - with self.assertRaisesRegexp( - NotImplementedError, - r'create_loss not yet implemented for MultiHead\.'): - multi_head.create_loss( - features={'x': np.array(((42,),), dtype=np.int32)}, - mode=model_fn.ModeKeys.TRAIN, - logits=logits, - labels=labels) + weighted_sum_loss, example_weight_sum, _ = multi_head.create_loss( + features={ + 'x': np.array(((42,),), dtype=np.int32), + 'weights1': weights1, + 'weights2': weights2 + }, + mode=model_fn.ModeKeys.TRAIN, + logits=logits, + labels=labels) + tol = 1e-3 + with self.test_session(): + # loss of the first head is [[(10 + 10) / 2], [(15 + 0) / 2]] + # = [10, 7.5] + # weighted_sum_loss = 1 * 10 + 2 * 7.5 = 25 + # loss of the second head is [[(20 + 20 + 20) / 3], [(30 + 0 + 0) / 3]] + # = [20, 10] + # weighted_sum_loss = 2 * 20 + 3 * 10 = 70 + # head-weighted merge = 1 * 25 + 2 * 70 = 165 + self.assertAllClose(165, weighted_sum_loss.eval(), rtol=tol, atol=tol) + # example_weight_sum = 1 * (1 + 2) + 2 * (2 + 3) = 13 + self.assertAllClose(13., example_weight_sum.eval(), rtol=tol, atol=tol) + + def test_train_create_loss_logits_tensor(self): + """Tests create_loss with logits Tensor.""" + weights1 = np.array([[1.], [2.]], dtype=np.float32) + weights2 = np.array([[2.], [3.]]) + head1 = head_lib.multi_label_head(n_classes=2, name='head1', + weight_column='weights1') + head2 = head_lib.multi_label_head(n_classes=3, name='head2', + weight_column='weights2') + multi_head = multi_head_lib.multi_head( + [head1, head2], head_weights=[1., 2.]) + + logits = np.array([[-10., 10., 20., -20., 20.], + [-15., 10., -30., 20., -20.]], dtype=np.float32) + labels = { + 'head1': np.array([[1, 0], [1, 1]], dtype=np.int64), + 'head2': np.array([[0, 1, 0], [1, 1, 0]], dtype=np.int64), + } + weighted_sum_loss, example_weight_sum, _ = multi_head.create_loss( + features={ + 'x': np.array(((42,),), dtype=np.int32), + 'weights1': weights1, + 'weights2': weights2 + }, + mode=model_fn.ModeKeys.TRAIN, + logits=logits, + labels=labels) + tol = 1e-3 + with self.test_session(): + # loss of the first head is [[(10 + 10) / 2], [(15 + 0) / 2]] + # = [10, 7.5] + # weighted_sum_loss = 1 * 10 + 2 * 7.5 = 25 + # loss of the second head is [[(20 + 20 + 20) / 3], [(30 + 0 + 0) / 3]] + # = [20, 10] + # weighted_sum_loss = 2 * 20 + 3 * 10 = 70 + # head-weighted merge = 1 * 25 + 2 * 70 = 165 + self.assertAllClose(165, weighted_sum_loss.eval(), rtol=tol, atol=tol) + # example_weight_sum = 1 * (1 + 2) + 2 * (2 + 3) = 13 + self.assertAllClose(13., example_weight_sum.eval(), rtol=tol, atol=tol) + + def test_train_create_loss_logits_tensor_multi_dim(self): + """Tests create_loss with multi-dimensional logits of shape [2, 2, 5].""" + head1 = head_lib.regression_head(label_dimension=2, name='head1') + head2 = head_lib.regression_head(label_dimension=3, name='head2') + multi_head = multi_head_lib.multi_head([head1, head2]) + + logits = np.array( + [[[-1., 1., 2., -2., 2.], [-1., 1., 2., -2., 2.]], + [[-1.5, 1.5, -2., 2., -2.], [-1.5, 1.5, -2., 2., -2.]]], + dtype=np.float32) + labels = { + 'head1': np.array([[[1., 0.], [1., 0.]], + [[1.5, 1.5], [1.5, 1.5]]], dtype=np.float32), + 'head2': np.array([[[0., 1., 0.], [0., 1., 0.]], + [[2., 2., 0.], [2., 2., 0.]]], dtype=np.float32), + } + # Loss for the first head: + # loss1 = (1+1)^2 + (0-1)^2 + (1+1)^2 + (0-1)^2 + + # (1.5+1.5)^2 + (1.5-1.5)^2 + (1.5+1.5)^2 + (1.5-1.5)^2 + # = 28 + # Loss for the second head: + # loss2 = (0-2)^2 + (1+2)^2 + (0-2)^2 + (0-2)^2 + (1+2)^2 + (0-2)^2 + + # (2+2)^2 + (2-2)^2 + (0+2)^2 + (2+2)^2 + (2-2)^2 + (0+2)^2 + # = 74 + expected_weighted_sum_loss = 28. + 74. + + weighted_sum_loss, example_weight_sum, _ = multi_head.create_loss( + features={}, + mode=model_fn.ModeKeys.TRAIN, + logits=logits, + labels=labels) + tol = 1e-3 + with self.test_session(): + self.assertAllClose( + expected_weighted_sum_loss, weighted_sum_loss.eval(), + rtol=tol, atol=tol) + self.assertAllClose( + 2. * 2. * 5., example_weight_sum.eval(), rtol=tol, atol=tol) def test_train_one_head(self): head1 = head_lib.multi_label_head(n_classes=2, name='head1') @@ -307,6 +513,7 @@ class MultiHeadTest(test.TestCase): six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)), train_result) _assert_simple_summaries(self, { + metric_keys.MetricKeys.LOSS: expected_loss, metric_keys.MetricKeys.LOSS + '/head1': expected_loss, # Average loss over examples. metric_keys.MetricKeys.LOSS_MEAN + '/head1': expected_loss / 2, @@ -332,7 +539,7 @@ class MultiHeadTest(test.TestCase): # (1 - labels) * (logits > 0) * logits => # head1: expected_unweighted_loss = [[10., 10.], [15., 0.]] # head2: expected_unweighted_loss = [[20., 20., 20.], [30., 0., 0]] - # Average over classes, weighted sum ober batch and heads. + # Average over classes, weighted sum over batch and heads. expected_loss_head1 = 17.5 expected_loss_head2 = 30.0 expected_loss = 1. * expected_loss_head1 + 2. * expected_loss_head2 @@ -367,6 +574,7 @@ class MultiHeadTest(test.TestCase): six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)), train_result) _assert_simple_summaries(self, { + metric_keys.MetricKeys.LOSS: expected_loss, metric_keys.MetricKeys.LOSS + '/head1': expected_loss_head1, metric_keys.MetricKeys.LOSS + '/head2': expected_loss_head2, # Average loss over examples. diff --git a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py new file mode 100644 index 0000000000000000000000000000000000000000..ca3a2394ee227f2ab78e6d4d3d882f2b10954699 --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py @@ -0,0 +1,529 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities to replicate model_fn's over local GPUs. + +This file contains util that allow to replicate `Estimator.model_fn` over +GPUs. Replicated version of a `model_fn` is returned that can subsequently +be used with `Estimator`. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import copy + +import six + +from tensorflow.core.framework import node_def_pb2 +from tensorflow.python.client import device_lib +from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.estimator import util +from tensorflow.python.estimator.export import export_output as export_output_lib +from tensorflow.python.framework import device as framework_device +from tensorflow.python.framework import ops as ops_lib +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import sparse_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.platform import tf_logging +from tensorflow.python.training import device_setter as device_setter_lib +from tensorflow.python.training import training_util + + +def replicate_model_fn(model_fn, optimizer_fn, devices=None): + """Replicate `Estimator.model_fn` over GPUs within a single host. + + The given `model_fn` specifies a single forward pass of a model. To replicate + such a model over GPUs, each GPU gets its own instance of the forward pass + (a.k.a. a tower). The input features and labels get sharded into the chunks + that correspond to the number of GPUs. Each tower computes its own loss based + on its input. For each such loss, gradients are computed. After that, the + available losses are summed to form aggregated loss. The available + gradients are summed too. Then, they update weights using the specified + optimizer. + + If `devices` are `None`, then all available GPUs are going to be used for + replication. If no GPUs are available, then the model is going to be + placed on the CPU. + + Two modes of local replication over available GPUs are supported: + 1) If exactly 1 GPU is detected, then variables and operations are placed + onto GPU. + 2) If more than 1 GPU is detected, then variables are going to be placed on + the CPU. Replicas of operations are placed on each individual GPU. + + Here is an example of how one might use their `model_fn` to run over GPUs: + ```python + def optimizer_fn(): + return tf.train.GradientDescentOptimizer(learning_rate=0.001) + ... + def model_fn(...): # See `model_fn` in `Estimator`. + loss = ... + if mode == tf.estimator.ModeKeys.TRAIN: + # See the section below on `EstimatorSpec.train_op`. + return EstimatorSpec(mode=mode, loss=loss, train_op=tf.noop()) + + # No change for `ModeKeys.EVAL` or `ModeKeys.PREDICT`. + return EstimatorSpec(...) + ... + classifier = tf.estimator.Estimator( + model_fn=replicate_model_fn.replicate_model_fn(model_fn, optimizer_fn)) + ``` + + On `EstimatorSpec.train_op`: + `model_fn` returns `EstimatorSpec.train_op` for + `tf.estimator.GraphKeys.TRAIN`. It is typically derived using an optimizer. + `replicate_model_fn` ignores the returned `EstimatorSpec.train_op`, so there + is no need to use an optimizer inside the user's `model_fn`. The + `EstimatorSpec.loss` subgraph is going to be executed, while + `EstimatorSpec.train_op` isn't going to be executed. One could pass + `train_op=tf.noop()` to `EstimatorSpec`. + + On sharding input features and labels: + Input features and labels are split for consumption by each tower. They are + split across the dimension 0. Features and labels need to be batch major. + + On reduction algorithms: + Certain algorithms were chosen for aggregating results of computations on + multiple towers: + - Losses from all towers are reduced using sum. + - Gradients are reduced using sum for each trainable variable. + - `eval_metrics_ops` are reduced per metric using `reduce_mean`. + - `EstimatorSpec.predictions` and `EstimatorSpec.export_outputs` are + reduced using concatenation. + - For all other fields of `EstimatorSpec` the values of the first tower + are taken. + + On distribution of variables: + Variables are not duplicated between towers. Instead, they are placed on a + single device as defined above and shared across towers. + + Other current limitations: + - `predictions` are not supported for `ModeKeys.EVAL`. That is required for + `tf.contrib.estimator.add_metrics`. + + Args: + model_fn: `model_fn` as defined in `Estimator`. See the section above about + the train_op argument of `EstimatorSpec`. + optimizer_fn: a function that returns an optimizer instance. The function + may accept one `params` argument. This is the `params` argument as + defined by `Estimator`. See the `Estimator` documentation for details. + devices: Optional list of devices to replicate the model across. This + argument can be used to replice only on the subset of available GPUs. + If `None`, then all available GPUs are going to be used for replication. + If no GPUs are available, then the model is going to be placed on the CPU. + + Returns: + A replicated version of the supplied `model_fn`. Returned function that + conforms to the requirements of `Estimator`'s `model_fn` and can be used + instead of the supplied `model_fn`. + """ + return _replicate_model_fn_with_mode( + model_fn, + optimizer_fn, + devices, + # TODO(isaprykin): Query system configuration to choose modes other than + # `SHARED_LOCAL_PARAMETER_SERVER`, even though it is often appropriate. + mode=_VariableDistributionMode.SHARED_LOCAL_PARAMETER_SERVER) + + +class _VariableDistributionMode(object): + """Modes for variable distribution used for forcing a particular one. + + Forcing a mode is meant for performance experimentation purposes rather than + for general use cases. + """ + + SHARED_LOCAL_PARAMETER_SERVER = 1 + """Variables are placed on a single device and shared across all devices. + + Two ways to achieve this distribution over available GPUs are supported: + 1) If exactly 1 GPU is detected, then variables and operations are placed + onto GPU. + 2) If more than 1 GPU is detected, then variables are going to be placed on + the CPU. Replicas of operations are placed on each individual GPU. + """ + + SHARED_ROUND_ROBIN = 2 + """Variables are placed on all devices in a round-robin fashion. + + Every subsequent variable is placed on the next device. There is only one + copy of each variable that is shared across all devices. + """ + + +def _replicate_model_fn_with_mode( + model_fn, + optimizer_fn, + devices=None, + mode=_VariableDistributionMode.SHARED_LOCAL_PARAMETER_SERVER): + """A version of `replicate_model_fn` that allows to specify a `mode`.""" + if not devices: + devices = _get_local_devices('GPU') or _get_local_devices('CPU') + + is_a_single_gpu_case = len(devices) == 1 and 'GPU' in devices[0] + consolidation_device = '/{}:0'.format('GPU' + if is_a_single_gpu_case else 'CPU') + + ps_devices = [consolidation_device] + if mode == _VariableDistributionMode.SHARED_ROUND_ROBIN: + ps_devices = devices + + tf_logging.info('Replicating the `model_fn` across {}. Variables are going ' + 'to be placed on {}. Consolidation device is going to be {}.' + .format(devices, ps_devices, consolidation_device)) + + def replicated_model_fn(features, labels, mode, params=None, config=None): + """Replicated version of `model_fn` to be used instead.""" + feature_shards, label_shards = _split_batch( + features, labels, len(devices), device=consolidation_device) + tower_specs = _get_loss_towers( + model_fn=model_fn, + mode=mode, + features=feature_shards, + labels=label_shards, + params=params, + config=config, + devices=devices, + local_ps_devices=ps_devices) + + if mode == model_fn_lib.ModeKeys.TRAIN: + train_op = _minimize_towers(tower_specs, + _call_optimizer_fn(optimizer_fn, params)) + return _train_spec( + tower_specs, train_op, aggregation_device=consolidation_device) + elif mode == model_fn_lib.ModeKeys.EVAL: + return _eval_spec(tower_specs, aggregation_device=consolidation_device) + elif mode == model_fn_lib.ModeKeys.PREDICT: + return _predict_spec(tower_specs, aggregation_device=consolidation_device) + + return replicated_model_fn + + +def _get_local_devices(device_type): + local_device_protos = device_lib.list_local_devices() + return [ + device.name + for device in local_device_protos + if device.device_type == device_type + ] + + +def _split_batch(features, labels, number_of_shards, device): + """Split input features and labes into batches.""" + + def split_dictionary(dictionary): + """Split a dictionary into shards.""" + shards = [{} for _ in range(number_of_shards)] + for name, tensor in six.iteritems(dictionary): + if isinstance(tensor, sparse_tensor.SparseTensor): + for i, shard in enumerate( + sparse_ops.sparse_split( + sp_input=tensor, num_split=number_of_shards, axis=0)): + shards[i][name] = shard + else: + for i, shard in enumerate(array_ops.split(tensor, number_of_shards)): + shards[i][name] = shard + return shards + + with ops_lib.name_scope('split_inputs'): + with ops_lib.device(device): + if isinstance(features, dict): + feature_shards = split_dictionary(features) + else: + feature_shards = array_ops.split(features, number_of_shards) + + if labels is None: + label_shards = None + elif isinstance(labels, dict): + label_shards = split_dictionary(labels) + else: + label_shards = array_ops.split(labels, number_of_shards) + return feature_shards, label_shards + + +_DEFAULT_NAME_SCOPE_PATTERN = 'tower_{}' + + +def _get_loss_towers(model_fn, + mode, + features, + labels, + params, + config, + devices, + local_ps_devices, + name_scope_pattern=_DEFAULT_NAME_SCOPE_PATTERN): + """Replicate the loss computation across devices.""" + tower_specs = [] + + model_fn_args = util.fn_args(model_fn) + optional_params = {} + if 'params' in model_fn_args: + optional_params['params'] = copy.deepcopy(params) + if 'config' in model_fn_args: + optional_params['config'] = copy.deepcopy(config) + + # pylint: disable=protected-access + round_robin_strategy = device_setter_lib._RoundRobinStrategy( + num_tasks=len(local_ps_devices)) + # pylint: enable=protected-access + + for i, device in enumerate(devices): + is_the_first_tower = (i == 0) + + device_setter = _local_device_setter( + worker_device=device, + ps_devices=local_ps_devices, + ps_strategy=round_robin_strategy) + + # We would like to preserve the names of the variables and ops that the user + # might be relying on. Names without a prefix are going to resolve to + # variables and ops of the first tower. + name_scope = name_scope_pattern + if is_the_first_tower: + name_scope = '' + + with variable_scope.variable_scope('', reuse=not is_the_first_tower): + with ops_lib.name_scope(name_scope.format(i)): + with ops_lib.device(device_setter): + labels_shard = None + if labels: + labels_shard = labels[i] + + tower_specs.append( + model_fn( + mode=mode, + features=features[i], + labels=labels_shard, + **optional_params)) + return tower_specs + + +def _local_device_setter(worker_device, ps_devices, ps_strategy): + """A device setter that puts distributes Var/Ops to PS/workers.""" + ps_ops = ['Variable', 'VariableV2', 'VarHandleOp'] + + def local_device_chooser(op): + current_device = framework_device.DeviceSpec.from_string(op.device or '') + + node_def = op if isinstance(op, node_def_pb2.NodeDef) else op.node_def + if node_def.op in ps_ops: + ps_device_spec = framework_device.DeviceSpec.from_string( + '{}'.format(ps_devices[ps_strategy(op)])) + + ps_device_spec.merge_from(current_device) + return ps_device_spec.to_string() + else: + worker_device_spec = framework_device.DeviceSpec.from_string( + worker_device or '') + worker_device_spec.merge_from(current_device) + return worker_device_spec.to_string() + + return local_device_chooser + + +def _minimize_towers(tower_specs, optimizer): + """Aggregate and apply gradients for computed losses.""" + grad_lists = {} + for tower_spec in tower_specs: + with ops_lib.device(tower_spec.loss.device): + for grad, var in optimizer.compute_gradients(tower_spec.loss): + if grad is not None: + grad_lists.setdefault(var, []).append(grad) + + aggregated_grads = [] + with ops_lib.name_scope('gradient_aggregating'): + for var, grads in six.iteritems(grad_lists): + grad = _compute_sum_on_device(grads, var.device) + aggregated_grads.append((grad, var)) + + train_op = optimizer.apply_gradients( + aggregated_grads, global_step=training_util.get_global_step()) + + return train_op + + +def _call_optimizer_fn(optimizer_fn, params): + arguments = {} + optimizer_fn_arguments = util.fn_args(optimizer_fn) + if 'params' in optimizer_fn_arguments: + arguments['params'] = params + return optimizer_fn(**arguments) + + +def _compute_sum_on_device(values, device, name=None): + with ops_lib.device(device): + if isinstance(values[0], ops_lib.IndexedSlices): + if name: + raise ValueError('The name {} is not expected to be given to ' + 'IndexedSlices {}'.format(name, values)) + + values_concat = array_ops.concat([v.values for v in values], axis=0) + indices_concat = array_ops.concat([v.indices for v in values], axis=0) + return ops_lib.IndexedSlices(values_concat, indices_concat, + values[0].dense_shape) + else: + return math_ops.add_n(values, name=name) + + +def _train_spec(tower_specs, + train_op, + aggregation_device, + aggregated_loss_name='loss'): + """Populate replicated EstimatorSpec for `GraphKeys.TRAIN`.""" + estimator_spec = tower_specs[0]._asdict() + estimator_spec['mode'] = model_fn_lib.ModeKeys.TRAIN + estimator_spec['train_op'] = train_op + estimator_spec['loss'] = _compute_sum_on_device( + [spec.loss for spec in tower_specs], aggregation_device, + aggregated_loss_name) + return model_fn_lib.EstimatorSpec(**estimator_spec) + + +def _eval_spec(tower_specs, aggregation_device, aggregated_loss_name='loss'): + """Populate replicated EstimatorSpec for `GraphKeys.EVAL`.""" + estimator_spec = tower_specs[0]._asdict() + estimator_spec['mode'] = model_fn_lib.ModeKeys.EVAL + estimator_spec['loss'] = _compute_sum_on_device( + [spec.loss for spec in tower_specs], aggregation_device, + aggregated_loss_name) + + update_ops = [] + for tower_spec in tower_specs: + for name, (_, update_op) in six.iteritems(tower_spec.eval_metric_ops): + update_ops.append(update_op) + + with ops_lib.control_dependencies(update_ops): + reduced_update_op = _reduce_metric_variables(len(tower_specs)) + + eval_metric_ops = {} + for name, (metric_tensor, _) in six.iteritems(tower_specs[0].eval_metric_ops): + eval_metric_ops[name] = (metric_tensor, reduced_update_op) + estimator_spec['eval_metric_ops'] = eval_metric_ops + return model_fn_lib.EstimatorSpec(**estimator_spec) + + +def _reduce_metric_variables(number_of_towers): + """Aggregate local variables used in metrics into the first tower.""" + if number_of_towers == 1: + return control_flow_ops.no_op() + + metric_variables = ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES) + variables_per_tower = len(metric_variables) // number_of_towers + + if len(metric_variables) % number_of_towers != 0: + raise ValueError( + 'Different `EstimatorSpec.eval_metric_ops` across `model_fn()` calls.' + ' Expected {} local variables, but got {} instead.'.format( + variables_per_tower * number_of_towers, len(metric_variables))) + + # `metric_variables` has the size of `variables_per_tower` x + # number_of_towers. Each tower is produced by calling the same model_fn. + # First `variables_per_tower` correspond to the first tower. Each such + # variable has an replica at the `(variables_per_tower * i)` position, where + # `i` is `[1.. number_of_towers]`. We are going to add values from replicas + # to each variable of the first tower. We then zero out replica values, so + # that `_reduce_metric_variables` operation is idempotent. If a metric + # is then computed based on local variables from the first tower, then the + # resulting metric is an estimate for all `number_of_towers` towers. + ops = [] + for i in range(0, variables_per_tower): + next_replica_id = i + variables_per_tower + replicas = [ + metric_variables[replica_id] + for replica_id in range(next_replica_id, len(metric_variables), + variables_per_tower) + ] # `replicas` doesn't contain the first-tower variable. + + reduce_op = state_ops.assign_add(metric_variables[i], + math_ops.add_n(replicas)) + + with ops_lib.control_dependencies([reduce_op]): + for replica in replicas: + zeros_for_replica = array_ops.zeros( + array_ops.shape(replica), dtype=replica.dtype) + zero_out_replica_op = state_ops.assign(replica, zeros_for_replica) + ops.append(zero_out_replica_op) + + return control_flow_ops.group(*ops) + + +def _predict_spec(tower_specs, aggregation_device): + """Populate replicated EstimatorSpec for `GraphKeys.PREDICT`.""" + estimator_spec = tower_specs[0]._asdict() + estimator_spec['mode'] = model_fn_lib.ModeKeys.PREDICT + + with ops_lib.device(aggregation_device): + estimator_spec['predictions'] = _concat_tensor_dicts( + *[tower_spec.predictions for tower_spec in tower_specs]) + + export_outputs_dict = _dict_concat( + *[tower_spec.export_outputs for tower_spec in tower_specs]) + + export_outputs = {} + for name, export_output_list in six.iteritems(export_outputs_dict): + if isinstance(export_output_list[0], export_output_lib.PredictOutput): + export_outputs[name] = export_output_lib.PredictOutput( + outputs=_concat_tensor_dicts(*[ + export_output.outputs for export_output in export_output_list + ])) + elif isinstance(export_output_list[0], + export_output_lib.RegressionOutput): + export_outputs[name] = export_output_lib.RegressionOutput( + value=array_ops.concat( + [export_output.value for export_output in export_output_list], + axis=0)) + elif isinstance(export_output_list[0], + export_output_lib.ClassificationOutput): + scores = None + if export_output_list[0].scores is not None: + scores = array_ops.concat( + [export_output.scores for export_output in export_output_list], + axis=0) + + classes = None + if export_output_list[0].classes is not None: + classes = array_ops.stack( + [export_output.classes for export_output in export_output_list], + axis=0) + + export_outputs[name] = export_output_lib.ClassificationOutput( + scores=scores, classes=classes) + + estimator_spec['export_outputs'] = export_outputs + return model_fn_lib.EstimatorSpec(**estimator_spec) + + +def _concat_tensor_dicts(*tensor_dicts): + return { + name: array_ops.concat(tensors, axis=0, name=name) + for name, tensors in six.iteritems(_dict_concat(*tensor_dicts)) + } + + +def _dict_concat(*dicts): + list_dict = {} + for d in dicts: + if d is None: + continue + + for k, v in six.iteritems(d): + list_dict.setdefault(k, []).append(v) + return list_dict diff --git a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py new file mode 100644 index 0000000000000000000000000000000000000000..a83a1b84079f115f94be33297f0ab0e2e8f2f7e3 --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py @@ -0,0 +1,1087 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for utilities that replicate `Estimator.model_fn` over GPUs.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import re +import shutil +import tempfile +import numpy as np +import six + +from tensorflow.contrib.estimator.python.estimator import replicate_model_fn +from tensorflow.python.estimator import estimator as estimator_lib +from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.estimator.canned import dnn +from tensorflow.python.estimator.canned import optimizers +from tensorflow.python.estimator.canned import prediction_keys +from tensorflow.python.estimator.export import export +from tensorflow.python.estimator.export import export_output +from tensorflow.python.estimator.inputs import numpy_io +from tensorflow.python.feature_column import feature_column +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops as ops_lib +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import metrics as metrics_lib +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.ops.losses import losses +from tensorflow.python.platform import gfile +from tensorflow.python.platform import test +from tensorflow.python.saved_model import signature_constants +from tensorflow.python.summary.writer import writer_cache +from tensorflow.python.training import device_setter +from tensorflow.python.training import gradient_descent + + +# TODO(isaprykin): Parametrize all the tests on +# replicate_model_fn._VariableDistributionMode when it's supported. +class DNNClassifierIntegrationTest(test_util.TensorFlowTestCase): + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + + def test_complete_flow_with_public_version(self): + return self._complete_flow_with_mode(mode=None) + + def test_complete_flow_with_mode_local_ps_server(self): + return self._complete_flow_with_mode( + replicate_model_fn._VariableDistributionMode. + SHARED_LOCAL_PARAMETER_SERVER) + + def test_complete_flow_with_mode_round_robin(self): + return self._complete_flow_with_mode( + replicate_model_fn._VariableDistributionMode.SHARED_ROUND_ROBIN) + + def _complete_flow_with_mode(self, mode): + n_classes = 3 + input_dimension = 2 + batch_size = 12 + + data = np.linspace( + 0., n_classes - 1., batch_size * input_dimension, dtype=np.float32) + x_data = data.reshape(batch_size, input_dimension) + categorical_data = np.random.random_integers( + 0, len(x_data), size=len(x_data)) + y_data = np.reshape(self._as_label(data[:batch_size]), (batch_size, 1)) + train_input_fn = numpy_io.numpy_input_fn( + x={'x': x_data, + 'categories': categorical_data}, + y=y_data, + batch_size=batch_size, + num_epochs=None, + shuffle=True) + eval_input_fn = numpy_io.numpy_input_fn( + x={'x': x_data, + 'categories': categorical_data}, + y=y_data, + batch_size=batch_size, + shuffle=False) + predict_input_fn = numpy_io.numpy_input_fn( + x={'x': x_data, + 'categories': categorical_data}, + batch_size=batch_size, + shuffle=False) + + feature_columns = [ + feature_column.numeric_column('x', shape=(input_dimension,)), + feature_column.embedding_column( + feature_column.categorical_column_with_vocabulary_list( + 'categories', + vocabulary_list=np.linspace( + 0., len(x_data), len(x_data), dtype=np.int64)), 1) + ] + + estimator = dnn.DNNClassifier( + hidden_units=(2, 2), + feature_columns=feature_columns, + n_classes=n_classes, + model_dir=self._model_dir) + + def optimizer_fn(): + return optimizers.get_optimizer_instance('Adagrad', learning_rate=0.05) + + if not mode: # Use the public `replicate_model_fn`. + model_fn = replicate_model_fn.replicate_model_fn( + estimator.model_fn, + optimizer_fn, + devices=['/gpu:0', '/gpu:1', '/gpu:2']) + else: + model_fn = replicate_model_fn._replicate_model_fn_with_mode( + estimator.model_fn, + optimizer_fn, + devices=['/gpu:0', '/gpu:1', '/gpu:2'], + mode=mode) + + estimator = estimator_lib.Estimator( + model_fn=model_fn, + model_dir=estimator.model_dir, + config=estimator.config, + params=estimator.params) + + num_steps = 10 + estimator.train(train_input_fn, steps=num_steps) + + scores = estimator.evaluate(eval_input_fn) + self.assertEqual(num_steps, scores[ops_lib.GraphKeys.GLOBAL_STEP]) + self.assertIn('loss', six.iterkeys(scores)) + + predicted_proba = np.array([ + x[prediction_keys.PredictionKeys.PROBABILITIES] + for x in estimator.predict(predict_input_fn) + ]) + self.assertAllEqual((batch_size, n_classes), predicted_proba.shape) + + feature_spec = feature_column.make_parse_example_spec(feature_columns) + serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn( + feature_spec) + export_dir = estimator.export_savedmodel(tempfile.mkdtemp(), + serving_input_receiver_fn) + self.assertTrue(gfile.Exists(export_dir)) + + def _as_label(self, data_in_float): + return np.rint(data_in_float).astype(np.int64) + + def tearDown(self): + if self._model_dir: + writer_cache.FileWriterCache.clear() + shutil.rmtree(self._model_dir) + + +class ReplicateModelTest(test_util.TensorFlowTestCase): + + def model_fn(self, mode, features, labels, params): + c = variable_scope.get_variable( + 'c', + initializer=constant_op.constant(10, dtype=dtypes.float64), + dtype=dtypes.float64) + + predictions = math_ops.multiply(features, c) + + loss = None + if mode is not model_fn_lib.ModeKeys.PREDICT: + loss = losses.absolute_difference( + labels=labels, + predictions=predictions, + reduction=losses.Reduction.SUM) + loss = math_ops.reduce_sum(loss) + + metrics = { + 'accuracy': metrics_lib.accuracy(labels, predictions), + 'auc': metrics_lib.auc(labels, predictions) + } + + return model_fn_lib.EstimatorSpec( + mode=mode, + loss=loss, + eval_metric_ops=metrics, + predictions={'probabilities': predictions}, + train_op=control_flow_ops.no_op()) # This train_op isn't actually used. + + def optimizer_fn(self, params): + return gradient_descent.GradientDescentOptimizer(params['learning_rate']) + + @property + def params(self): + params = {} + params['learning_rate'] = 1.0 + return params + + def test_train(self): + features = np.array([[1.0], [2.0]]) + labels = np.array([[1.0], [2.0]]) + + with self.test_session() as session: + replicated_model_fn = replicate_model_fn.replicate_model_fn( + self.model_fn, self.optimizer_fn, devices=['/gpu:0', '/gpu:1']) + estimator_spec = replicated_model_fn( + features, labels, model_fn_lib.ModeKeys.TRAIN, self.params) + session.run(variables.global_variables_initializer()) + + # loss = feature * c - label + total_loss = (1.0 * 10 - 1.0) + (2.0 * 10 - 2.0) + self.assertEqual(total_loss, session.run(estimator_spec.loss)) + + # loss' of c is 3. + # new value of c = 10 - learning rate * 3 = 7.0. + session.run(estimator_spec.train_op) + with variable_scope.variable_scope('', reuse=True): + c = variable_scope.get_variable('c', dtype=dtypes.float64) + self.assertEqual(7.0, session.run(c)) + + def test_train_spec_with_optimizer_without_params(self): + + def optimizer_fn_without_params(): + return gradient_descent.GradientDescentOptimizer(learning_rate=1.0) + + features = np.array([[1.0], [2.0]]) + labels = np.array([[1.0], [2.0]]) + + with self.test_session() as session: # pylint: disable=unused-variable + replicated_model_fn = replicate_model_fn.replicate_model_fn( + self.model_fn, + optimizer_fn_without_params, + devices=['/gpu:0', '/gpu:1']) + # This call is going to fail if `replicated_model_fn` is still passing + # `params` inside `optimizer_fn`, even though the latter doesn't take any: + estimator_spec = replicated_model_fn( + features, labels, model_fn_lib.ModeKeys.TRAIN, self.params) + del estimator_spec + + def test_eval(self): + features = np.array([[0.01], [0.002]]) + labels = np.array([[0.01], [0.02]]) + + with self.test_session() as session: + replicated_model_fn = replicate_model_fn.replicate_model_fn( + self.model_fn, self.optimizer_fn, devices=['/gpu:0', '/gpu:1']) + estimator_spec = replicated_model_fn( + features, labels, model_fn_lib.ModeKeys.EVAL, self.params) + session.run(variables.local_variables_initializer()) + session.run(variables.global_variables_initializer()) + + accuracy, a = estimator_spec.eval_metric_ops['accuracy'] + auc, b = estimator_spec.eval_metric_ops['auc'] + + session.run([a, b]) + accuracy = session.run(accuracy) + auc = session.run(auc) + + # loss[i] = features[i] * 10 - labels[i]. + # Accuracy is 0.0 (no match) in the first tower. + # Accuracy is 1.0 (match) in the second tower, since the feature + # times weight "c" happened to be equal to the label. + total_loss = ((0.01 * 10 - 0.01) + (0.002 * 10 - 0.02)) + + self.assertNear((0.0 + 1.0) / 2.0, accuracy, 0.01) + self.assertEqual(0, auc) + self.assertNear(total_loss, session.run(estimator_spec.loss), 0.01) + + def test_predict(self): + features = np.array([[0.01], [0.002]]) + labels = np.array([[0.01], [0.02]]) + + with self.test_session() as session: + replicated_model_fn = replicate_model_fn.replicate_model_fn( + self.model_fn, self.optimizer_fn, devices=['/gpu:0', '/gpu:1']) + estimator_spec = replicated_model_fn( + features, labels, model_fn_lib.ModeKeys.PREDICT, self.params) + session.run(variables.global_variables_initializer()) + + self.assertAllClose({ + 'probabilities': np.array([[0.1], [0.02]]) + }, session.run(estimator_spec.predictions)) + + def test_train_single_tower(self): + features = np.array([[1.0], [2.0]]) + labels = np.array([[1.0], [2.0]]) + + with self.test_session() as session: + replicated_model_fn = replicate_model_fn.replicate_model_fn( + self.model_fn, self.optimizer_fn, devices=['/gpu:0']) + estimator_spec = replicated_model_fn( + features, labels, model_fn_lib.ModeKeys.TRAIN, self.params) + session.run(variables.global_variables_initializer()) + + # loss = feature * c - label + total_loss = (1.0 * 10 - 1.0) + (2.0 * 10 - 2.0) + self.assertEqual(total_loss, session.run(estimator_spec.loss)) + + # loss' of c is 3. + # new value of c = 10 - learning rate * 3 = 7.0. + session.run(estimator_spec.train_op) + with variable_scope.variable_scope('', reuse=True): + c = variable_scope.get_variable('c', dtype=dtypes.float64) + self.assertEqual(7.0, session.run(c)) + + def test_eval_single_tower(self): + features = np.array([[0.01], [0.002]]) + labels = np.array([[0.01], [0.02]]) + + with self.test_session() as session: + replicated_model_fn = replicate_model_fn.replicate_model_fn( + self.model_fn, self.optimizer_fn, devices=['/gpu:0']) + estimator_spec = replicated_model_fn( + features, labels, model_fn_lib.ModeKeys.EVAL, self.params) + session.run(variables.local_variables_initializer()) + session.run(variables.global_variables_initializer()) + + accuracy, a = estimator_spec.eval_metric_ops['accuracy'] + auc, b = estimator_spec.eval_metric_ops['auc'] + + session.run([a, b]) + accuracy = session.run(accuracy) + auc = session.run(auc) + + # Accuracy is 0.0 (no match) in the first tower. + # Accuracy is 1.0 (match) in the second tower, since the feature + # times weight "c" happened to be equal to the label. + total_loss = ((0.01 * 10 - 0.01) + (0.002 * 10 - 0.02)) + + self.assertNear((0.0 + 1.0) / 2.0, accuracy, 0.01) + self.assertEqual(0, auc) + self.assertNear(total_loss, session.run(estimator_spec.loss), 0.01) + + def test_predict_single_tower(self): + features = np.array([[0.01], [0.002]]) + labels = np.array([[0.01], [0.02]]) + + with self.test_session() as session: + replicated_model_fn = replicate_model_fn.replicate_model_fn( + self.model_fn, self.optimizer_fn, devices=['/gpu:0']) + estimator_spec = replicated_model_fn( + features, labels, model_fn_lib.ModeKeys.PREDICT, self.params) + session.run(variables.global_variables_initializer()) + + self.assertAllClose({ + 'probabilities': np.array([[0.1], [0.02]]) + }, session.run(estimator_spec.predictions)) + + +class GetLossTowersTest(test_util.TensorFlowTestCase): + + def model_fn(self, mode, features, labels, params): + c = variable_scope.get_variable( + 'c', + initializer=constant_op.constant(0.25, dtype=dtypes.float64), + dtype=dtypes.float64) + + predictions = math_ops.add(np.array([0.1, 0.2, 0.3, features[0]]), c) + labels = np.array([0.1, 0.2, 0.3, labels[0]]) + + loss = losses.absolute_difference( + labels=labels, predictions=predictions, reduction=losses.Reduction.SUM) + + return model_fn_lib.EstimatorSpec(mode=mode, loss=math_ops.reduce_sum(loss)) + + def test_gradients_are_computed(self): + with self.test_session() as session: + tower_specs = replicate_model_fn._get_loss_towers( + self.model_fn, + mode=None, + features=[[0.6], [1.6]], + labels=[[0.6], [0.6]], + params=None, + config=None, + devices=['/gpu:0', '/gpu:1'], + local_ps_devices=['/gpu:0'], + name_scope_pattern='test_tower_{}') + session.run(variables.global_variables_initializer()) + + self.assertEqual(len(tower_specs), 2) + + self.assertEqual('/device:GPU:0', tower_specs[0].loss.device) + self.assertEqual('Sum:0', tower_specs[0].loss.name) + self.assertEqual(1.0, session.run(tower_specs[0].loss)) + + self.assertEqual('/device:GPU:1', tower_specs[1].loss.device) + self.assertEqual('test_tower_1/Sum:0', tower_specs[1].loss.name) + # The input batch for the second tower had a loss that is 1.0 + # bigger: 0.6 vs 1.6. + self.assertEqual(2.0, session.run(tower_specs[1].loss)) + + self.assertEqual(1, len(variables.global_variables())) + self.assertEqual(1, len(variables.trainable_variables())) + + with variable_scope.variable_scope('', reuse=True): + c = variable_scope.get_variable('c', dtype=dtypes.float64) + self.assertEqual(0.25, session.run(c)) + + def test_variables_are_round_robined_correctly(self): + """Test that creates multiple variables and tests round-robin placement.""" + + def model_fn(mode, features, labels, params): + del params + for variable_name in ['a', 'b', 'c', 'd']: + c = variable_scope.get_variable( + variable_name, + initializer=constant_op.constant(0.25, dtype=dtypes.float64), + dtype=dtypes.float64) + + predictions = math_ops.add(np.array([0.1, 0.2, 0.3, features[0]]), c) + labels = np.array([0.1, 0.2, 0.3, labels[0]]) + loss = losses.absolute_difference( + labels=labels, + predictions=predictions, + reduction=losses.Reduction.SUM) + return model_fn_lib.EstimatorSpec( + mode=mode, loss=math_ops.reduce_sum(loss)) + + with self.test_session() as session: + tower_specs = replicate_model_fn._get_loss_towers( + model_fn, + mode=None, + features=[[0.6], [1.6], [2.6]], + labels=[[0.6], [0.6], [2.6]], + params=None, + config=None, + devices=['/gpu:0', '/gpu:1', '/gpu:3'], + local_ps_devices=['/gpu:0', '/gpu:1', '/gpu:3'], + name_scope_pattern='test_tower_{}') + session.run(variables.global_variables_initializer()) + + self.assertEqual(len(tower_specs), 3) + self.assertEqual('/device:GPU:0', tower_specs[0].loss.device) + self.assertEqual('/device:GPU:1', tower_specs[1].loss.device) + self.assertEqual('/device:GPU:3', tower_specs[2].loss.device) + + with variable_scope.variable_scope('', reuse=True): + a = variable_scope.get_variable('a', dtype=dtypes.float64) + self.assertEqual('/device:GPU:0', a.device) + b = variable_scope.get_variable('b', dtype=dtypes.float64) + self.assertEqual('/device:GPU:1', b.device) + c = variable_scope.get_variable('c', dtype=dtypes.float64) + self.assertEqual('/device:GPU:3', c.device) + d = variable_scope.get_variable('d', dtype=dtypes.float64) + self.assertEqual('/device:GPU:0', d.device) + + +class SplitBatchTest(test_util.TensorFlowTestCase): + + def evaluate_shards(self, first_list, second_list): + evaluate_items = lambda x: x.eval() + return list(map(evaluate_items, first_list)), list( + map(evaluate_items, second_list)) + + def test_simple_half_split(self): + with self.test_session() as session: # pylint: disable=unused-variable + features = [0.0, 1.0, 2.0, 3.0] + labels = [10.0, 11.0, 12.0, 13.0] + feature_shards, label_shards = replicate_model_fn._split_batch( + features, labels, 2, device='/gpu:0') + + feature_shards, label_shards = self.evaluate_shards( + feature_shards, label_shards) + + self.assertAllEqual([[0.0, 1.0], [2.0, 3.0]], feature_shards) + self.assertAllEqual([[10.0, 11.0], [12.0, 13.0]], label_shards) + + def test_to_each_their_own(self): + with self.test_session() as session: # pylint: disable=unused-variable + features = [0.0, 1.0, 2.0, 3.0] + labels = [10.0, 11.0, 12.0, 13.0] + feature_shards, label_shards = replicate_model_fn._split_batch( + features, labels, 4, device='/gpu:0') + + feature_shards, label_shards = self.evaluate_shards( + feature_shards, label_shards) + + self.assertAllEqual([[0.0], [1.0], [2.0], [3.0]], feature_shards) + self.assertAllEqual([[10.0], [11.0], [12.0], [13.0]], label_shards) + + def test_one_batch(self): + with self.test_session() as session: # pylint: disable=unused-variable + features = [0.0, 1.0, 2.0, 3.0] + labels = [10.0, 11.0, 12.0, 13.0] + feature_shards, label_shards = replicate_model_fn._split_batch( + features, labels, 1, device='/gpu:0') + + feature_shards, label_shards = self.evaluate_shards( + feature_shards, label_shards) + + self.assertAllEqual([[0.0, 1.0, 2.0, 3.0]], feature_shards) + self.assertAllEqual([[10.0, 11.0, 12.0, 13.0]], label_shards) + + def test_half_split_in_dictionary(self): + with self.test_session() as session: # pylint: disable=unused-variable + features = {'first': [0.0, 1.0, 2.0, 3.0], 'second': [4.0, 5.0, 6.0, 7.0]} + labels = [10.0, 11.0, 12.0, 13.0] + + feature_shards, label_shards = replicate_model_fn._split_batch( + features, labels, 2, device='/gpu:0') + + self.assertAllEqual([0.0, 1.0], feature_shards[0]['first'].eval()) + self.assertAllEqual([4.0, 5.0], feature_shards[0]['second'].eval()) + self.assertAllEqual([2.0, 3.0], feature_shards[1]['first'].eval()) + self.assertAllEqual([6.0, 7.0], feature_shards[1]['second'].eval()) + self.assertAllEqual([10.0, 11.0], label_shards[0].eval()) + self.assertAllEqual([12.0, 13.0], label_shards[1].eval()) + + def test_one_batch_in_dictionary(self): + with self.test_session() as session: # pylint: disable=unused-variable + features = {'first': [0.0, 1.0, 2.0, 3.0], 'second': [4.0, 5.0, 6.0, 7.0]} + labels = [10.0, 11.0, 12.0, 13.0] + + feature_shards, label_shards = replicate_model_fn._split_batch( + features, labels, 1, device='/gpu:0') + + self.assertAllEqual([0.0, 1.0, 2.0, 3.0], + feature_shards[0]['first'].eval()) + self.assertAllEqual([4.0, 5.0, 6.0, 7.0], + feature_shards[0]['second'].eval()) + self.assertAllEqual([10.0, 11.0, 12.0, 13.0], label_shards[0].eval()) + + def test_feature_and_label_dictionaries(self): + with self.test_session() as session: # pylint: disable=unused-variable + features = {'first': [0.0, 1.0, 2.0, 3.0], 'second': [4.0, 5.0, 6.0, 7.0]} + labels = {'first': [10.0, 11.0], 'second': [12.0, 13.0]} + + feature_shards, label_shards = replicate_model_fn._split_batch( + features, labels, 2, device='/gpu:0') + + self.assertAllEqual([0.0, 1.0], feature_shards[0]['first'].eval()) + self.assertAllEqual([4.0, 5.0], feature_shards[0]['second'].eval()) + self.assertAllEqual([2.0, 3.0], feature_shards[1]['first'].eval()) + self.assertAllEqual([6.0, 7.0], feature_shards[1]['second'].eval()) + self.assertAllEqual([10.0], label_shards[0]['first'].eval()) + self.assertAllEqual([12.0], label_shards[0]['second'].eval()) + self.assertAllEqual([11], label_shards[1]['first'].eval()) + self.assertAllEqual([13.0], label_shards[1]['second'].eval()) + + +class TrainSpecTest(test_util.TensorFlowTestCase): + + expected_predictions = {} + + def create_estimator_spec(self, loss): + return model_fn_lib.EstimatorSpec( + mode=model_fn_lib.ModeKeys.TRAIN, + loss=loss, + train_op=loss, # Not used; currently required. + predictions=self.expected_predictions) + + def create_constant_loss(self, loss_value): + return constant_op.constant(loss_value, dtype=dtypes.float64) + + def test_example(self): + with self.test_session() as session: + tower_losses = list(map(self.create_constant_loss, [2, 4, 6])) + tower_specs = list(map(self.create_estimator_spec, tower_losses)) + + expected_train_op = tower_losses[1] + + estimator_spec = replicate_model_fn._train_spec( + tower_specs, expected_train_op, aggregation_device='/gpu:0') + + self.assertEqual(expected_train_op, estimator_spec.train_op) + self.assertEqual(2 + 4 + 6, session.run(estimator_spec.loss)) + self.assertEqual(self.expected_predictions, estimator_spec.predictions) + + +class EvalSpecTest(test_util.TensorFlowTestCase): + + def create_estimator_spec(self, loss, metrics): + return model_fn_lib.EstimatorSpec( + mode=model_fn_lib.ModeKeys.EVAL, loss=loss, eval_metric_ops=metrics) + + def create_constant_loss(self, loss_value): + return constant_op.constant(loss_value, dtype=dtypes.float64) + + def create_eval_metrics(self, noise): + predictions = np.array([0.1, 0.2, 0.3, 0.6 + noise]) + labels = np.array([0.1, 0.2, 0.3, 0.6]) + + metrics = { + 'accuracy': metrics_lib.accuracy(labels, predictions), + 'auc': metrics_lib.auc(labels, predictions) + } + return metrics + + def test_example(self): + with self.test_session() as session: + tower_losses = map(self.create_constant_loss, [2, 4, 6]) + tower_metrics = map(self.create_eval_metrics, [0, 0.2, 0.3]) + tower_specs = [ + self.create_estimator_spec(l, m) + for l, m in zip(tower_losses, tower_metrics) + ] + session.run(variables.local_variables_initializer()) + + estimator_spec = replicate_model_fn._eval_spec( + tower_specs, aggregation_device='/device:GPU:0') + + accuracy, a = estimator_spec.eval_metric_ops['accuracy'] + auc, b = estimator_spec.eval_metric_ops['auc'] + + self.assertEqual('/device:CPU:0', accuracy.device) + self.assertEqual('/device:CPU:0', auc.device) + + session.run([a, b]) + accuracy, auc = session.run([accuracy, auc]) + + self.assertNear((12 - 2) / 12, accuracy, 0.01) + self.assertEqual(0, auc) + self.assertEqual(2 + 4 + 6, session.run(estimator_spec.loss)) + + def test_handles_single_tower(self): + with self.test_session() as session: + tower_losses = map(self.create_constant_loss, [5]) + tower_metrics = map(self.create_eval_metrics, [0.2]) + tower_specs = [ + self.create_estimator_spec(l, m) + for l, m in zip(tower_losses, tower_metrics) + ] + session.run(variables.local_variables_initializer()) + + estimator_spec = replicate_model_fn._eval_spec( + tower_specs, aggregation_device='/device:GPU:0') + + accuracy, a = estimator_spec.eval_metric_ops['accuracy'] + auc, b = estimator_spec.eval_metric_ops['auc'] + + self.assertEqual('/device:CPU:0', accuracy.device) + self.assertEqual('/device:CPU:0', auc.device) + + session.run([a, b]) + accuracy = session.run(accuracy) + auc = session.run(auc) + + self.assertNear((4 - 1) / 4, accuracy, 0.01) + self.assertEqual(0, auc) + self.assertEqual(5, session.run(estimator_spec.loss)) + + +class PredictSpecTest(test_util.TensorFlowTestCase): + + def model_fn(self, mode, features, labels, params): + c = variable_scope.get_variable( + 'c', + initializer=constant_op.constant(0.25, dtype=dtypes.float64), + dtype=dtypes.float64) + + predictions = math_ops.add(np.array([features[0], features[0]]), c) + + return model_fn_lib.EstimatorSpec( + mode=model_fn_lib.ModeKeys.PREDICT, + predictions={ + 'probabilities': predictions + }) + + def test_example(self): + with self.test_session() as session: + tower_specs = replicate_model_fn._get_loss_towers( + self.model_fn, + mode=None, + features=[[0.1], [0.2]], + labels=[[], []], + params=None, + config=None, + devices=['/gpu:0', '/gpu:1'], + local_ps_devices=['/gpu:0'], + ) + session.run(variables.global_variables_initializer()) + + estimator_spec = replicate_model_fn._predict_spec( + tower_specs, aggregation_device='/gpu:0') + + self.assertEqual('/device:GPU:0', + estimator_spec.predictions['probabilities'].device) + self.assertAllClose({ + 'probabilities': np.array([0.35, 0.35, 0.45, 0.45]) + }, session.run(estimator_spec.predictions)) + + +class ReduceMetricVariablesTest(test_util.TensorFlowTestCase): + + def create_metric_variable(self, initial_value, name): + return variable_scope.variable( + initial_value, + trainable=False, + collections=[ops_lib.GraphKeys.METRIC_VARIABLES], + validate_shape=True, + name=name) + + def create_tower_metrics(self, tower_id): + with variable_scope.variable_scope('', reuse=(tower_id != 0)): + self.create_metric_variable(1.3 * (tower_id + 1), 'total') + self.create_metric_variable(2.3 * (tower_id + 1), 'count') + self.create_metric_variable( + np.array([3.3, 3.5, 3.7]) * (tower_id + 1), 'total') + + def test_example(self): + with self.test_session() as session: + for tower_id in range(3): + self.create_tower_metrics(tower_id) + + session.run( + variables.variables_initializer( + ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES))) + + session.run( + replicate_model_fn._reduce_metric_variables(number_of_towers=3)) + + # 1st tower = 1.3, 2.3, [3.3, 3.5, 3.7] + # 2nd tower = 2.6, 4.6, [6.6, 7.0, 7.4] + # 3rd tower = 3.9, 6.9, [9.9, 10.5, 11.1] + # Reduced = 7.8, 13.8, [19.8, 21.0, 22.2] + # Towers are accumulated in the first tower. + local_metrics = session.run( + ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES)) + + self.assertNear(7.8, local_metrics[0], 0.01) + self.assertNear(13.8, local_metrics[1], 0.01) + self.assertAllClose([19.8, 21., 22.1], local_metrics[2], 0.01) + self.assertNear(0.0, local_metrics[3], 0.01) + self.assertNear(0.0, local_metrics[4], 0.01) + self.assertAllClose([0.0, 0.0, 0.0], local_metrics[5], 0.01) + self.assertNear(0.0, local_metrics[6], 0.01) + self.assertNear(0.0, local_metrics[7], 0.01) + self.assertAllClose([0.0, 0.0, 0.0], local_metrics[8], 0.01) + + def test_reduce_is_idempotent(self): + with self.test_session() as session: + for tower_id in range(3): + self.create_tower_metrics(tower_id) + + session.run( + variables.variables_initializer( + ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES))) + + for _ in range(20): + session.run( + replicate_model_fn._reduce_metric_variables(number_of_towers=3)) + + local_metrics = session.run( + ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES)) + + self.assertNear(7.8, local_metrics[0], 0.01) + self.assertNear(13.8, local_metrics[1], 0.01) + self.assertAllClose([19.8, 21., 22.1], local_metrics[2], 0.01) + self.assertNear(0.0, local_metrics[3], 0.01) + self.assertNear(0.0, local_metrics[4], 0.01) + self.assertAllClose([0.0, 0.0, 0.0], local_metrics[5], 0.01) + self.assertNear(0.0, local_metrics[6], 0.01) + self.assertNear(0.0, local_metrics[7], 0.01) + self.assertAllClose([0.0, 0.0, 0.0], local_metrics[8], 0.01) + + def test_handles_single_tower(self): + with self.test_session() as session: + self.create_tower_metrics(0) + session.run( + variables.variables_initializer( + ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES))) + + session.run( + replicate_model_fn._reduce_metric_variables(number_of_towers=1)) + + local_metrics = session.run( + ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES)) + + self.assertNear(1.3, local_metrics[0], 0.01) + self.assertNear(2.3, local_metrics[1], 0.01) + self.assertAllClose([3.3, 3.5, 3.7], local_metrics[2], 0.01) + + def test_doesnt_accept_uneven_number_of_variables(self): + with self.test_session() as session: + for tower_id in range(3): + self.create_tower_metrics(tower_id) + self.create_metric_variable(-1.0, 'oddball') + + session.run( + variables.variables_initializer( + ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES))) + + with self.assertRaisesRegexp(ValueError, ''): + session.run( + replicate_model_fn._reduce_metric_variables(number_of_towers=3)) + + +class MergeExportOutputsTest(test_util.TensorFlowTestCase): + + def optimizer_fn(self): + return gradient_descent.GradientDescentOptimizer(1.0) + + def model_fn(self, mode, features, labels, params): + c = variable_scope.get_variable( + 'c', + initializer=constant_op.constant(10, dtype=dtypes.float64), + dtype=dtypes.float64) + + predictions = {'probabilities': math_ops.multiply(features, c)} + loss = losses.absolute_difference( + labels=labels, + predictions=predictions['probabilities'], + reduction=losses.Reduction.SUM) + + metrics = { + 'accuracy': metrics_lib.accuracy(labels, predictions['probabilities']), + 'auc': metrics_lib.auc(labels, predictions['probabilities']) + } + tensor_string_repr = str(features) + classes = constant_op.constant( + re.search('(split_inputs/split:[0-9])', tensor_string_repr).group(1), + dtype=dtypes.string) + + export_outputs = { + signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: + export_output.PredictOutput(predictions), + 'classification_output': + export_output.ClassificationOutput(predictions['probabilities'], + classes), + 'classification_scores': + export_output.ClassificationOutput( + scores=predictions['probabilities']), + 'classification_classes': + export_output.ClassificationOutput(classes=classes), + 'regression_output': + export_output.RegressionOutput(predictions['probabilities']), + } + + return model_fn_lib.EstimatorSpec( + mode=mode, + loss=math_ops.reduce_sum(loss), + eval_metric_ops=metrics, + predictions=predictions, + train_op=loss, # This train_op isn't actually used. + export_outputs=export_outputs) + + def replicate_estimator_spec(self, session): + features = np.array([0.01, 0.002]) + labels = np.array([0.01, 0.02]) + + replicated_model_fn = replicate_model_fn.replicate_model_fn( + self.model_fn, self.optimizer_fn, devices=['/gpu:0', '/gpu:1']) + estimator_spec = replicated_model_fn(features, labels, + model_fn_lib.ModeKeys.PREDICT, {}) + session.run(variables.global_variables_initializer()) + return estimator_spec + + def test_merde_predict_output(self): + with self.test_session() as session: + estimator_spec = self.replicate_estimator_spec(session) + self.assertAllClose( + { + 'probabilities': np.array([0.1, 0.02]) + }, + session.run(estimator_spec.export_outputs[ + signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY].outputs)) + + def test_merge_classification_output_scores_classes(self): + with self.test_session() as session: + estimator_spec = self.replicate_estimator_spec(session) + self.assertAllClose( + [0.1, 0.02], + session.run( + estimator_spec.export_outputs['classification_output'].scores)) + self.assertAllEqual( + [b'split_inputs/split:0', b'split_inputs/split:1'], + session.run( + estimator_spec.export_outputs['classification_output'].classes)) + + def test_merge_classification_output_scores(self): + with self.test_session() as session: + estimator_spec = self.replicate_estimator_spec(session) + self.assertAllClose( + [0.1, 0.02], + session.run( + estimator_spec.export_outputs['classification_scores'].scores)) + self.assertEqual( + None, estimator_spec.export_outputs['classification_scores'].classes) + + def test_merge_classification_output_classes(self): + with self.test_session() as session: + estimator_spec = self.replicate_estimator_spec(session) + self.assertAllEqual( + [b'split_inputs/split:0', b'split_inputs/split:1'], + session.run( + estimator_spec.export_outputs['classification_classes'].classes)) + self.assertEqual( + None, estimator_spec.export_outputs['classification_classes'].scores) + + def test_merge_regression_output(self): + with self.test_session() as session: + estimator_spec = self.replicate_estimator_spec(session) + self.assertAllClose( + [0.1, 0.02], + session.run(estimator_spec.export_outputs['regression_output'].value)) + + +class GetLocalDevicesTest(test_util.TensorFlowTestCase): + + def test_there_is_at_least_a_cpu(self): + self.assertTrue(replicate_model_fn._get_local_devices('CPU')) + + def test_there_is_no_xpu(self): + self.assertFalse( + replicate_model_fn._get_local_devices('XPU')) # XPU doesn't exist. + + def test_whether_there_is_a_gpu(self): + if test.is_gpu_available(): + self.assertTrue(len(replicate_model_fn._get_local_devices('GPU'))) + + +class LocalDeviceSetterTest(test_util.TensorFlowTestCase): + + def test_vars_are_on_ps_but_ops_are_on_workers(self): + ps_devices = ['/device:GPU:3'] + round_robin = device_setter._RoundRobinStrategy(num_tasks=len(ps_devices)) + + local_device_setter = replicate_model_fn._local_device_setter( + ps_devices=ps_devices, + ps_strategy=round_robin, + worker_device='/device:GPU:2') + + with ops_lib.device(local_device_setter): + a = variables.Variable(0.01) + self.assertEqual('/device:GPU:3', a.device) + + b = variables.Variable(0.02) + self.assertEqual('/device:GPU:3', b.device) + + c = variables.Variable(0.03) + self.assertEqual('/device:GPU:3', c.device) + + a_op = array_ops.concat(a, axis=0) + self.assertEqual('/device:GPU:2', a_op.device) + + b_op = array_ops.concat(b, axis=0) + self.assertEqual('/device:GPU:2', b_op.device) + + def test_round_robin_placement(self): + ps_devices = [ + '/device:GPU:0', '/device:GPU:1', '/device:GPU:3', '/device:GPU:4' + ] + round_robin = device_setter._RoundRobinStrategy(num_tasks=len(ps_devices)) + + local_device_setter = replicate_model_fn._local_device_setter( + ps_devices=ps_devices, + ps_strategy=round_robin, + worker_device='/device:GPU:2') + + with ops_lib.device(local_device_setter): + a = variables.Variable(0.01) + self.assertEqual('/device:GPU:0', a.device) + + b = variables.Variable(0.02) + self.assertEqual('/device:GPU:1', b.device) + + c = variables.Variable(0.03) + self.assertEqual('/device:GPU:3', c.device) + + a_op = array_ops.concat(a, axis=0) + self.assertEqual('/device:GPU:2', a_op.device) + + b_op = array_ops.concat(b, axis=0) + self.assertEqual('/device:GPU:2', b_op.device) + + c = variables.Variable(0.03) + self.assertEqual('/device:GPU:4', c.device) + + d = variables.Variable(0.03) + self.assertEqual('/device:GPU:0', d.device) + + c_op = array_ops.concat(c, axis=0) + self.assertEqual('/device:GPU:2', c_op.device) + + +class ComputeSumWithDevicePlacementTest(test_util.TensorFlowTestCase): + + def test_vectors(self): + with self.test_session() as session: + total = replicate_model_fn._compute_sum_on_device( + [1.0, 2.0, 3.0, 4.0], device='/device:GPU:0', name='test_sum') + + self.assertEqual('/device:GPU:0', total.device) + self.assertEqual('test_sum', total.op.name) + self.assertEqual(10.0, session.run(total)) + + def test_tensors(self): + with self.test_session() as session: + total = replicate_model_fn._compute_sum_on_device( + [[1.0, 2.0], [3.0, 4.0]], device='/device:GPU:0', name='test_sum') + + self.assertEqual('/device:GPU:0', total.device) + self.assertEqual('test_sum', total.op.name) + self.assertAllEqual([4.0, 6.0], session.run(total)) + + def test_indexedslices(self): + with self.test_session() as session: + a = ops_lib.IndexedSlices( + constant_op.constant([1.0, 2.0]), [0, 1], + dense_shape=constant_op.constant([2])) + b = ops_lib.IndexedSlices(constant_op.constant([3.0, 4.0]), [0, 1]) + + total = replicate_model_fn._compute_sum_on_device( + [a, b], device='/device:GPU:0') + + self.assertEqual('/device:GPU:0', total.device) + self.assertAllEqual([4.0, 6.0], + session.run(ops_lib.convert_to_tensor(total))) + + def test_indexedslices_higher_dimensions(self): + with self.test_session() as session: + a = ops_lib.IndexedSlices( + constant_op.constant([[1.0, 5.0], [2.0, 6.0]]), [0, 1], + dense_shape=constant_op.constant([2, 4])) + b = ops_lib.IndexedSlices( + constant_op.constant([[3.0, 7.0], [4.0, 8.0]]), [0, 1]) + + total = replicate_model_fn._compute_sum_on_device( + [a, b], device='/device:GPU:0') + + self.assertEqual('/device:GPU:0', total.device) + self.assertAllEqual([[4.0, 12.0], [6.0, 14.0]], + session.run(ops_lib.convert_to_tensor(total))) + + def test_indexedslices_some_dont_overlap(self): + with self.test_session() as session: + a = ops_lib.IndexedSlices( + constant_op.constant([1.0, 2.0]), [0, 3], + dense_shape=constant_op.constant([4])) + b = ops_lib.IndexedSlices(constant_op.constant([3.0, 4.0]), [0, 1]) + + total = replicate_model_fn._compute_sum_on_device( + [a, b], device='/device:GPU:0') + + self.assertEqual('/device:GPU:0', total.device) + self.assertAllEqual([4.0, 4.0, 0.0, 2.0], + session.run(ops_lib.convert_to_tensor(total))) + + def test_no_name_for_indexslices(self): + a = ops_lib.IndexedSlices( + constant_op.constant([1.0, 2.0]), [0, 1], + dense_shape=constant_op.constant([2])) + b = ops_lib.IndexedSlices(constant_op.constant([3.0, 4.0]), [0, 1]) + + with self.assertRaisesRegexp(ValueError, ''): + _ = replicate_model_fn._compute_sum_on_device( + [a, b], device='/device:GPU:0', name='cant_name_indexslices') + + +class ConcatTensorDictsTest(test_util.TensorFlowTestCase): + + def test_example(self): + tensor_dicts = [ + { + 'a': np.array([1.0, 2.0]), + 'b': np.array([11.0]), + 'c': np.array([21.0]), + }, + { + 'a': np.array([3.0]), + 'b': np.array([12.0, 13.0]), + }, + { + 'b': np.array([14.0]), + }, + ] + + with self.test_session() as session: + self.assertAllClose({ + 'a': np.array([1.0, 2.0, 3.0]), + 'b': np.array([11.0, 12.0, 13.0, 14.0]), + 'c': np.array([21.0]), + }, session.run(replicate_model_fn._concat_tensor_dicts(*tensor_dicts))) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/factorization/python/ops/gmm.py b/tensorflow/contrib/factorization/python/ops/gmm.py index 0d67e09f8151b48c97094b6b48f26e63443707ef..f72280c4ecf19e33278ffe74061f44bbb7b21709 100644 --- a/tensorflow/contrib/factorization/python/ops/gmm.py +++ b/tensorflow/contrib/factorization/python/ops/gmm.py @@ -24,7 +24,7 @@ import numpy as np from tensorflow.contrib import framework from tensorflow.contrib.factorization.python.ops import gmm_ops from tensorflow.contrib.framework.python.framework import checkpoint_utils -from tensorflow.contrib.framework.python.ops import variables +from tensorflow.python.training import training_util from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib from tensorflow.python.framework import constant_op @@ -167,7 +167,7 @@ class GMM(estimator.Estimator): self._num_clusters, self._random_seed, self._covariance_type, self._params) - incr_step = state_ops.assign_add(variables.get_global_step(), 1) + incr_step = state_ops.assign_add(training_util.get_global_step(), 1) loss = math_ops.reduce_sum(losses) training_op = with_dependencies([training_op, incr_step], loss) training_hooks = [_InitializeClustersHook( diff --git a/tensorflow/contrib/factorization/python/ops/wals.py b/tensorflow/contrib/factorization/python/ops/wals.py index 3976395d78e9188dd56d5b3b32fa8a3daf43c37d..4fe22ea26ec5f5a43f1c99d1fee518b1d326c5c9 100644 --- a/tensorflow/contrib/factorization/python/ops/wals.py +++ b/tensorflow/contrib/factorization/python/ops/wals.py @@ -19,7 +19,6 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.factorization.python.ops import factorization_ops -from tensorflow.contrib.framework.python.ops import variables as framework_variables from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.learn.python.learn.estimators import model_fn from tensorflow.python.framework import dtypes @@ -32,175 +31,81 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging as logging from tensorflow.python.summary import summary from tensorflow.python.training import session_run_hook +from tensorflow.python.training import training_util class _SweepHook(session_run_hook.SessionRunHook): """Keeps track of row/col sweeps, and runs prep ops before each sweep.""" - def __init__(self, is_row_sweep_var, train_ops, num_rows, num_cols, - input_row_indices, input_col_indices, row_prep_ops, - col_prep_ops, init_op, completed_sweeps_var): + def __init__(self, is_row_sweep_var, is_sweep_done_var, init_op, + row_prep_ops, col_prep_ops, row_train_op, col_train_op, + switch_op): """Initializes SweepHook. Args: is_row_sweep_var: A Boolean tf.Variable, determines whether we are currently doing a row or column sweep. It is updated by the hook. - train_ops: A list of ops. The ops created by this hook will have - control dependencies on `train_ops`. - num_rows: int, the total number of rows to be processed. - num_cols: int, the total number of columns to be processed. - input_row_indices: A Tensor of type int64. The indices of the input rows - that are processed during the current sweep. All elements of - `input_row_indices` must be in [0, num_rows). - input_col_indices: A Tensor of type int64. The indices of the input - columns that are processed during the current sweep. All elements of - `input_col_indices` must be in [0, num_cols). - row_prep_ops: list of ops, to be run before the beginning of each row - sweep, in the given order. - col_prep_ops: list of ops, to be run before the beginning of each column - sweep, in the given order. + is_sweep_done_var: A Boolean tf.Variable, determines whether we are + starting a new sweep (this is used to determine when to run the prep ops + below). init_op: op to be run once before training. This is typically a local initialization op (such as cache initialization). - completed_sweeps_var: An integer tf.Variable, indicates the number of - completed sweeps. It is updated by the hook. + row_prep_ops: A list of TensorFlow ops, to be run before the beginning of + each row sweep (and during initialization), in the given order. + col_prep_ops: A list of TensorFlow ops, to be run before the beginning of + each column sweep (and during initialization), in the given order. + row_train_op: A TensorFlow op to be run during row sweeps. + col_train_op: A TensorFlow op to be run during column sweeps. + switch_op: A TensorFlow op to be run before each sweep. """ - self._num_rows = num_rows - self._num_cols = num_cols + self._is_row_sweep_var = is_row_sweep_var + self._is_sweep_done_var = is_sweep_done_var + self._init_op = init_op self._row_prep_ops = row_prep_ops self._col_prep_ops = col_prep_ops - self._init_op = init_op - self._is_row_sweep_var = is_row_sweep_var - self._completed_sweeps_var = completed_sweeps_var - # Boolean variable that determines whether the init_ops have been run. + self._row_train_op = row_train_op + self._col_train_op = col_train_op + self._switch_op = switch_op + # Boolean variable that determines whether the init_op has been run. self._is_initialized = False - # Ops to run jointly with train_ops, responsible for updating - # `is_row_sweep_var` and incrementing the `global_step` and - # `completed_sweeps` counters. - self._update_op, self._is_sweep_done_var, self._switch_op = ( - self._create_hook_ops(input_row_indices, input_col_indices, train_ops)) - - def _create_hook_ops(self, input_row_indices, input_col_indices, train_ops): - """Creates ops to update is_row_sweep_var, global_step and completed_sweeps. - - Creates two boolean tensors `processed_rows` and `processed_cols`, which - keep track of which rows/cols have been processed during the current sweep. - Returns ops that should be run after each row / col update. - - When `self._is_row_sweep_var` is True, it sets - processed_rows[input_row_indices] to True. - - When `self._is_row_sweep_var` is False, it sets - processed_cols[input_col_indices] to True. - - Args: - input_row_indices: A Tensor. The indices of the input rows that are - processed during the current sweep. - input_col_indices: A Tensor. The indices of the input columns that - are processed during the current sweep. - train_ops: A list of ops. The ops created by this function have control - dependencies on `train_ops`. - - Returns: - A tuple consisting of: - update_op: An op to be run jointly with training. It updates the state - and increments counters (global step and completed sweeps). - is_sweep_done_var: A Boolean tf.Variable, specifies whether the sweep is - done, i.e. all rows (during a row sweep) or all columns (during a - column sweep) have been processed. - switch_op: An op to be run in `self.before_run` when the sweep is done. - """ - processed_rows_init = array_ops.fill(dims=[self._num_rows], value=False) - with ops.colocate_with(processed_rows_init): - processed_rows = variable_scope.variable( - processed_rows_init, - collections=[ops.GraphKeys.GLOBAL_VARIABLES], - trainable=False, - name="sweep_hook_processed_rows") - processed_cols_init = array_ops.fill(dims=[self._num_cols], value=False) - with ops.colocate_with(processed_cols_init): - processed_cols = variable_scope.variable( - processed_cols_init, - collections=[ops.GraphKeys.GLOBAL_VARIABLES], - trainable=False, - name="sweep_hook_processed_cols") - switch_ops = control_flow_ops.group( - state_ops.assign( - self._is_row_sweep_var, - math_ops.logical_not(self._is_row_sweep_var)), - state_ops.assign(processed_rows, processed_rows_init), - state_ops.assign(processed_cols, processed_cols_init)) - is_sweep_done_var = variable_scope.variable( - False, - collections=[ops.GraphKeys.GLOBAL_VARIABLES], - trainable=False, - name="is_sweep_done") - - # After running the `train_ops`, updates `processed_rows` or - # `processed_cols` tensors, depending on whether this is a row or col sweep. - with ops.control_dependencies(train_ops): - with ops.colocate_with(processed_rows): - update_processed_rows = state_ops.scatter_update( - processed_rows, - input_row_indices, - math_ops.logical_and( - self._is_row_sweep_var, - array_ops.ones_like(input_row_indices, dtype=dtypes.bool))) - with ops.colocate_with(processed_cols): - update_processed_cols = state_ops.scatter_update( - processed_cols, - input_col_indices, - math_ops.logical_and( - math_ops.logical_not(self._is_row_sweep_var), - array_ops.ones_like(input_col_indices, dtype=dtypes.bool))) - update_processed_op = control_flow_ops.group( - update_processed_rows, update_processed_cols) - - with ops.control_dependencies([update_processed_op]): - is_sweep_done = math_ops.logical_or( - math_ops.reduce_all(processed_rows), - math_ops.reduce_all(processed_cols)) - # Increments global step. - global_step = framework_variables.get_global_step() - if global_step is not None: - global_step_incr_op = state_ops.assign_add( - global_step, 1, name="global_step_incr").op - else: - global_step_incr_op = control_flow_ops.no_op() - # Increments completed sweeps. - completed_sweeps_incr_op = state_ops.assign_add( - self._completed_sweeps_var, - math_ops.cast(is_sweep_done, dtypes.int32), - use_locking=True).op - update_ops = control_flow_ops.group( - global_step_incr_op, - completed_sweeps_incr_op, - state_ops.assign(is_sweep_done_var, is_sweep_done)) - - return update_ops, is_sweep_done_var, switch_ops def before_run(self, run_context): """Runs the appropriate prep ops, and requests running update ops.""" - # Runs the appropriate init ops and prep ops. sess = run_context.session is_sweep_done = sess.run(self._is_sweep_done_var) if not self._is_initialized: - logging.info("SweepHook running cache init op.") + logging.info("SweepHook running init op.") sess.run(self._init_op) if is_sweep_done: + logging.info("SweepHook starting the next sweep.") sess.run(self._switch_op) + is_row_sweep = sess.run(self._is_row_sweep_var) if is_sweep_done or not self._is_initialized: - logging.info("SweepHook running sweep prep ops.") - row_sweep = sess.run(self._is_row_sweep_var) - prep_ops = self._row_prep_ops if row_sweep else self._col_prep_ops + logging.info("SweepHook running prep ops for the {} sweep.".format( + "row" if is_row_sweep else "col")) + prep_ops = self._row_prep_ops if is_row_sweep else self._col_prep_ops for prep_op in prep_ops: sess.run(prep_op) - self._is_initialized = True - - # Requests running `self._update_op` jointly with the training op. logging.info("Next fit step starting.") - return session_run_hook.SessionRunArgs(fetches=[self._update_op]) + return session_run_hook.SessionRunArgs( + fetches=[self._row_train_op if is_row_sweep else self._col_train_op]) - def after_run(self, run_context, run_values): - logging.info("Fit step done.") + +class _IncrementGlobalStepHook(session_run_hook.SessionRunHook): + """Hook that increments the global step.""" + + def __init__(self): + global_step = training_util.get_global_step() + if global_step: + self._global_step_incr_op = state_ops.assign_add( + global_step, 1, name="global_step_incr").op + else: + self._global_step_incr_op = None + + def before_run(self, run_context): + if self._global_step_incr_op: + run_context.session.run(self._global_step_incr_op) class _StopAtSweepHook(session_run_hook.SessionRunHook): @@ -246,6 +151,9 @@ def _wals_factorization_model_function(features, labels, mode, params): Returns: A ModelFnOps object. + + Raises: + ValueError: If `mode` is not recognized. """ assert labels is None use_factors_weights_cache = (params["use_factors_weights_cache_for_training"] @@ -269,86 +177,145 @@ def _wals_factorization_model_function(features, labels, mode, params): use_gramian_cache=use_gramian_cache) # Get input rows and cols. We either update rows or columns depending on - # the value of row_sweep, which is maintained using a session hook + # the value of row_sweep, which is maintained using a session hook. input_rows = features[WALSMatrixFactorization.INPUT_ROWS] input_cols = features[WALSMatrixFactorization.INPUT_COLS] - input_row_indices, _ = array_ops.unique(input_rows.indices[:, 0]) - input_col_indices, _ = array_ops.unique(input_cols.indices[:, 0]) - - # Train ops, controlled using the SweepHook - # We need to run the following ops: - # Before a row sweep: - # row_update_prep_gramian_op - # initialize_row_update_op - # During a row sweep: - # update_row_factors_op - # Before a col sweep: - # col_update_prep_gramian_op - # initialize_col_update_op - # During a col sweep: - # update_col_factors_op - - is_row_sweep_var = variable_scope.variable( - True, - trainable=False, - name="is_row_sweep", - collections=[ops.GraphKeys.GLOBAL_VARIABLES]) - completed_sweeps_var = variable_scope.variable( - 0, - trainable=False, - name=WALSMatrixFactorization.COMPLETED_SWEEPS, - collections=[ops.GraphKeys.GLOBAL_VARIABLES]) - - # The row sweep is determined by is_row_sweep_var (controlled by the - # sweep_hook) in TRAIN mode, and manually in EVAL mode. - is_row_sweep = (features[WALSMatrixFactorization.PROJECT_ROW] - if mode == model_fn.ModeKeys.EVAL else is_row_sweep_var) - - def update_row_factors(): - return model.update_row_factors(sp_input=input_rows, transpose_input=False) - - def update_col_factors(): - return model.update_col_factors(sp_input=input_cols, transpose_input=True) - - (_, train_op, - unregularized_loss, regularization, sum_weights) = control_flow_ops.cond( - is_row_sweep, update_row_factors, update_col_factors) - loss = unregularized_loss + regularization - root_weighted_squared_error = math_ops.sqrt(unregularized_loss / sum_weights) - - row_prep_ops = [ - model.row_update_prep_gramian_op, model.initialize_row_update_op - ] - col_prep_ops = [ - model.col_update_prep_gramian_op, model.initialize_col_update_op - ] - init_ops = [model.worker_init] - - sweep_hook = _SweepHook( - is_row_sweep_var, - [train_op, loss], - params["num_rows"], - params["num_cols"], - input_row_indices, - input_col_indices, - row_prep_ops, - col_prep_ops, - init_ops, - completed_sweeps_var) - training_hooks = [sweep_hook] - if max_sweeps is not None: - training_hooks.append(_StopAtSweepHook(max_sweeps)) - - # The root weighted squared error = - # \sqrt( \sum_{i,j} w_ij * (a_ij - r_ij)^2 / \sum_{i,j} w_ij ) - summary.scalar("loss", loss) # the estimated total training loss - summary.scalar("root_weighted_squared_error", root_weighted_squared_error) - summary.scalar("completed_sweeps", completed_sweeps_var) - - # Prediction ops (only return predictions in INFER mode) - predictions = {} - if mode == model_fn.ModeKeys.INFER: - project_row = features[WALSMatrixFactorization.PROJECT_ROW] + + # TRAIN mode: + if mode == model_fn.ModeKeys.TRAIN: + # Training consists of the following ops (controlled using a SweepHook). + # Before a row sweep: + # row_update_prep_gramian_op + # initialize_row_update_op + # During a row sweep: + # update_row_factors_op + # Before a col sweep: + # col_update_prep_gramian_op + # initialize_col_update_op + # During a col sweep: + # update_col_factors_op + + is_row_sweep_var = variable_scope.variable( + True, + trainable=False, + name="is_row_sweep", + collections=[ops.GraphKeys.GLOBAL_VARIABLES]) + is_sweep_done_var = variable_scope.variable( + False, + trainable=False, + name="is_sweep_done", + collections=[ops.GraphKeys.GLOBAL_VARIABLES]) + completed_sweeps_var = variable_scope.variable( + 0, + trainable=False, + name=WALSMatrixFactorization.COMPLETED_SWEEPS, + collections=[ops.GraphKeys.GLOBAL_VARIABLES]) + loss_var = variable_scope.variable( + 0., + trainable=False, + name=WALSMatrixFactorization.LOSS, + collections=[ops.GraphKeys.GLOBAL_VARIABLES]) + # The root weighted squared error = + # \sqrt( \sum_{i,j} w_ij * (a_ij - r_ij)^2 / \sum_{i,j} w_ij ) + rwse_var = variable_scope.variable( + 0., + trainable=False, + name=WALSMatrixFactorization.RWSE, + collections=[ops.GraphKeys.GLOBAL_VARIABLES]) + + summary.scalar("loss", loss_var) + summary.scalar("root_weighted_squared_error", rwse_var) + summary.scalar("completed_sweeps", completed_sweeps_var) + + def create_axis_ops(sp_input, num_items, update_fn, axis_name): + """Creates book-keeping and training ops for a given axis. + + Args: + sp_input: A SparseTensor corresponding to the row or column batch. + num_items: An integer, the total number of items of this axis. + update_fn: A function that takes one argument (`sp_input`), and that + returns a tuple of + * new_factors: A flot Tensor of the factor values after update. + * update_op: a TensorFlow op which updates the factors. + * loss: A float Tensor, the unregularized loss. + * reg_loss: A float Tensor, the regularization loss. + * sum_weights: A float Tensor, the sum of factor weights. + axis_name: A string that specifies the name of the axis. + + Returns: + A tuple consisting of: + * reset_processed_items_op: A TensorFlow op, to be run before the + beginning of any sweep. It marks all items as not-processed. + * axis_train_op: A Tensorflow op, to be run during this axis' sweeps. + """ + processed_items_init = array_ops.fill(dims=[num_items], value=False) + with ops.colocate_with(processed_items_init): + processed_items = variable_scope.variable( + processed_items_init, + collections=[ops.GraphKeys.GLOBAL_VARIABLES], + trainable=False, + name="processed_" + axis_name) + _, update_op, loss, reg, sum_weights = update_fn(sp_input) + input_indices = sp_input.indices[:, 0] + with ops.control_dependencies([ + update_op, + state_ops.assign(loss_var, loss + reg), + state_ops.assign(rwse_var, math_ops.sqrt(loss / sum_weights))]): + with ops.colocate_with(processed_items): + update_processed_items = state_ops.scatter_update( + processed_items, + input_indices, + array_ops.ones_like(input_indices, dtype=dtypes.bool), + name="update_processed_{}_indices".format(axis_name)) + with ops.control_dependencies([update_processed_items]): + is_sweep_done = math_ops.reduce_all(processed_items) + axis_train_op = control_flow_ops.group( + state_ops.assign(is_sweep_done_var, is_sweep_done), + state_ops.assign_add( + completed_sweeps_var, + math_ops.cast(is_sweep_done, dtypes.int32)), + name="{}_sweep_train_op".format(axis_name)) + return processed_items.initializer, axis_train_op + + reset_processed_rows_op, row_train_op = create_axis_ops( + input_rows, + params["num_rows"], + lambda x: model.update_row_factors(sp_input=x, transpose_input=False), + "rows") + reset_processed_cols_op, col_train_op = create_axis_ops( + input_cols, + params["num_cols"], + lambda x: model.update_col_factors(sp_input=x, transpose_input=True), + "cols") + switch_op = control_flow_ops.group( + state_ops.assign( + is_row_sweep_var, math_ops.logical_not(is_row_sweep_var)), + reset_processed_rows_op, + reset_processed_cols_op, + name="sweep_switch_op") + row_prep_ops = [ + model.row_update_prep_gramian_op, model.initialize_row_update_op] + col_prep_ops = [ + model.col_update_prep_gramian_op, model.initialize_col_update_op] + init_op = model.worker_init + sweep_hook = _SweepHook( + is_row_sweep_var, is_sweep_done_var, init_op, + row_prep_ops, col_prep_ops, row_train_op, col_train_op, switch_op) + global_step_hook = _IncrementGlobalStepHook() + training_hooks = [sweep_hook, global_step_hook] + if max_sweeps is not None: + training_hooks.append(_StopAtSweepHook(max_sweeps)) + + return model_fn.ModelFnOps( + mode=model_fn.ModeKeys.TRAIN, + predictions={}, + loss=loss_var, + eval_metric_ops={}, + train_op=control_flow_ops.no_op(), + training_hooks=training_hooks) + + # INFER mode + elif mode == model_fn.ModeKeys.INFER: projection_weights = features.get( WALSMatrixFactorization.PROJECTION_WEIGHTS) @@ -364,17 +331,45 @@ def _wals_factorization_model_function(features, labels, mode, params): projection_weights=projection_weights, transpose_input=True) - predictions[WALSMatrixFactorization.PROJECTION_RESULT] = ( - control_flow_ops.cond(project_row, get_row_projection, - get_col_projection)) + predictions = { + WALSMatrixFactorization.PROJECTION_RESULT: control_flow_ops.cond( + features[WALSMatrixFactorization.PROJECT_ROW], + get_row_projection, + get_col_projection) + } - return model_fn.ModelFnOps( - mode=mode, - predictions=predictions, - loss=loss, - eval_metric_ops={}, - train_op=train_op, - training_hooks=training_hooks) + return model_fn.ModelFnOps( + mode=model_fn.ModeKeys.INFER, + predictions=predictions, + loss=None, + eval_metric_ops={}, + train_op=control_flow_ops.no_op(), + training_hooks=[]) + + # EVAL mode + elif mode == model_fn.ModeKeys.EVAL: + def get_row_loss(): + _, _, loss, reg, _ = model.update_row_factors( + sp_input=input_rows, transpose_input=False) + return loss + reg + def get_col_loss(): + _, _, loss, reg, _ = model.update_col_factors( + sp_input=input_cols, transpose_input=True) + return loss + reg + loss = control_flow_ops.cond( + features[WALSMatrixFactorization.PROJECT_ROW], + get_row_loss, + get_col_loss) + return model_fn.ModelFnOps( + mode=model_fn.ModeKeys.EVAL, + predictions={}, + loss=loss, + eval_metric_ops={}, + train_op=control_flow_ops.no_op(), + training_hooks=[]) + + else: + raise ValueError("mode=%s is not recognized." % str(mode)) class WALSMatrixFactorization(estimator.Estimator): @@ -452,6 +447,10 @@ class WALSMatrixFactorization(estimator.Estimator): PROJECTION_RESULT = "projection" # Name of the completed_sweeps variable COMPLETED_SWEEPS = "completed_sweeps" + # Name of the loss variable + LOSS = "WALS_loss" + # Name of the Root Weighted Squared Error variable + RWSE = "WALS_RWSE" def __init__(self, num_rows, diff --git a/tensorflow/contrib/factorization/python/ops/wals_test.py b/tensorflow/contrib/factorization/python/ops/wals_test.py index 8bd72b7025aad80e387171b93b9b264da3ed0f66..36b483c6d7a59bba78b7fa22aac0714e278f22cc 100644 --- a/tensorflow/contrib/factorization/python/ops/wals_test.py +++ b/tensorflow/contrib/factorization/python/ops/wals_test.py @@ -417,73 +417,67 @@ class WALSMatrixFactorizationUnsupportedTest(test.TestCase): class SweepHookTest(test.TestCase): - def setUp(self): - self._num_rows = 5 - self._num_cols = 7 - self._train_op = control_flow_ops.no_op() - self._row_prep_done = variables.Variable(False) - self._col_prep_done = variables.Variable(False) - self._init_done = variables.Variable(False) - self._row_prep_ops = [state_ops.assign(self._row_prep_done, True)] - self._col_prep_ops = [state_ops.assign(self._col_prep_done, True)] - self._init_ops = [state_ops.assign(self._init_done, True)] - self._input_row_indices_ph = array_ops.placeholder(dtypes.int64) - self._input_col_indices_ph = array_ops.placeholder(dtypes.int64) - def test_sweeps(self): - def ind_feed(row_indices, col_indices): - return { - self._input_row_indices_ph: row_indices, - self._input_col_indices_ph: col_indices - } + is_row_sweep_var = variables.Variable(True) + is_sweep_done_var = variables.Variable(False) + init_done = variables.Variable(False) + row_prep_done = variables.Variable(False) + col_prep_done = variables.Variable(False) + row_train_done = variables.Variable(False) + col_train_done = variables.Variable(False) + + init_op = state_ops.assign(init_done, True) + row_prep_op = state_ops.assign(row_prep_done, True) + col_prep_op = state_ops.assign(col_prep_done, True) + row_train_op = state_ops.assign(row_train_done, True) + col_train_op = state_ops.assign(col_train_done, True) + train_op = control_flow_ops.no_op() + switch_op = control_flow_ops.group( + state_ops.assign(is_sweep_done_var, False), + state_ops.assign(is_row_sweep_var, + math_ops.logical_not(is_row_sweep_var))) + mark_sweep_done = state_ops.assign(is_sweep_done_var, True) with self.test_session() as sess: - is_row_sweep_var = variables.Variable(True) - completed_sweeps_var = variables.Variable(0) sweep_hook = wals_lib._SweepHook( is_row_sweep_var, - [self._train_op], - self._num_rows, - self._num_cols, - self._input_row_indices_ph, - self._input_col_indices_ph, - self._row_prep_ops, - self._col_prep_ops, - self._init_ops, - completed_sweeps_var) + is_sweep_done_var, + init_op, + [row_prep_op], + [col_prep_op], + row_train_op, + col_train_op, + switch_op) mon_sess = monitored_session._HookedSession(sess, [sweep_hook]) sess.run([variables.global_variables_initializer()]) - # Init ops should run before the first run. Row sweep not completed. - mon_sess.run(self._train_op, ind_feed([0, 1, 2], [])) - self.assertTrue(sess.run(self._init_done), - msg='init ops not run by the sweep_hook') - self.assertTrue(sess.run(self._row_prep_done), - msg='row_prep not run by the sweep_hook') - self.assertTrue(sess.run(is_row_sweep_var), - msg='Row sweep is not complete but is_row_sweep is ' - 'False.') - # Row sweep completed. - mon_sess.run(self._train_op, ind_feed([3, 4], [0, 1, 2, 3, 4, 5, 6])) - self.assertTrue(sess.run(completed_sweeps_var) == 1, - msg='Completed sweeps should be equal to 1.') - self.assertTrue(sess.run(sweep_hook._is_sweep_done_var), - msg='Sweep is complete but is_sweep_done is False.') - # Col init ops should run. Col sweep not completed. - mon_sess.run(self._train_op, ind_feed([], [0, 1, 2, 3, 4])) - self.assertTrue(sess.run(self._col_prep_done), - msg='col_prep not run by the sweep_hook') - self.assertFalse(sess.run(is_row_sweep_var), - msg='Col sweep is not complete but is_row_sweep is ' - 'True.') - self.assertFalse(sess.run(sweep_hook._is_sweep_done_var), - msg='Sweep is not complete but is_sweep_done is True.') - # Col sweep completed. - mon_sess.run(self._train_op, ind_feed([], [4, 5, 6])) - self.assertTrue(sess.run(sweep_hook._is_sweep_done_var), - msg='Sweep is complete but is_sweep_done is False.') - self.assertTrue(sess.run(completed_sweeps_var) == 2, - msg='Completed sweeps should be equal to 2.') + # Row sweep. + mon_sess.run(train_op) + self.assertTrue(sess.run(init_done), + msg='init op not run by the Sweephook') + self.assertTrue(sess.run(row_prep_done), + msg='row_prep_op not run by the SweepHook') + self.assertTrue(sess.run(row_train_done), + msg='row_train_op not run by the SweepHook') + self.assertTrue( + sess.run(is_row_sweep_var), + msg='Row sweep is not complete but is_row_sweep_var is False.') + # Col sweep. + mon_sess.run(mark_sweep_done) + mon_sess.run(train_op) + self.assertTrue(sess.run(col_prep_done), + msg='col_prep_op not run by the SweepHook') + self.assertTrue(sess.run(col_train_done), + msg='col_train_op not run by the SweepHook') + self.assertFalse( + sess.run(is_row_sweep_var), + msg='Col sweep is not complete but is_row_sweep_var is True.') + # Row sweep. + mon_sess.run(mark_sweep_done) + mon_sess.run(train_op) + self.assertTrue( + sess.run(is_row_sweep_var), + msg='Col sweep is complete but is_row_sweep_var is False.') class StopAtSweepHookTest(test.TestCase): diff --git a/tensorflow/contrib/ffmpeg/BUILD b/tensorflow/contrib/ffmpeg/BUILD index 7a5a4cb8c9499b950a3ad89be710e48474d5791e..eccce99071dc1477cf4f3bb152f3304b3b0fc35a 100644 --- a/tensorflow/contrib/ffmpeg/BUILD +++ b/tensorflow/contrib/ffmpeg/BUILD @@ -47,10 +47,25 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "decode_video_op_cc", + srcs = ["decode_video_op.cc"], + copts = tf_copts(), + linkstatic = 1, + visibility = ["//visibility:private"], + deps = [ + "//tensorflow/contrib/ffmpeg/default:ffmpeg_lib", + "//tensorflow/core:framework_headers_lib", + "//third_party/eigen3", + ], + alwayslink = 1, +) + tf_custom_op_library( name = "ffmpeg.so", deps = [ ":decode_audio_op_cc", + ":decode_video_op_cc", ":encode_audio_op_cc", ], ) @@ -59,6 +74,7 @@ cc_library( name = "ffmpeg_op_lib", deps = [ ":decode_audio_op_cc", + ":decode_video_op_cc", ":encode_audio_op_cc", ], ) @@ -81,6 +97,15 @@ tf_gen_op_wrapper_py( ], ) +tf_gen_op_wrapper_py( + name = "decode_video_op_py", + require_shape_functions = True, + visibility = ["//visibility:private"], + deps = [ + ":decode_video_op_cc", + ], +) + tf_py_test( name = "decode_audio_op_test", srcs = ["decode_audio_op_test.py"], @@ -115,6 +140,27 @@ tf_py_test( tags = ["manual"], ) +tf_py_test( + name = "decode_video_op_test", + size = "small", + srcs = ["decode_video_op_test.py"], + additional_deps = [ + ":ffmpeg_ops_py", + "@six_archive//:six", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:platform", + "//tensorflow/python:image_ops", + ], + data = [ + ":test_data", + ], + tags = [ + "manual", + "notap", + ], +) + py_library( name = "ffmpeg_ops_py", srcs = [ @@ -126,6 +172,7 @@ py_library( visibility = ["//visibility:public"], deps = [ ":decode_audio_op_py", + ":decode_video_op_py", ":encode_audio_op_py", "//tensorflow/contrib/util:util_py", "//tensorflow/python:framework_for_generated_wrappers", diff --git a/tensorflow/contrib/ffmpeg/__init__.py b/tensorflow/contrib/ffmpeg/__init__.py index 2bcb7284e10991b19ee5607147371e8d505c7732..daba965a98893b992abdc598ec713f13020d6e91 100644 --- a/tensorflow/contrib/ffmpeg/__init__.py +++ b/tensorflow/contrib/ffmpeg/__init__.py @@ -26,9 +26,11 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.ffmpeg.ffmpeg_ops import decode_audio +from tensorflow.contrib.ffmpeg.ffmpeg_ops import decode_video from tensorflow.contrib.ffmpeg.ffmpeg_ops import encode_audio +from tensorflow.contrib.ffmpeg.ffmpeg_ops import decode_video from tensorflow.python.util.all_util import remove_undocumented -_allowed_symbols = ['decode_audio', 'encode_audio'] +_allowed_symbols = ['decode_audio', 'encode_audio', 'decode_video'] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/ffmpeg/decode_audio_op.cc b/tensorflow/contrib/ffmpeg/decode_audio_op.cc index 4b1c8a337e10c7025ca06e2ed6e1b934716dc1d0..92fad70b1f9cc55e0690a3fbb35abcf56aa68f16 100644 --- a/tensorflow/contrib/ffmpeg/decode_audio_op.cc +++ b/tensorflow/contrib/ffmpeg/decode_audio_op.cc @@ -37,29 +37,6 @@ namespace { // https://www.ffmpeg.org/ffmpeg-formats.html const char* kValidFileFormats[] = {"mp3", "mp4", "ogg", "wav"}; -// Writes binary data to a file. -Status WriteFile(const string& filename, tensorflow::StringPiece contents) { - Env& env = *Env::Default(); - std::unique_ptr file; - TF_RETURN_IF_ERROR(env.NewWritableFile(filename, &file)); - TF_RETURN_IF_ERROR(file->Append(contents)); - TF_RETURN_IF_ERROR(file->Close()); - return Status::OK(); -} - -// Cleans up a file on destruction. -class FileDeleter { - public: - explicit FileDeleter(const string& filename) : filename_(filename) {} - ~FileDeleter() { - Env& env = *Env::Default(); - env.DeleteFile(filename_).IgnoreError(); - } - - private: - const string filename_; -}; - /* * Decoding implementation, shared across V1 and V2 ops. Creates a new * output in the context. @@ -69,7 +46,7 @@ void Decode(OpKernelContext* context, const string& file_format, const int32 samples_per_second, const int32 channel_count) { // Write the input data to a temp file. - const string temp_filename = GetTempFilename(file_format); + const string temp_filename = io::GetTempFilename(file_format); OP_REQUIRES_OK(context, WriteFile(temp_filename, file_contents)); FileDeleter deleter(temp_filename); diff --git a/tensorflow/contrib/ffmpeg/decode_video_op.cc b/tensorflow/contrib/ffmpeg/decode_video_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..d44032968d559bec14722902a4d47d22c46ea4aa --- /dev/null +++ b/tensorflow/contrib/ffmpeg/decode_video_op.cc @@ -0,0 +1,118 @@ +// Copyright 2016 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include + +#include +#include + +#include "tensorflow/contrib/ffmpeg/ffmpeg_lib.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { +namespace ffmpeg { + +class DecodeVideoOp : public OpKernel { + public: + explicit DecodeVideoOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + OP_REQUIRES( + context, context->num_inputs() == 1, + errors::InvalidArgument("DecodeVideo requires exactly 1 input.")); + const Tensor& contents_tensor = context->input(0); + + OP_REQUIRES(context, TensorShapeUtils::IsScalar(contents_tensor.shape()), + errors::InvalidArgument( + "contents must be a rank-0 tensor but got shape ", + contents_tensor.shape().DebugString())); + const tensorflow::StringPiece contents = contents_tensor.scalar()(); + + // Write the input data to a temp file. + string extension; + const string temp_filename = io::GetTempFilename(extension); + OP_REQUIRES_OK(context, WriteFile(temp_filename, contents)); + FileDeleter deleter(temp_filename); + + uint32 width = 0; + uint32 height = 0; + uint32 frames = 0; + + // Run FFmpeg on the data and verify results. + std::vector output_data; + const Status result = ffmpeg::ReadVideoFile(temp_filename, &output_data, + &width, &height, &frames); + if (result.code() == error::Code::NOT_FOUND) { + OP_REQUIRES( + context, result.ok(), + errors::Unavailable("FFmpeg must be installed to run this op. FFmpeg " + "can be found at http://www.ffmpeg.org.")); + } else if (result.code() == error::UNKNOWN) { + LOG(ERROR) << "Ffmpeg failed with error '" << result.error_message() + << "'. Returning empty tensor."; + Tensor* output = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(0, TensorShape({0, 0}), &output)); + return; + } else { + OP_REQUIRES_OK(context, result); + } + OP_REQUIRES(context, !output_data.empty(), + errors::Unknown("No output created by FFmpeg.")); + OP_REQUIRES( + context, output_data.size() == (frames * height * width * 3), + errors::Unknown("Output created by FFmpeg [", output_data.size(), + "] does not match description [", frames, ", ", height, + ", ", width, ", 3]")); + Tensor* output = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output( + 0, TensorShape({frames, height, width, 3}), &output)); + auto output_flat = output->flat(); + std::copy_n(output_data.begin(), output_data.size(), &output_flat(0)); + } +}; + +REGISTER_KERNEL_BUILDER(Name("DecodeVideo").Device(DEVICE_CPU), DecodeVideoOp); + +REGISTER_OP("DecodeVideo") + .Input("contents: string") + .Output("output: uint8") + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->UnknownShapeOfRank(4)); + return Status::OK(); + }) + .Doc(R"doc( +Processes the contents of an audio file into a tensor using FFmpeg to decode +the file. + +One row of the tensor is created for each channel in the audio file. Each +channel contains audio samples starting at the beginning of the audio and +having `1/samples_per_second` time between them. If the `channel_count` is +different from the contents of the file, channels will be merged or created. + +contents: The binary audio file contents, as a string or rank-0 string + tensor. +)doc"); + +} // namespace ffmpeg +} // namespace tensorflow diff --git a/tensorflow/contrib/ffmpeg/decode_video_op_test.py b/tensorflow/contrib/ffmpeg/decode_video_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..b43b6b8919223bd7731209d5423b142601396ea5 --- /dev/null +++ b/tensorflow/contrib/ffmpeg/decode_video_op_test.py @@ -0,0 +1,69 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""Tests for third_party.tensorflow.contrib.ffmpeg.decode_video_op.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os.path + +import six # pylint: disable=unused-import + +from tensorflow.contrib import ffmpeg +from tensorflow.python.ops import image_ops +from tensorflow.python.platform import resource_loader +from tensorflow.python.platform import test + + +class DecodeVideoOpTest(test.TestCase): + + def _loadFileAndTest(self, filename, width, height, frames, bmp_filename, + index): + """Loads an video file and validates the output tensor. + + Args: + filename: The filename of the input file. + width: The width of the video. + height: The height of the video. + frames: The frames of the video. + bmp_filename: The filename for the bmp file. + index: Index location inside the video. + """ + with self.test_session(): + path = os.path.join(resource_loader.get_data_files_path(), 'testdata', + filename) + with open(path, 'rb') as f: + contents = f.read() + + bmp_path = os.path.join(resource_loader.get_data_files_path(), 'testdata', + bmp_filename) + with open(bmp_path, 'rb') as f: + bmp_contents = f.read() + + image_op = image_ops.decode_bmp(bmp_contents) + image = image_op.eval() + self.assertEqual(image.shape, (height, width, 3)) + video_op = ffmpeg.decode_video(contents) + video = video_op.eval() + self.assertEqual(video.shape, (frames, height, width, 3)) + self.assertAllEqual(video[index, :, :, :], image) + + def testMp4(self): + self._loadFileAndTest('small.mp4', 560, 320, 166, 'small_100.bmp', 99) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc index 545a4386d043af604a747b8b5a8103101812b177..1245f515fe84f02e8470dbf941243bcd9834f3d0 100644 --- a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc +++ b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc @@ -16,6 +16,7 @@ #include "tensorflow/contrib/ffmpeg/ffmpeg_lib.h" #include +#include #include #include #include @@ -25,6 +26,7 @@ #include #include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/cpu_info.h" #include "tensorflow/core/platform/env.h" @@ -38,28 +40,45 @@ namespace { const char kFfmpegExecutable[] = "ffmpeg"; const int32 kDefaultProbeSize = 5000000; // 5MB -std::vector FfmpegCommandLine(const string& input_filename, - const string& output_filename, - const string& input_format_id, - int32 samples_per_second, - int32 channel_count) { - return { - "-nostats", // No additional progress display. - "-nostdin", // No interactive commands accepted. - "-f", input_format_id, // eg: "mp3" - "-probesize", StrCat(kDefaultProbeSize), - "-i", input_filename, - "-loglevel", "info", // Enable verbose logging to support debugging. - "-map_metadata", "-1", // Copy global metadata from input to output. - "-vn", // No video recording. - "-ac:a:0", StrCat(channel_count), - "-ar:a:0", StrCat(samples_per_second), - // Output set (in several ways) to signed 16-bit little-endian ints. - "-codec:a:0", "pcm_s16le", "-sample_fmt", "s16", "-f", "s16le", - "-sn", // No subtitle recording. - "-y", // Overwrite output file. - StrCat(output_filename) - }; +std::vector FfmpegAudioCommandLine(const string& input_filename, + const string& output_filename, + const string& input_format_id, + int32 samples_per_second, + int32 channel_count) { + return {"-nostats", // No additional progress display. + "-nostdin", // No interactive commands accepted. + "-f", input_format_id, // eg: "mp3" + "-probesize", StrCat(kDefaultProbeSize), "-i", input_filename, + "-loglevel", "info", // Enable verbose logging to support debugging. + "-map_metadata", "-1", // Copy global metadata from input to output. + "-vn", // No video recording. + "-ac:a:0", StrCat(channel_count), "-ar:a:0", + StrCat(samples_per_second), + // Output set (in several ways) to signed 16-bit little-endian ints. + "-codec:a:0", "pcm_s16le", "-sample_fmt", "s16", "-f", "s16le", + "-sn", // No subtitle recording. + "-y", // Overwrite output file. + StrCat(output_filename)}; +} + +std::vector FfmpegVideoCommandLine(const string& input_filename, + const string& output_filename) { + return {"-nostats", // No additional progress display. + "-nostdin", // No interactive commands accepted. + "-i", + input_filename, + "-f", + "image2pipe", + "-probesize", + StrCat(kDefaultProbeSize), + "-loglevel", + "info", // Enable verbose logging to support debugging. + "-vcodec", + "rawvideo", + "-pix_fmt", + "rgb24", + "-y", // Overwrite output file. + StrCat(output_filename)}; } // Is a named binary installed and executable by the current process? @@ -106,7 +125,7 @@ bool IsBinaryInstalled(const string& binary_name) { ::execvp(kFfmpegExecutable, args_chars.data()); // exec only returns on error. const int error = errno; - LOG(ERROR) << "FFmpeg could not be executed: " << error; + LOG(ERROR) << "FFmpeg could not be executed: " << strerror(error); ::_exit(error); } @@ -198,52 +217,101 @@ string BuildWavFile(int32 samples_per_second, int32 channel_count, return data; } -// Returns a unique number every time it is called. -int64 UniqueId() { - static mutex mu(LINKER_INITIALIZED); - static int64 id = 0; - mutex_lock l(mu); - return ++id; -} - -} // namespace - -string GetTempFilename(const string& extension) { - for (const char* dir : std::vector( - {getenv("TEST_TMPDIR"), getenv("TMPDIR"), getenv("TMP"), "/tmp"})) { - if (!dir || !dir[0]) { +Status ReadInfoFile(const string& filename, uint32* width, uint32* height, + uint32* frames) { + string data; + TF_QCHECK_OK(ReadFileToString(Env::Default(), filename, &data)) + << "Could not read FFmpeg file: " << filename; + bool in_output = false; + bool in_mapping = false; + uint32 frames_value = 0; + uint32 height_value = 0; + uint32 width_value = 0; + for (const string& line : str_util::Split(data, '\n')) { + // Output starts with the first line of `Output #..`. + // Further processing output region starts next line so we could continue + // the loop. + if (!in_output && line.find("Output #") == 0) { + in_output = true; + in_mapping = false; continue; } - struct stat statbuf; - if (!stat(dir, &statbuf) && S_ISDIR(statbuf.st_mode)) { - // UniqueId is added here because mkstemps is not as thread safe as it - // looks. https://github.com/tensorflow/tensorflow/issues/5804 shows - // the problem. - string tmp_filepath = io::JoinPath( - dir, - StrCat("tmp_file_tensorflow_", UniqueId(), "_XXXXXX.", extension)); - int fd = mkstemps(&tmp_filepath[0], extension.length() + 1); - if (fd < 0) { - LOG(FATAL) << "Failed to create temp file."; - } else { - close(fd); - return tmp_filepath; + // Stream mapping starts with the first line of `Stream mapping`, it also + // signals the end of Output section. + // Further processing of stream mapping region starts next line so we could + // continue the loop. + if (!in_mapping && line.find("Stream mapping:") == 0) { + in_output = false; + in_mapping = true; + continue; + } + if (in_output) { + // We only look for the first stream in output `Stream #0`. + // Once processed we will not further process output section. + if (line.find(" Stream #") == 0) { + size_t p = line.find(", rgb24, ", 24); + if (p != std::string::npos) { + string rgb24 = line.substr(p + 9, line.find(" ", p + 9)); + rgb24 = rgb24.substr(0, rgb24.find(",")); + string rgb24_width = rgb24.substr(0, rgb24.find("x")); + string rgb24_height = rgb24.substr(rgb24_width.length() + 1); + if (strings::safe_strtou32(rgb24_width, &width_value) && + strings::safe_strtou32(rgb24_height, &height_value)) { + in_output = false; + } + } + } + continue; + } + if (in_mapping) { + // We only look for the first stream mapping to have the number of the + // frames. + // Once processed we will not further process stream mapping section. + if (line.find("frame= ") == 0) { + string number = line.substr(8, line.find(" ", 8)); + number = number.substr(0, number.find(" ")); + if (strings::safe_strtou32(number, &frames_value)) { + in_mapping = false; + } } + continue; } } - LOG(FATAL) << "No temp directory found."; + if (frames_value == 0 || height_value == 0 || width_value == 0) { + return errors::Unknown("Not enough video info returned by FFmpeg [", + frames_value, ", ", height_value, ", ", width_value, + ", 3]"); + } + *width = width_value; + *height = height_value; + *frames = frames_value; + return Status::OK(); } -Status ReadAudioFile(const string& filename, - const string& audio_format_id, - int32 samples_per_second, - int32 channel_count, +} // namespace + +FileDeleter::~FileDeleter() { + Env& env = *Env::Default(); + env.DeleteFile(filename_).IgnoreError(); +} + +Status WriteFile(const string& filename, StringPiece contents) { + Env& env = *Env::Default(); + std::unique_ptr file; + TF_RETURN_IF_ERROR(env.NewWritableFile(filename, &file)); + TF_RETURN_IF_ERROR(file->Append(contents)); + TF_RETURN_IF_ERROR(file->Close()); + return Status::OK(); +} + +Status ReadAudioFile(const string& filename, const string& audio_format_id, + int32 samples_per_second, int32 channel_count, std::vector* output_samples) { // Create an argument list. - string output_filename = GetTempFilename("raw"); + string output_filename = io::GetTempFilename("raw"); const std::vector args = - FfmpegCommandLine(filename, output_filename, audio_format_id, - samples_per_second, channel_count); + FfmpegAudioCommandLine(filename, output_filename, audio_format_id, + samples_per_second, channel_count); // Unfortunately, it's impossible to differentiate an exec failure due to the // binary being missing and an error from the binary's execution. Therefore, @@ -256,7 +324,8 @@ Status ReadAudioFile(const string& filename, // Execute ffmpeg and report errors. pid_t child_pid = ::fork(); if (child_pid < 0) { - return Status(error::Code::UNKNOWN, StrCat("fork failed: ", errno)); + return Status(error::Code::UNKNOWN, + StrCat("fork failed: ", strerror(errno))); } if (child_pid == 0) { ExecuteFfmpeg(args); @@ -285,5 +354,63 @@ Status CreateAudioFile(const string& audio_format_id, int32 bits_per_second, return Status::OK(); } +Status ReadVideoFile(const string& filename, std::vector* output_data, + uint32* width, uint32* height, uint32* frames) { + if (!IsBinaryInstalled(kFfmpegExecutable)) { + return Status(error::Code::NOT_FOUND, StrCat("FFmpeg could not be found.")); + } + + string output_filename = io::GetTempFilename("raw"); + string stderr_filename = io::GetTempFilename("err"); + + // Create an argument list. + const std::vector args = + FfmpegVideoCommandLine(filename, output_filename); + + // Execute ffmpeg and report errors. + pid_t child_pid = ::fork(); + if (child_pid < 0) { + return Status(error::Code::UNKNOWN, + StrCat("fork failed: ", strerror(errno))); + } + if (child_pid == 0) { + const int fd = + open(stderr_filename.c_str(), O_RDWR | O_CREAT | O_APPEND, 0600); + if (fd < 0) { + const int error = errno; + LOG(ERROR) << "FFmpeg stderr file could not be created: " + << strerror(error); + ::_exit(error); + } + close(STDERR_FILENO); + dup2(fd, STDERR_FILENO); + ExecuteFfmpeg(args); + } else { + int status_code; + if (::waitpid(child_pid, &status_code, 0) < 0) { + return Status(error::Code::UNKNOWN, + StrCat("waitpid failed: ", strerror(errno))); + } + if (status_code) { + return Status(error::Code::UNKNOWN, + StrCat("FFmpeg execution failed: ", status_code)); + } + + TF_QCHECK_OK(ReadInfoFile(stderr_filename, width, height, frames)) + << "Could not read FFmpeg stderr file: " << stderr_filename; + + string raw_data; + TF_QCHECK_OK(ReadFileToString(Env::Default(), output_filename, &raw_data)) + << "Could not read FFmpeg output file: " << output_filename; + output_data->resize(raw_data.size()); + std::copy_n(raw_data.data(), raw_data.size(), output_data->begin()); + + TF_QCHECK_OK(Env::Default()->DeleteFile(output_filename)) + << output_filename; + TF_QCHECK_OK(Env::Default()->DeleteFile(stderr_filename)) + << stderr_filename; + return Status::OK(); + } +} } // namespace ffmpeg } // namespace tensorflow diff --git a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_test.cc b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_test.cc index 2871c1462894c6a4ddef63e9178272df0d14824c..85b61b26163d87a10d4e316720b4f633e038bbec 100644 --- a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_test.cc +++ b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_test.cc @@ -39,7 +39,7 @@ const char kTestMp3Filename[] = // Set to true via a command line flag iff the test is expected to have FFmpeg // installed. -mutex mu; +mutex mu(LINKER_INITIALIZED); bool should_ffmpeg_be_installed GUARDED_BY(mu) = false; string ParseTestFlags(int* argc, char** argv) { diff --git a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_utility_test.cc b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_utility_test.cc index 7176f3b550679555d5ab3b70f2b360a90eaee253..36fc71794b06e0f3cb86c40b325ce50e8999c667 100644 --- a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_utility_test.cc +++ b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_utility_test.cc @@ -20,7 +20,10 @@ #include #include + +#include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/test.h" @@ -49,7 +52,7 @@ TEST(FfmpegLibTest, TestTempDirectoryThreading) { pool.Schedule([&mu, &temp_filenames, environment]() { std::array buffer; for (int32 j = 0; j < kStringsPerItem; ++j) { - buffer[j] = GetTempFilename("mp3"); + buffer[j] = io::GetTempFilename("mp3"); TF_QCHECK_OK(environment->DeleteFile(buffer[j])); } mutex_lock l(mu); diff --git a/tensorflow/contrib/ffmpeg/ffmpeg_lib.h b/tensorflow/contrib/ffmpeg/ffmpeg_lib.h index f64007c81d74276d42c9d6ebd7c8f46cda6b7d72..c5ea1432bf8b61c87615074a93a45325371c4c87 100644 --- a/tensorflow/contrib/ffmpeg/ffmpeg_lib.h +++ b/tensorflow/contrib/ffmpeg/ffmpeg_lib.h @@ -24,16 +24,24 @@ namespace tensorflow { namespace ffmpeg { -// Gets a temp filename in an appropriate location. -string GetTempFilename(const string& extension); +// Cleans up a file on destruction. +class FileDeleter { + public: + explicit FileDeleter(const string& filename) : filename_(filename) {} + ~FileDeleter(); + + private: + const string filename_; +}; + +// Writes binary data to a file. +Status WriteFile(const string& filename, tensorflow::StringPiece contents); // Reads an audio file using ffmpeg and converts it into an array of samples in // [-1.0, 1.0]. If there are multiple channels in the audio then each frame will // contain a separate sample for each channel. Frames are ordered by time. -Status ReadAudioFile(const string& filename, - const string& audio_format_id, - int32 samples_per_second, - int32 channel_count, +Status ReadAudioFile(const string& filename, const string& audio_format_id, + int32 samples_per_second, int32 channel_count, std::vector* output_samples); // Creates an audio file using ffmpeg in a specific format. The samples are in @@ -45,6 +53,11 @@ Status CreateAudioFile(const string& audio_format_id, int32 bits_per_second, int32 samples_per_second, int32 channel_count, const std::vector& samples, string* output_data); +// Reads an video file using ffmpeg adn converts it into a RGB24 in uint8 +// [frames, height, width, 3]. The w, h, and frames are obtained from ffmpeg. +Status ReadVideoFile(const string& filename, std::vector* output_data, + uint32* width, uint32* height, uint32* frames); + } // namespace ffmpeg } // namespace tensorflow diff --git a/tensorflow/contrib/ffmpeg/ffmpeg_ops.py b/tensorflow/contrib/ffmpeg/ffmpeg_ops.py index 18b0b8b812c908cff62a241aa59b3a53021123f4..08b5a6ea48c2d4959af68a2ee9d27d21c6245457 100644 --- a/tensorflow/contrib/ffmpeg/ffmpeg_ops.py +++ b/tensorflow/contrib/ffmpeg/ffmpeg_ops.py @@ -19,7 +19,9 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.ffmpeg.ops import gen_decode_audio_op_py +from tensorflow.contrib.ffmpeg.ops import gen_decode_video_op_py from tensorflow.contrib.ffmpeg.ops import gen_encode_audio_op_py +from tensorflow.contrib.ffmpeg.ops import gen_decode_video_op_py from tensorflow.contrib.util import loader from tensorflow.python.framework import ops from tensorflow.python.platform import resource_loader @@ -89,3 +91,19 @@ def encode_audio(audio, file_format=None, samples_per_second=None): ops.NotDifferentiable('EncodeAudio') + + +def decode_video(contents): + """Create an op that decodes the contents of a video file. + + Args: + contents: The binary contents of the video file to decode. This is a + scalar. + + Returns: + A rank-4 `Tensor` that has `[frames, height, width, 3]` RGB as output. + """ + return gen_decode_video_op_py.decode_video(contents) + + +ops.NotDifferentiable('DecodeVideo') diff --git a/tensorflow/contrib/ffmpeg/testdata/small.mp4 b/tensorflow/contrib/ffmpeg/testdata/small.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..1fc478842f51e7519866f474a02ad605235bc6a6 Binary files /dev/null and b/tensorflow/contrib/ffmpeg/testdata/small.mp4 differ diff --git a/tensorflow/contrib/ffmpeg/testdata/small_100.bmp b/tensorflow/contrib/ffmpeg/testdata/small_100.bmp new file mode 100644 index 0000000000000000000000000000000000000000..61f53a2a21c933037f004d6ae4319dc6065fb886 Binary files /dev/null and b/tensorflow/contrib/ffmpeg/testdata/small_100.bmp differ diff --git a/tensorflow/contrib/framework/BUILD b/tensorflow/contrib/framework/BUILD index 891425fd8cae6fbbf60d30cbd9137c049073456c..5b659ddaa1386736eb8cc05a203ed1827ccd160e 100644 --- a/tensorflow/contrib/framework/BUILD +++ b/tensorflow/contrib/framework/BUILD @@ -24,6 +24,7 @@ tf_custom_op_py_library( "python/framework/__init__.py", "python/framework/checkpoint_utils.py", "python/framework/experimental.py", + "python/framework/graph_util.py", "python/framework/tensor_util.py", "python/ops/__init__.py", "python/ops/accumulate_n_v2.py", @@ -32,6 +33,7 @@ tf_custom_op_py_library( "python/ops/checkpoint_ops.py", "python/ops/ops.py", "python/ops/prettyprint_ops.py", + "python/ops/sort_ops.py", "python/ops/variables.py", ], dso = [ @@ -231,6 +233,17 @@ py_test( ], ) +py_test( + name = "graph_util_test", + srcs = ["python/framework/graph_util_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":framework_py", + "//tensorflow/python:client_testlib", + "//tensorflow/python:platform", + ], +) + py_test( name = "tensor_util_test", srcs = ["python/framework/tensor_util_test.py"], @@ -263,6 +276,7 @@ py_test( "//tensorflow/python:nn_ops", "//tensorflow/python:partitioned_variables", "//tensorflow/python:platform", + "//tensorflow/python:resource_variable_ops", "//tensorflow/python:session", "//tensorflow/python:training", "//tensorflow/python:variable_scope", @@ -307,6 +321,20 @@ py_test( ], ) +py_test( + name = "sort_ops_test", + size = "medium", + srcs = ["python/ops/sort_ops_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":framework_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:random_ops", + "//third_party/py/numpy", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/framework/__init__.py b/tensorflow/contrib/framework/__init__.py index 8421ba7c0423c6ed274f92ba74930822d0171e05..4edc77f86ba786ca547b8d3842e2cf02833fbbac 100644 --- a/tensorflow/contrib/framework/__init__.py +++ b/tensorflow/contrib/framework/__init__.py @@ -65,6 +65,7 @@ See the @{$python/contrib.framework} guide. @@get_variable_full_name @@get_variables_to_restore @@get_variables +@@global_variable @@local_variable @@model_variable @@variable @@ -79,6 +80,8 @@ See the @{$python/contrib.framework} guide. @@load_embedding_initializer @@load_linear_multiclass_bias_initializer @@load_variable_slot_initializer + +@@sort """ from __future__ import absolute_import diff --git a/tensorflow/contrib/framework/python/framework/__init__.py b/tensorflow/contrib/framework/python/framework/__init__.py index c8e6a4685498a4d89cef44f6a9a3acbe7557cb67..2d49771ab756359712a3ee0b23649c231678f952 100644 --- a/tensorflow/contrib/framework/python/framework/__init__.py +++ b/tensorflow/contrib/framework/python/framework/__init__.py @@ -21,6 +21,7 @@ from __future__ import print_function # pylint: disable=wildcard-import from tensorflow.contrib.framework.python.framework.checkpoint_utils import * from tensorflow.contrib.framework.python.framework.experimental import experimental +from tensorflow.contrib.framework.python.framework.graph_util import * from tensorflow.contrib.framework.python.framework.tensor_util import * # pylint: enable=wildcard-import from tensorflow.python.util import decorator_utils diff --git a/tensorflow/contrib/framework/python/framework/graph_util.py b/tensorflow/contrib/framework/python/framework/graph_util.py new file mode 100644 index 0000000000000000000000000000000000000000..a18ff2320d99726bb355ff6179fc97a070c2fec7 --- /dev/null +++ b/tensorflow/contrib/framework/python/framework/graph_util.py @@ -0,0 +1,154 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Helpers to manipulate a tensor graph in python. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import copy +import six + +# pylint: disable=unused-import +from tensorflow.core.framework import graph_pb2 +from tensorflow.core.framework import node_def_pb2 +from tensorflow.python.framework import ops +from tensorflow.python.framework.graph_util_impl import _assert_nodes_are_present +from tensorflow.python.framework.graph_util_impl import _bfs_for_reachable_nodes +from tensorflow.python.framework.graph_util_impl import _extract_graph_summary +from tensorflow.python.framework.graph_util_impl import _node_name + + +__all__ = ["fuse_op", "get_placeholders"] + + +def fuse_op(graph_def, input_nodes, output_nodes, output_dtypes, + output_quantized, op_name, op_type): + """Fuse subgraph between input_nodes and output_nodes into a single custom op. + + Args: + graph_def: A graph_pb2.GraphDef proto. + input_nodes: input nodes to the subgraph to be fused. + output_nodes: output nodes to the subgraph to be fused. + output_dtypes: A list of output datatypes for the custom op + output_quantized: A boolean flag that indicates if output is quantized + op_name: fused op name. + op_type: fused op type. + Returns: + The GraphDef of the new graph. + + Raises: + TypeError: If 'graph_def' is not a graph_pb2.GraphDef proto. + """ + + if not isinstance(graph_def, graph_pb2.GraphDef): + raise TypeError("graph_def must be a graph_pb2.GraphDef proto.") + + if isinstance(input_nodes, six.string_types): + raise TypeError("input_nodes must be a list.") + + if isinstance(output_nodes, six.string_types): + raise TypeError("output_nodes must be a list.") + + name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary( + graph_def) + _assert_nodes_are_present(name_to_node, input_nodes + output_nodes) + + # Nodes upto and including input_nodes + reachable_by_input = _bfs_for_reachable_nodes(input_nodes, name_to_input_name) + # Nodes upto and including output_nodes + reachable_by_output = _bfs_for_reachable_nodes(output_nodes, + name_to_input_name) + + # Set of nodes in the list input_nodes + input_nodes_set = set(input_nodes) + + # Set of nodes in the list output_nodes + output_nodes_set = set(output_nodes) + + nodes_post_output = [] + for node in graph_def.node: + n = _node_name(node.name) + if n in reachable_by_output: + if n not in reachable_by_input and n not in output_nodes_set: + # n is between input and output, i.e., part of the fused op + next_to_visit = [n] + while next_to_visit: + cur_node = next_to_visit[0] + del next_to_visit[0] + if cur_node in reachable_by_input and cur_node not in input_nodes_set: + raise TypeError("Node %s uses input %s not in input_nodes." % + (n, cur_node)) + if cur_node not in input_nodes_set: + next_to_visit += name_to_input_name[cur_node] + elif n not in reachable_by_input: + nodes_post_output.append(n) + + # Add all nodes upto the input nodes + out = graph_pb2.GraphDef() + reachable_by_input_sorted = sorted( + list(reachable_by_input), key=lambda n: name_to_seq_num[n]) + for node in reachable_by_input_sorted: + out.node.extend([copy.deepcopy(name_to_node[node])]) + + # Add the custom op + new_node = node_def_pb2.NodeDef() + for node in input_nodes: + new_node.input.append(node) + new_node.attr["_output_types"].list.type[:] = output_dtypes + new_node.attr["_output_quantized"].b = output_quantized + new_node.op = op_type + new_node.name = op_name + out.node.extend([new_node]) + + # Add the nodes in the output of the custom op + for index, n in enumerate(output_nodes): + assert len(name_to_node[n].input) == 1 + new_node = copy.deepcopy(name_to_node[n]) + del new_node.input[:] + new_node.input.append(op_name + (":" + str(index) if index != 0 else "")) + out.node.extend([new_node]) + + # Add the nodes post output_nodes + for n in nodes_post_output: + out.node.extend([copy.deepcopy(name_to_node[n])]) + + out.library.CopyFrom(graph_def.library) + out.versions.CopyFrom(graph_def.versions) + return out + + +def get_placeholders(graph): + """Get placeholders of a graph. + + Args: + graph: A tf.Graph. + Returns: + A list contains all placeholders of given graph. + + Raises: + TypeError: If `graph` is not a tensorflow graph. + """ + + if not isinstance(graph, ops.Graph): + raise TypeError("Input graph needs to be a Graph: %s" % graph) + + # For each placeholder() call, there is a corresponding + # operation of type 'Placeholder' registered to the graph. + # The return value (a Tensor) of placeholder() is the + # first output of this operation in fact. + operations = graph.get_operations() + result = [i.outputs[0] for i in operations if i.type == "Placeholder"] + return result diff --git a/tensorflow/contrib/framework/python/framework/graph_util_test.py b/tensorflow/contrib/framework/python/framework/graph_util_test.py new file mode 100644 index 0000000000000000000000000000000000000000..b8a6d109e19211d271c2b15bac66ddacd38fe395 --- /dev/null +++ b/tensorflow/contrib/framework/python/framework/graph_util_test.py @@ -0,0 +1,99 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""@graph_util tests.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.framework.python.framework import graph_util +from tensorflow.core.framework import graph_pb2 +from tensorflow.core.framework import node_def_pb2 +from tensorflow.core.framework import types_pb2 +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +def GetNewNode(name, op, input_nodes): + new_node = node_def_pb2.NodeDef() + new_node.op = op + new_node.name = name + for node in input_nodes: + new_node.input.append(node) + return new_node + + +class GraphUtilTest(test.TestCase): + + def testGraphUtil(self): + graph_def = graph_pb2.GraphDef() + node_a = GetNewNode('A', 'Placeholder', []) + node_b = GetNewNode('B', 'Op1', ['A']) + node_c = GetNewNode('C', 'Op1', ['B']) + node_d = GetNewNode('D', 'Op1', ['C']) + node_e = GetNewNode('E', 'Op1', ['D']) + graph_def.node.extend([node_a, node_b, node_c, node_d, node_e]) + fused_graph_def = graph_util.fuse_op( + graph_def, ['A'], ['D'], [types_pb2.DT_FLOAT], True, 'FusedOp', 'Op2') + self.assertEqual(len(fused_graph_def.node), 4) + self.assertEqual(fused_graph_def.node[0].name, 'A') + self.assertEqual(fused_graph_def.node[1].name, 'FusedOp') + self.assertEqual(fused_graph_def.node[1].input[0], 'A') + self.assertEqual(fused_graph_def.node[1].op, 'Op2') + self.assertEqual(fused_graph_def.node[1].attr['_output_quantized'].b, True) + self.assertEqual(fused_graph_def.node[1].attr['_output_types'].list.type, + [types_pb2.DT_FLOAT]) + self.assertEqual(fused_graph_def.node[2].name, 'D') + self.assertEqual(fused_graph_def.node[3].name, 'E') + + def testGraphUtilArtificialDependencyInjection(self): + graph_def = graph_pb2.GraphDef() + node_a = GetNewNode('A', 'Placeholder', []) + node_a1 = GetNewNode('A1', 'Placeholder', []) + node_b = GetNewNode('B', 'Op1', ['A']) + node_c = GetNewNode('C', 'Op1', ['B']) + node_d = GetNewNode('D', 'Op1', ['C']) + node_e = GetNewNode('E', 'Op1', ['D']) + graph_def.node.extend([node_a, node_a1, node_b, node_c, node_d, node_e]) + fused_graph_def = graph_util.fuse_op(graph_def, ['A', 'A1'], ['D'], + [types_pb2.DT_FLOAT], True, 'FusedOp', + 'Op2') + self.assertEqual(len(fused_graph_def.node), 5) + self.assertEqual(fused_graph_def.node[0].name, 'A') + self.assertEqual(fused_graph_def.node[1].name, 'A1') + self.assertEqual(fused_graph_def.node[2].name, 'FusedOp') + self.assertEqual(fused_graph_def.node[2].input[0], 'A') + self.assertEqual(fused_graph_def.node[2].op, 'Op2') + self.assertEqual(fused_graph_def.node[2].attr['_output_quantized'].b, True) + self.assertEqual(fused_graph_def.node[2].attr['_output_types'].list.type, + [types_pb2.DT_FLOAT]) + self.assertEqual(fused_graph_def.node[3].name, 'D') + self.assertEqual(fused_graph_def.node[4].name, 'E') + + +class GetPlaceholdersTest(test.TestCase): + + def test_get_placeholders(self): + with ops.Graph().as_default() as g: + placeholders = [array_ops.placeholder(dtypes.float32) for _ in range(5)] + results = graph_util.get_placeholders(g) + self.assertEqual( + sorted(placeholders, key=lambda x: x._id), # pylint: disable=protected-access + sorted(results, key=lambda x: x._id)) # pylint: disable=protected-access + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/framework/python/ops/__init__.py b/tensorflow/contrib/framework/python/ops/__init__.py index edef37cf0c0719bf10a4c75c34adb30b9716cdcd..685bb94779762ce46ee342e7e0a182c54be64743 100644 --- a/tensorflow/contrib/framework/python/ops/__init__.py +++ b/tensorflow/contrib/framework/python/ops/__init__.py @@ -24,5 +24,6 @@ from tensorflow.contrib.framework.python.ops.arg_scope import * from tensorflow.contrib.framework.python.ops.checkpoint_ops import * from tensorflow.contrib.framework.python.ops.ops import * from tensorflow.contrib.framework.python.ops.prettyprint_ops import * +from tensorflow.contrib.framework.python.ops.sort_ops import * from tensorflow.contrib.framework.python.ops.variables import * # pylint: enable=wildcard-import diff --git a/tensorflow/contrib/framework/python/ops/accumulate_n_v2.py b/tensorflow/contrib/framework/python/ops/accumulate_n_v2.py index a0667bd489213cf366e27114a91e8699ed9e7428..2375ee4f550616ff60d20b87b5773704d8fbbe1e 100644 --- a/tensorflow/contrib/framework/python/ops/accumulate_n_v2.py +++ b/tensorflow/contrib/framework/python/ops/accumulate_n_v2.py @@ -48,7 +48,7 @@ def accumulate_n_v2(inputs, shape=None, tensor_dtype=None, name=None): tf.accumulate_n_v2([a, b, a]) # [[7, 4], [6, 14]] # Explicitly pass shape and type - tf.accumulate_n_v2([a, b, a], shape=[2, 2], tensor_dtype=tf.int32) + tf.accumulate_n_v2([a, b, a], shape=[2, 2], tensor_dtype=tf.int32) # [[7, 4], # [6, 14]] ``` @@ -93,7 +93,7 @@ def accumulate_n_v2(inputs, shape=None, tensor_dtype=None, name=None): elif len(inputs) == 1 and name is not None: return array_ops.identity(inputs[0], name=name) elif context.in_eager_mode(): - # TemporaryVariable not currently supported in eager mode; fall back + # TemporaryVariable not currently supported in eager mode; fall back # onto AddN for now. # TODO(frreiss) remove this once the lifetime of eager variables gets # addressed @@ -101,7 +101,7 @@ def accumulate_n_v2(inputs, shape=None, tensor_dtype=None, name=None): else: return gen_math_ops._accumulate_nv2(inputs, name=name, shape=shape) -# The following code should eventually be merged into +# The following code should eventually be merged into # tensorflow/python/ops/math_grad.py @ops.RegisterGradient("AccumulateNV2") def _AddNGrad(op, grad): diff --git a/tensorflow/contrib/framework/python/ops/accumulate_n_v2_eager_test.py b/tensorflow/contrib/framework/python/ops/accumulate_n_v2_eager_test.py index c2229bb8ad3d5b38321d16f150ed94175ab9bdbe..8f44698da851b48abf831e957c80fa1643a58bda 100644 --- a/tensorflow/contrib/framework/python/ops/accumulate_n_v2_eager_test.py +++ b/tensorflow/contrib/framework/python/ops/accumulate_n_v2_eager_test.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for new version of accumulate_n op that will eventually go into +"""Tests for new version of accumulate_n op that will eventually go into `ops.math_ops`. -These test cases spefically exercise the `eager` APIs. They need to be in a +These test cases spefically exercise the `eager` APIs. They need to be in a separate file from the remaining tests because eager mode is currently something you can turn on but can't turn off for the lifetime of the current process.""" from __future__ import absolute_import @@ -64,7 +64,7 @@ class AccumulateNV2EagerTest(test_util.TensorFlowTestCase): np.random.seed(42) num_inputs = 3 input_vars = [ - resource_variable_ops.ResourceVariable(10.0 * np.random.random(), + resource_variable_ops.ResourceVariable(10.0 * np.random.random(), name="t%d" % i) for i in range(0, num_inputs) ] @@ -72,7 +72,7 @@ class AccumulateNV2EagerTest(test_util.TensorFlowTestCase): def fn(first, second, third): return av2.accumulate_n_v2([first, second, third]) - grad_fn = backprop.gradients_function(fn) + grad_fn = backprop.gradients_function(fn) grad = grad_fn(input_vars[0], input_vars[1], input_vars[2]) self.assertAllEqual(np.repeat(1.0, num_inputs), # d/dx (x + y + ...) = 1 [elem.numpy() for elem in grad]) diff --git a/tensorflow/contrib/framework/python/ops/accumulate_n_v2_test.py b/tensorflow/contrib/framework/python/ops/accumulate_n_v2_test.py index 3386e849d5cb8516ab3b1f6cb0429be3fc2fc960..b5e9f8df79262635bf579a6bf2260bc40c140c6f 100644 --- a/tensorflow/contrib/framework/python/ops/accumulate_n_v2_test.py +++ b/tensorflow/contrib/framework/python/ops/accumulate_n_v2_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for new version of accumulate_n op that will eventually go into +"""Tests for new version of accumulate_n op that will eventually go into `ops.math_ops`.""" from __future__ import absolute_import from __future__ import division @@ -102,21 +102,21 @@ class AccumulateNV2Test(test_util.TensorFlowTestCase): with self.assertRaises(ValueError): a = variables.Variable(np.array([0.1,0.2])) b = variables.Variable(np.array([[0.3],[0.4]])) - tf_val = av2.accumulate_n_v2([a,b]) + tf_val = av2.accumulate_n_v2([a,b]) def testWrongType(self): with self.test_session(): with self.assertRaises(TypeError): a = variables.Variable(0.2, dtype=np.float32) b = variables.Variable(0.1, dtype=np.float32) - tf_val = av2.accumulate_n_v2([a,b], tensor_dtype=np.int32) + tf_val = av2.accumulate_n_v2([a,b], tensor_dtype=np.int32) def testWrongTypeOneInput(self): # Scenario that used to trigger a bug, even when testWrongType() worked with self.test_session(): with self.assertRaises(TypeError): a = variables.Variable(0.2, dtype=np.float32) - tf_val = av2.accumulate_n_v2([a], tensor_dtype=np.int32) + tf_val = av2.accumulate_n_v2([a], tensor_dtype=np.int32) if __name__ == "__main__": diff --git a/tensorflow/contrib/framework/python/ops/sort_ops.py b/tensorflow/contrib/framework/python/ops/sort_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..8f62f0ea7b9b561f235b9496ffda97a9f378d530 --- /dev/null +++ b/tensorflow/contrib/framework/python/ops/sort_ops.py @@ -0,0 +1,113 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Support for sorting tensors. + +@@sort +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops as framework_ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops + + +def sort(values, axis=-1, direction='ASCENDING', name=None): + """Sorts a tensor. + + Args: + values: 1-D or higher numeric `Tensor`. + axis: The axis along which to sort. The default is -1, which sorts the last + axis. + direction: The direction in which to sort the values (`'ASCENDING'` or + `'DESCENDING'`). + name: Optional name for the operation. + + Returns: + A `Tensor` with the same dtype and shape as `values`, with the elements + sorted along the given `axis`. + + Raises: + ValueError: If axis is not a constant scalar, or the direction is invalid. + """ + with framework_ops.name_scope(name, 'sort'): + if direction not in _SORT_IMPL: + raise ValueError('%s should be one of %s' % + (direction, ', '.join(sorted(_SORT_IMPL.keys())))) + # Axis must be an integer, not a Tensor. + axis = framework_ops.convert_to_tensor(axis, name='axis') + axis_static = tensor_util.constant_value(axis) + if axis.shape.ndims != 0 or axis_static is None: + raise ValueError('axis must be a constant scalar') + axis_static = int(axis_static) # Avoids NumPy casting error + + values = framework_ops.convert_to_tensor(values, name='values') + + return _SORT_IMPL[direction](values, axis_static) + + +def _descending_sort(values, axis): + """Sorts values in reverse using `top_k`. + + Args: + values: Tensor of numeric values. + axis: Index of the axis which values should be sorted along. + + Returns: + The sorted values. + """ + k = array_ops.shape(values)[axis] + rank = array_ops.rank(values) + # Fast path: sorting the last axis. + if axis == -1 or axis + 1 == values.get_shape().ndims: + return nn_ops.top_k(values, k)[0] + + # Otherwise, transpose the array. Swap axes `axis` and `rank - 1`. + if axis < 0: + # Make axis a Tensor with the real axis index if needed. + axis += rank + transposition = array_ops.concat( + [ + # Axes up to axis are unchanged. + math_ops.range(axis), + # Swap axis and rank - 1. + [rank - 1], + # Axes in [axis + 1, rank - 1) are unchanged. + math_ops.range(axis + 1, rank - 1), + # Swap axis and rank - 1. + [axis] + ], + axis=0) + top_k_input = array_ops.transpose(values, transposition) + values, unused_indices = nn_ops.top_k(top_k_input, k) + # transposition contains a single cycle of length 2 (swapping 2 elements), + # so it is an involution (it is its own inverse). + return array_ops.transpose(values, transposition) + + +def _ascending_sort(values, axis): + # Negate the values to get the ascending order from descending sort. + values_or_indices = _descending_sort(-values, axis) + return -values_or_indices + + +_SORT_IMPL = { + 'ASCENDING': _ascending_sort, + 'DESCENDING': _descending_sort, +} diff --git a/tensorflow/contrib/framework/python/ops/sort_ops_test.py b/tensorflow/contrib/framework/python/ops/sort_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..d08ae502f10d98ee14d8bea2f76b18bedb935cea --- /dev/null +++ b/tensorflow/contrib/framework/python/ops/sort_ops_test.py @@ -0,0 +1,95 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the sort wrapper.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.framework.python.ops import sort_ops +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.platform import test + + +class SortTest(test.TestCase): + + def testRandom_lowDimensionality(self): + self._testRandom_lowDimensionality(negative_axis=False) + + def testRandom_lowDimensionality_negative(self): + self._testRandom_lowDimensionality(negative_axis=True) + + def _testRandom_lowDimensionality(self, negative_axis): + np.random.seed(42) + for _ in range(20): + rank = np.random.randint(1, 3) + shape = [np.random.randint(0, 20) for _ in range(rank)] + arr = np.random.random(shape) + sort_axis = np.random.choice(rank) + if negative_axis: + sort_axis = -1 - sort_axis + with self.test_session(): + self.assertAllEqual( + np.sort(arr, axis=sort_axis), + sort_ops.sort(constant_op.constant(arr), axis=sort_axis).eval()) + + def testRandom_highDimensionality(self): + np.random.seed(100) + for _ in range(20): + rank = np.random.randint(5, 15) + shape = [np.random.randint(1, 4) for _ in range(rank)] + arr = np.random.random(shape) + sort_axis = np.random.choice(rank) + with self.test_session(): + self.assertAllEqual( + np.sort(arr, axis=sort_axis), + sort_ops.sort(constant_op.constant(arr), axis=sort_axis).eval()) + + def testScalar(self): + # Create an empty scalar where the static shape is unknown. + zeros_length_1 = array_ops.zeros( + random_ops.random_uniform([1], minval=0, maxval=1, dtype=dtypes.int32), + dtype=dtypes.int32) + scalar = array_ops.zeros(zeros_length_1) + + sort = sort_ops.sort(scalar) + with self.test_session(): + with self.assertRaises(errors.InvalidArgumentError): + sort.eval() + + def testNegativeOutOfBounds_staticShape(self): + arr = constant_op.constant([3, 4, 5]) + with self.assertRaises(ValueError): + sort_ops.sort(arr, axis=-4) + + def testDescending(self): + arr = np.random.random((10, 5, 5)) + with self.test_session(): + self.assertAllEqual( + np.sort(arr, axis=0)[::-1], + sort_ops.sort( + constant_op.constant(arr), + axis=0, + direction='DESCENDING').eval()) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/framework/python/ops/variables.py b/tensorflow/contrib/framework/python/ops/variables.py index 1bd9a14a7f3e17b30b811b3b73e5915c0dd1ec59..3f1ece4510578b5ac39849c577fffbb2a3be45a7 100644 --- a/tensorflow/contrib/framework/python/ops/variables.py +++ b/tensorflow/contrib/framework/python/ops/variables.py @@ -60,6 +60,7 @@ __all__ = ['add_model_variable', 'get_variable_full_name', 'get_variables_to_restore', 'get_variables', + 'global_variable', 'local_variable', 'model_variable', 'variable', @@ -147,20 +148,48 @@ def get_or_create_global_step(graph=None): return training_util.get_or_create_global_step(graph) -def local_variable(initial_value, validate_shape=True, name=None): - """Create variable and add it to `GraphKeys.LOCAL_VARIABLES` collection. +def local_variable(initial_value, + validate_shape=True, + name=None, + use_resource=None): + """Create a variable with a value and add it to `GraphKeys.LOCAL_VARIABLES`. Args: initial_value: See variables.Variable.__init__. validate_shape: See variables.Variable.__init__. name: See variables.Variable.__init__. + use_resource: If `True` use a ResourceVariable instead of a Variable. Returns: New variable. """ return variable_scope.variable( initial_value, trainable=False, collections=[ops.GraphKeys.LOCAL_VARIABLES], - validate_shape=validate_shape, name=name) + validate_shape=validate_shape, + use_resource=use_resource, + name=name) + + +def global_variable(initial_value, + validate_shape=True, + name=None, + use_resource=None): + """Create a variable with a value and add it to `GraphKeys.GLOBAL_VARIABLES`. + + Args: + initial_value: See variables.Variable.__init__. + validate_shape: See variables.Variable.__init__. + name: See variables.Variable.__init__. + use_resource: If `True` use a ResourceVariable instead of a Variable. + Returns: + New variable. + """ + return variable_scope.variable( + initial_value, trainable=False, + collections=[ops.GraphKeys.GLOBAL_VARIABLES], + validate_shape=validate_shape, + use_resource=use_resource, + name=name) @contrib_add_arg_scope @@ -201,7 +230,7 @@ def variable(name, shape=None, dtype=None, initializer=None, else [ops.GraphKeys.GLOBAL_VARIABLES]) # Remove duplicates - collections = set(collections) + collections = list(set(collections)) getter = variable_scope.get_variable if custom_getter is not None: getter = functools.partial(custom_getter, @@ -412,7 +441,7 @@ def get_unique_variable(var_op_name): """ candidates = get_variables(scope=var_op_name) if not candidates: - raise ValueError('Couldnt find variable %s' % var_op_name) + raise ValueError('Couldn\'t find variable %s' % var_op_name) for candidate in candidates: if candidate.op.name == var_op_name: diff --git a/tensorflow/contrib/framework/python/ops/variables_test.py b/tensorflow/contrib/framework/python/ops/variables_test.py index 6a74e4e8666e98ca3c97dc9ddd8a6c11613f708e..2f06df93acb0a4c0b36c68839ff531e3c22c5ee3 100644 --- a/tensorflow/contrib/framework/python/ops/variables_test.py +++ b/tensorflow/contrib/framework/python/ops/variables_test.py @@ -33,6 +33,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import partitioned_variables +from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables as variables_lib from tensorflow.python.platform import gfile @@ -102,6 +103,82 @@ class LocalVariableTest(test.TestCase): sess.run(variables_lib.local_variables_initializer()) self.assertAllEqual(a.eval(), [0] * 5) + def testResourceVariable(self): + a = variables_lib2.local_variable(0) + b = variables_lib2.local_variable(0, use_resource=True) + self.assertEqual(type(a), variables_lib.Variable) + self.assertEqual(type(b), resource_variable_ops.ResourceVariable) + + +class GlobalVariableTest(test.TestCase): + + def test_global_variable(self): + with self.test_session() as sess: + self.assertEquals([], variables_lib.global_variables()) + value0 = 42 + variables_lib2.global_variable(value0) + value1 = 43 + variables_lib2.global_variable(value1) + variables = variables_lib.global_variables() + self.assertEquals(2, len(variables)) + with self.assertRaisesOpError( + 'Attempting to use uninitialized value Variable'): + sess.run(variables) + variables_lib.variables_initializer(variables).run() + self.assertAllEqual(set([value0, value1]), set(sess.run(variables))) + + def testVariableNameAndShape(self): + with self.test_session(): + with variable_scope.variable_scope('A'): + a = variables_lib2.global_variable([1, 1, 1, 1, 1], name='a') + self.assertEquals(a.op.name, 'A/a') + self.assertListEqual(a.get_shape().as_list(), [5]) + self.assertListEqual([a], variables_lib.global_variables()) + + def testGlobalVariableNotInLocalVariables(self): + with self.test_session(): + with variable_scope.variable_scope('A'): + a = variables_lib2.global_variable(0) + self.assertFalse(a in variables_lib.local_variables()) + self.assertTrue(a in variables_lib.global_variables()) + + def testGlobalVariableInVariablesToRestore(self): + with self.test_session(): + with variable_scope.variable_scope('A'): + a = variables_lib2.global_variable(0) + self.assertFalse(a in variables_lib.local_variables()) + self.assertTrue(a in variables_lib2.get_variables_to_restore()) + + def testGetVariablesReturnsThem(self): + with self.test_session(): + with variable_scope.variable_scope('A'): + a = variables_lib2.global_variable(0) + with variable_scope.variable_scope('B'): + b = variables_lib2.global_variable(0) + self.assertEquals([a], variables_lib2.get_variables('A')) + self.assertEquals([b], variables_lib2.get_variables('B')) + + def testGetLocalVariablesDontReturnsThem(self): + with self.test_session(): + with variable_scope.variable_scope('A'): + variables_lib2.global_variable(0) + with variable_scope.variable_scope('B'): + variables_lib2.global_variable(0) + self.assertEquals([], variables_lib2.get_local_variables('A')) + self.assertEquals([], variables_lib2.get_local_variables('B')) + + def testInitializedVariableValue(self): + with self.test_session() as sess: + a = variables_lib2.global_variable([0, 0, 0, 0, 0], name='a') + sess.run(variables_lib.global_variables_initializer()) + self.assertAllEqual(a.eval(), [0] * 5) + + def testResourceVariable(self): + a = variables_lib2.global_variable(0) + b = variables_lib2.global_variable(0, use_resource=True) + self.assertEqual(type(a), variables_lib.Variable) + self.assertEqual(type(b), resource_variable_ops.ResourceVariable) + class GlobalStepTest(test.TestCase): diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc index 88306094ab9947c9c78b03c0013f6afc88316803..5fec69ea4361a97c79ddc3188469e7ffb327f6cc 100644 --- a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc +++ b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc @@ -493,6 +493,8 @@ void LaunchFusedConv2DBiasActivationOp:: {{conv_input_rows, conv_input_cols}}, output_depth, {{filter_rows, filter_cols}}, + // TODO(yangzihao): Add support for arbitrary dilations for fused conv. + {{1, 1}}, // dilation_rows, dilation_cols {{row_stride, col_stride}}, {{padding_rows, padding_cols}}, conv_input->dtype(), diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h b/tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h index dc43af11580ce5fda74ee25da6c151a5b89c7aee..fa7a3c03aa35c756252b22a004be91fa24c10e41 100644 --- a/tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h +++ b/tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h @@ -30,11 +30,12 @@ class FusedConvParameters : public ConvParameters { public: FusedConvParameters(int64 batch, int64 in_depths, const SpatialArray& in, int64 out_depths, const SpatialArray& filter, - const SpatialArray& stride, const SpatialArray& padding, - DataType dtype, int device_id, bool has_side_input, + const SpatialArray& dilation, const SpatialArray& stride, + const SpatialArray& padding, DataType dtype, + int device_id, bool has_side_input, ActivationMode activation_mode) - : ConvParameters(batch, in_depths, in, out_depths, filter, stride, - padding, dtype, device_id), + : ConvParameters(batch, in_depths, in, out_depths, filter, dilation, + stride, padding, dtype, device_id), activation_mode_(activation_mode), has_side_input_(has_side_input) { hash_code_ = Hash64Combine(hash_code_, has_side_input); diff --git a/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc b/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc index 887ebc5a6c35379476fa1a643c866d38e2b25699..6a56237f67c844a3daa546eb02d64c9e2658f639 100644 --- a/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc +++ b/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc @@ -52,6 +52,7 @@ REGISTER_OP("FusedConv2DBiasActivation") .Attr("data_format: {'NHWC', 'NCHW', 'NCHW_VECT_C'} = 'NHWC'") .Attr("filter_format: {'HWIO', 'OIHW', 'OIHW_VECT_I'} = 'HWIO'") .Attr("activation_mode: {'Relu'} = 'Relu'") + .Attr("dilations: list(int) = [1, 1, 1, 1]") .SetShapeFn([](shape_inference::InferenceContext* c) { using shape_inference::ShapeHandle; using shape_inference::DimensionHandle; @@ -151,6 +152,11 @@ REGISTER_OP("FusedConv2DBiasActivation") kernel_height, kernel_width, input_channels % 4 ]` activation_mode: The activation applied to the output. Currently must be "Relu". + dilations: 1-D tensor of length 4. The dilation factor for each dimension + of `input`. If set to k > 1, there will be k-1 skipped cells between + each filter element on that dimension. The dimension order is determined + by the value of `data_format`, see above for details. Dilations in the + batch and depth dimensions must be 1. )doc"); } // namespace tensorflow diff --git a/tensorflow/contrib/gan/BUILD b/tensorflow/contrib/gan/BUILD index 1418c87023af0dbff890f46e10f0140d5b89e4b7..a2e6fa51f1e1cea1d995204d84a620a991cfb7ba 100644 --- a/tensorflow/contrib/gan/BUILD +++ b/tensorflow/contrib/gan/BUILD @@ -116,6 +116,7 @@ py_library( deps = [ ":clip_weights", ":conditioning_utils", + ":random_tensor_pool", ":virtual_batchnorm", "//tensorflow/python:util", ], @@ -219,6 +220,37 @@ py_test( ], ) +py_library( + name = "random_tensor_pool", + srcs = [ + "python/features/python/random_tensor_pool.py", + "python/features/python/random_tensor_pool_impl.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:data_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:random_ops", + "//tensorflow/python:util", + ], +) + +py_test( + name = "random_tensor_pool_test", + srcs = ["python/features/python/random_tensor_pool_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":random_tensor_pool", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//third_party/py/numpy", + ], +) + py_library( name = "virtual_batchnorm", srcs = [ diff --git a/tensorflow/contrib/gan/README.md b/tensorflow/contrib/gan/README.md index 3ab84780705b35567169bd76fd3485ad355ba9d8..4bca0a1d62a2b404c6783c7cfe3b5c67cfc58221 100644 --- a/tensorflow/contrib/gan/README.md +++ b/tensorflow/contrib/gan/README.md @@ -8,7 +8,8 @@ explicitly model the distribution and without writing an explicit loss. For example, the generator could learn to draw samples from the distribution of natural images. For more details on this technique, see ['Generative Adversarial Networks'](https://arxiv.org/abs/1406.2661) by -Goodfellow et al. +Goodfellow et al. See [tensorflow/models](https://github.com/tensorflow/models/tree/master/research/gan/) for examples, and [this tutorial](https://github.com/tensorflow/models/tree/master/research/gan/tutorial.ipynb) for an +introduction. #### Usage ```python @@ -23,8 +24,8 @@ mix TFGAN, native TF, and other custom frameworks * Use already implemented [GAN losses and penalties](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/losses/python/losses_impl.py) (ex Wasserstein loss, gradient penalty, mutual information penalty, etc) * [Monitor and visualize](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/eval/python/summaries_impl.py) GAN progress during training, and [evaluate](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py) them * Use already-implemented [tricks](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/features/python/) to stabilize and improve training -* Develop based on examples of common GAN setups -* Use the TFGAN-backed tf.Learn Estimator to easily train a GAN model +* Develop based on examples of [common GAN setups](https://github.com/tensorflow/models/tree/master/research/gan/) +* Use the TFGAN-backed [GANEstimator](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py) to easily train a GAN model * Improvements in TFGAN infrastructure will automatically benefit your TFGAN project * Stay up-to-date with research as we add more algorithms @@ -51,7 +52,7 @@ network to evaluate your unconditional generative model. You can also use your own pretrained classifier for more specific performance numbers, or use other methods for evaluating conditional generative models. -* examples (coming soon): +* [examples](https://github.com/tensorflow/models/tree/master/research/gan/) and [tutorial](https://github.com/tensorflow/models/tree/master/research/gan/tutorial.ipynb): See examples of how to use TFGAN to make GAN training easier, or use the more complicated examples to jumpstart your own project. These include unconditional and conditional GANs, InfoGANs, adversarial losses on existing networks, and image-to-image translation. diff --git a/tensorflow/contrib/gan/__init__.py b/tensorflow/contrib/gan/__init__.py index dff361fdc42708ea69999c2def4721f9d49fcf14..f1946c7f925660eae3aaa650c437e03da1f33d6c 100644 --- a/tensorflow/contrib/gan/__init__.py +++ b/tensorflow/contrib/gan/__init__.py @@ -12,7 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""TFGAN grouped API. Please see README.md for details and usage.""" +"""TFGAN is a lightweight library for training and evaluating GANs. + +In addition to providing the infrastructure for easily training and evaluating +GANS, this library contains modules for a TFGAN-backed Estimator, +evaluation metrics, features (such as virtual batch normalization), and losses. +Please see README.md for details and usage. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/gan/python/estimator/__init__.py b/tensorflow/contrib/gan/python/estimator/__init__.py index 8c4a18228039cb4f2c06e0333f4b8408f1f631e9..c9f7bc61b25230e4159cf8cbc7c9cceead0aa706 100644 --- a/tensorflow/contrib/gan/python/estimator/__init__.py +++ b/tensorflow/contrib/gan/python/estimator/__init__.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""TFGAN grouped API. Please see README.md for details and usage.""" +"""TFGAN estimator module. + +GANEstimator provides all the infrastructure support of a TensorFlow Estimator +with the feature support of TFGAN. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py index e89993991a389d68254a95aded2d771f4c2627be..9d14f391332fa95035bf96f8f37930af595634a9 100644 --- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py +++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools import enum from tensorflow.contrib.framework.python.ops import variables as variable_lib @@ -29,6 +30,7 @@ from tensorflow.python.estimator import estimator from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.framework import ops from tensorflow.python.ops import variable_scope +from tensorflow.python.util import tf_inspect as inspect __all__ = [ @@ -76,7 +78,7 @@ class GANEstimator(estimator.Estimator): return logits # Create GAN estimator. - gan_estimator = estimator.GANEstimator( + gan_estimator = tfgan.estimator.GANEstimator( model_dir, generator_fn=generator_fn, discriminator_fn=discriminator_fn, @@ -105,6 +107,7 @@ class GANEstimator(estimator.Estimator): discriminator_loss_fn=None, generator_optimizer=None, discriminator_optimizer=None, + get_hooks_fn=None, add_summaries=None, use_loss_summaries=True, config=None): @@ -116,7 +119,10 @@ class GANEstimator(estimator.Estimator): to continue training a previously saved model. generator_fn: A python function that takes a Tensor, Tensor list, or Tensor dictionary as inputs and returns the outputs of the GAN - generator. See `TFGAN` for more details and examples. + generator. See `TFGAN` for more details and examples. Additionally, if + it has an argument called `mode`, the Estimator's `mode` will be passed + in (ex TRAIN, EVAL, PREDICT). This is useful for things like batch + normalization. discriminator_fn: A python function that takes the output of `generator_fn` or real data in the GAN setup, and `generator_inputs`. Outputs a Tensor in the range [-inf, inf]. See `TFGAN` for more details @@ -132,6 +138,10 @@ class GANEstimator(estimator.Estimator): work. discriminator_optimizer: Same as `generator_optimizer`, but for the discriminator updates. + get_hooks_fn: A function that takes a `GANTrainOps` tuple and returns a + list of hooks. These hooks are run on the generator and discriminator + train ops, and can be used to implement the GAN training scheme. + Defaults to `train.get_sequential_train_hooks()`. add_summaries: `None`, a single `SummaryType`, or a list of `SummaryType`. use_loss_summaries: If `True`, add loss summaries. If `False`, does not. If `None`, uses defaults. @@ -146,7 +156,7 @@ class GANEstimator(estimator.Estimator): else discriminator_optimizer) gan_head = head_lib.gan_head( generator_loss_fn, discriminator_loss_fn, gopt, dopt, - use_loss_summaries) + use_loss_summaries, get_hooks_fn=get_hooks_fn) return _gan_model_fn( features, labels, mode, generator_fn, discriminator_fn, gan_head, add_summaries) @@ -225,9 +235,12 @@ def _gan_model_fn( labels=None) -def _make_train_gan_model(generator_fn, discriminator_fn, real_data, - generator_inputs, generator_scope, add_summaries): - """Make a `GANModel` for training.""" +def _make_gan_model(generator_fn, discriminator_fn, real_data, + generator_inputs, generator_scope, add_summaries, mode): + """Make a `GANModel`, and optionally pass in `mode`.""" + # If `generator_fn` has an argument `mode`, pass mode to it. + if 'mode' in inspect.getargspec(generator_fn).args: + generator_fn = functools.partial(generator_fn, mode=mode) gan_model = tfgan_train.gan_model( generator_fn, discriminator_fn, @@ -245,15 +258,28 @@ def _make_train_gan_model(generator_fn, discriminator_fn, real_data, return gan_model +def _make_train_gan_model(generator_fn, discriminator_fn, real_data, + generator_inputs, generator_scope, add_summaries): + """Make a `GANModel` for training.""" + return _make_gan_model(generator_fn, discriminator_fn, real_data, + generator_inputs, generator_scope, add_summaries, + model_fn_lib.ModeKeys.TRAIN) + + def _make_eval_gan_model(generator_fn, discriminator_fn, real_data, generator_inputs, generator_scope, add_summaries): """Make a `GANModel` for evaluation.""" - return _make_train_gan_model(generator_fn, discriminator_fn, real_data, - generator_inputs, generator_scope, add_summaries) + return _make_gan_model(generator_fn, discriminator_fn, real_data, + generator_inputs, generator_scope, add_summaries, + model_fn_lib.ModeKeys.EVAL) def _make_prediction_gan_model(generator_inputs, generator_fn, generator_scope): """Make a `GANModel` from just the generator.""" + # If `generator_fn` has an argument `mode`, pass mode to it. + if 'mode' in inspect.getargspec(generator_fn).args: + generator_fn = functools.partial(generator_fn, + mode=model_fn_lib.ModeKeys.PREDICT) with variable_scope.variable_scope(generator_scope) as gen_scope: generator_inputs = tfgan_train._convert_tensor_or_l_or_d(generator_inputs) # pylint:disable=protected-access generated_data = generator_fn(generator_inputs) diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py index 1bfdce9ee94d4d05d5186cd999361662bc0e3f85..e752f0bcccda418b79d4fdabb27807394cbbb425 100644 --- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py +++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py @@ -48,7 +48,8 @@ from tensorflow.python.training import training from tensorflow.python.training import training_util -def generator_fn(noise_dict): +def generator_fn(noise_dict, mode): + del mode noise = noise_dict['x'] return layers.fully_connected(noise, noise.shape[1].value) @@ -90,7 +91,6 @@ def mock_head(testcase, expected_generator_inputs, expected_real_data, generator_var_names, set([x.name for x in gan_model.generator_variables])) testcase.assertEqual(generator_scope_name, gan_model.generator_scope.name) - testcase.assertEqual(generator_fn, gan_model.generator_fn) testcase.assertEqual(_or_none(expected_real_data), gan_model.real_data) # TODO(joelshor): Add check on `discriminator_real_outputs`. # TODO(joelshor): Add check on `discriminator_gen_outputs`. diff --git a/tensorflow/contrib/gan/python/estimator/python/head_impl.py b/tensorflow/contrib/gan/python/estimator/python/head_impl.py index 204c646e194319c0e63599da0b2a4909ef270ef3..a21358c50bbdb4a1a929b0c5bc322cec4c9923b5 100644 --- a/tensorflow/contrib/gan/python/estimator/python/head_impl.py +++ b/tensorflow/contrib/gan/python/estimator/python/head_impl.py @@ -71,7 +71,7 @@ class GANHead(head._Head): # pylint: disable=protected-access def __init__(self, generator_loss_fn, discriminator_loss_fn, generator_optimizer, discriminator_optimizer, use_loss_summaries=True, - get_hooks_fn=tfgan_train.get_sequential_train_hooks(), + get_hooks_fn=None, name=None): """`Head` for GAN training. @@ -86,10 +86,12 @@ class GANHead(head._Head): # pylint: disable=protected-access use_loss_summaries: If `True`, add loss summaries. If `False`, does not. If `None`, uses defaults. get_hooks_fn: A function that takes a GANTrainOps tuple and returns a list - of hooks. + of hooks. Defaults to `train.get_sequential_train_hooks()` name: name of the head. If provided, summary and metrics keys will be suffixed by `"/" + name`. """ + if get_hooks_fn is None: + get_hooks_fn = tfgan_train.get_sequential_train_hooks() # TODO(joelshor): Validate inputs. if use_loss_summaries in [True, False]: diff --git a/tensorflow/contrib/gan/python/eval/__init__.py b/tensorflow/contrib/gan/python/eval/__init__.py index bb8046187807d0cc584f7174eb9aac578855c110..7daf78bc5dcab87f6fa31a8334269d31e94576d4 100644 --- a/tensorflow/contrib/gan/python/eval/__init__.py +++ b/tensorflow/contrib/gan/python/eval/__init__.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""TFGAN grouped API. Please see README.md for details and usage.""" +"""TFGAN evaluation module. + +This module supports techniques such as Inception Score, Frechet Inception +distance, and Sliced Wasserstein distance. +""" # pylint: disable=,wildcard-import,unused-import from __future__ import absolute_import diff --git a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py index d4c080cab3d82f6a69a293e84e1c08322bbb6f86..82293b575aefa198a618ae7286ca24ebabd6987d 100644 --- a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py +++ b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py @@ -57,8 +57,10 @@ __all__ = [ 'run_inception', 'inception_score', 'classifier_score', + 'classifier_score_from_logits', 'frechet_inception_distance', 'frechet_classifier_distance', + 'frechet_classifier_distance_from_activations', 'INCEPTION_DEFAULT_IMAGE_SIZE', ] @@ -130,10 +132,10 @@ def preprocess_image( with ops.name_scope(scope, 'preprocess', [images, height, width]): if not images.dtype.is_floating: images = math_ops.to_float(images) - images = (images - 128.0) / 128.0 if is_single: images = array_ops.expand_dims(images, axis=0) resized = image_ops.resize_bilinear(images, [height, width]) + resized = (resized - 128.0) / 128.0 if is_single: resized = array_ops.squeeze(resized, axis=0) return resized @@ -222,13 +224,13 @@ def run_inception(images, image_size: Required image width and height. See unit tests for the default values. input_tensor: Name of input Tensor. - output_tensor: Name of output Tensor. This function will compute activations - at the specified layer. Examples include INCEPTION_V3_OUTPUT and - INCEPTION_V3_FINAL_POOL which would result in this function computing + output_tensor: Name or list of output Tensors. This function will compute + activations at the specified layer. Examples include INCEPTION_V3_OUTPUT + and INCEPTION_V3_FINAL_POOL which would result in this function computing the final logits or the penultimate pooling layer. Returns: - Logits. + Tensor or Tensors corresponding to computed `output_tensor`. Raises: ValueError: If images are not the correct size. @@ -244,8 +246,14 @@ def run_inception(images, activations = run_image_classifier(images, graph_def, input_tensor, output_tensor) - if array_ops.rank(activations) != 2: - activations = layers.flatten(activations) + if isinstance(activations, list): + for i, activation in enumerate(activations): + if array_ops.rank(activation) != 2: + activations[i] = layers.flatten(activation) + else: + if array_ops.rank(activations) != 2: + activations = layers.flatten(activations) + return activations @@ -257,23 +265,26 @@ def run_image_classifier(tensor, graph_def, input_tensor, tensor: An Input tensor. graph_def: A GraphDef proto. input_tensor: Name of input tensor in graph def. - output_tensor: Name of output tensor in graph def. + output_tensor: A tensor name or list of tensor names in graph def. scope: Name scope for classifier. Returns: - Classifier output. Shape depends on the classifier used, but is often - [batch, classes]. + Classifier output if `output_tensor` is a string, or a list of outputs if + `output_tensor` is a list. Raises: - ValueError: If `image_size` is not `None`, and `tensor` are not the correct - size. + ValueError: If `input_tensor` or `output_tensor` aren't in the graph_def. """ input_map = {input_tensor: tensor} - return_elements = [output_tensor] - classifier_output = importer.import_graph_def( - graph_def, input_map, return_elements, name=scope)[0] + is_singleton = isinstance(output_tensor, str) + if is_singleton: + output_tensor = [output_tensor] + classifier_outputs = importer.import_graph_def( + graph_def, input_map, output_tensor, name=scope) + if is_singleton: + classifier_outputs = classifier_outputs[0] - return classifier_output + return classifier_outputs def classifier_score(images, classifier_fn, num_batches=1): @@ -297,7 +308,8 @@ def classifier_score(images, classifier_fn, num_batches=1): efficiently run them through the classifier network. Returns: - The classifier score. A floating-point scalar. + The classifier score. A floating-point scalar of the same type as the output + of `classifier_fn`. """ generated_images_list = array_ops.split( images, num_or_size_splits=num_batches) @@ -311,12 +323,36 @@ def classifier_score(images, classifier_fn, num_batches=1): swap_memory=True, name='RunClassifier') logits = array_ops.concat(array_ops.unstack(logits), 0) + + return classifier_score_from_logits(logits) + + +def classifier_score_from_logits(logits): + """Classifier score for evaluating a conditional generative model. + + This is based on the Inception Score, but for an arbitrary classifier. + + This technique is described in detail in https://arxiv.org/abs/1606.03498. In + summary, this function calculates + + exp( E[ KL(p(y|x) || p(y)) ] ) + + which captures how different the network's classification prediction is from + the prior distribution over classes. + + Args: + logits: A 2D Tensor of logits. + + Returns: + The classifier score. A floating-point scalar of the same type as the output + of `logits`. + """ logits.shape.assert_has_rank(2) # Use maximum precision for best results. logits_dtype = logits.dtype if logits_dtype != dtypes.float64: - logits = math_ops.cast(logits, dtypes.float64) + logits = math_ops.to_double(logits) p = nn_ops.softmax(logits) q = math_ops.reduce_mean(p, axis=0) @@ -326,7 +362,7 @@ def classifier_score(images, classifier_fn, num_batches=1): final_score = math_ops.exp(log_score) if logits_dtype != dtypes.float64: - final_score = math_ops.cast(final_score, dtypes.float64) + final_score = math_ops.cast(final_score, logits_dtype) return final_score @@ -415,7 +451,8 @@ def frechet_classifier_distance(real_images, efficiently run them through the classifier network. Returns: - The Frechet Inception distance. A floating-point scalar. + The Frechet Inception distance. A floating-point scalar of the same type + as the output of `classifier_fn` """ real_images_list = array_ops.split( @@ -440,20 +477,65 @@ def frechet_classifier_distance(real_images, # Ensure the activations have the right shapes. real_a = array_ops.concat(array_ops.unstack(real_a), 0) gen_a = array_ops.concat(array_ops.unstack(gen_a), 0) - real_a.shape.assert_has_rank(2) - gen_a.shape.assert_has_rank(2) + + return frechet_classifier_distance_from_activations(real_a, gen_a) + + +def frechet_classifier_distance_from_activations( + real_activations, generated_activations): + """Classifier distance for evaluating a generative model. + + This is based on the Frechet Inception distance, but for an arbitrary + classifier. + + This technique is described in detail in https://arxiv.org/abs/1706.08500. + Given two Gaussian distribution with means m and m_w and covariance matrices + C and C_w, this function calcuates + + |m - m_w|^2 + Tr(C + C_w - 2(C * C_w)^(1/2)) + + which captures how different the distributions of real images and generated + images (or more accurately, their visual features) are. Note that unlike the + Inception score, this is a true distance and utilizes information about real + world images. + + Note that when computed using sample means and sample covariance matrices, + Frechet distance is biased. It is more biased for small sample sizes. (e.g. + even if the two distributions are the same, for a small sample size, the + expected Frechet distance is large). It is important to use the same + sample size to compute frechet classifier distance when comparing two + generative models. + + Args: + real_activations: Real images to use to compute Frechet Inception distance. + generated_activations: Generated images to use to compute Frechet Inception + distance. + + Returns: + The Frechet Inception distance. A floating-point scalar of the same type + as the output of the activations. + """ + real_activations.shape.assert_has_rank(2) + generated_activations.shape.assert_has_rank(2) + + activations_dtype = real_activations.dtype + if activations_dtype != dtypes.float64: + real_activations = math_ops.to_double(real_activations) + generated_activations = math_ops.to_double(generated_activations) # Compute mean and covariance matrices of activations. - m = math_ops.reduce_mean(real_a, 0) - m_v = math_ops.reduce_mean(gen_a, 0) - num_examples = math_ops.to_float(array_ops.shape(real_a)[0]) + m = math_ops.reduce_mean(real_activations, 0) + m_v = math_ops.reduce_mean(generated_activations, 0) + num_examples = math_ops.to_double(array_ops.shape(real_activations)[0]) # sigma = (1 / (n - 1)) * (X - mu) (X - mu)^T + real_centered = real_activations - m sigma = math_ops.matmul( - real_a - m, real_a - m, transpose_a=True) / (num_examples - 1) + real_centered, real_centered, transpose_a=True) / (num_examples - 1) + gen_centered = generated_activations - m_v sigma_v = math_ops.matmul( - gen_a - m_v, gen_a - m_v, transpose_a=True) / (num_examples - 1) + gen_centered, gen_centered, transpose_a=True) / (num_examples - 1) # Find the Tr(sqrt(sigma sigma_v)) component of FID sqrt_trace_component = trace_sqrt_product(sigma, sigma_v) @@ -467,6 +549,8 @@ def frechet_classifier_distance(real_images, # Next the distance between means. mean = math_ops.square(linalg_ops.norm(m - m_v)) # This uses the L2 norm. fid = trace + mean + if activations_dtype != dtypes.float64: + fid = math_ops.cast(fid, activations_dtype) return fid diff --git a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py index 81fa2fc0f126647d2f01a1f4fc695d714eba2c75..1e18c699ba93b5f524341c65d0a2db84556b65a2 100644 --- a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py +++ b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py @@ -190,6 +190,23 @@ class ClassifierMetricsTest(test.TestCase): # Check that none of the model variables are trainable. self.assertListEqual([], variables.trainable_variables()) + def test_run_inception_multiple_outputs(self): + """Test `run_inception` graph construction with multiple outputs.""" + batch_size = 3 + img = array_ops.ones([batch_size, 299, 299, 3]) + logits, pool = _run_with_mock( + classifier_metrics.run_inception, img, + output_tensor=[classifier_metrics.INCEPTION_OUTPUT, + classifier_metrics.INCEPTION_FINAL_POOL]) + + self.assertTrue(isinstance(logits, ops.Tensor)) + self.assertTrue(isinstance(pool, ops.Tensor)) + logits.shape.assert_is_compatible_with([batch_size, 1001]) + pool.shape.assert_is_compatible_with([batch_size, 2048]) + + # Check that none of the model variables are trainable. + self.assertListEqual([], variables.trainable_variables()) + def test_inception_score_graph(self): """Test `inception_score` graph construction.""" score = _run_with_mock(classifier_metrics.inception_score, @@ -277,7 +294,7 @@ class ClassifierMetricsTest(test.TestCase): expected_fid = _expected_fid(test_pool_real_a, test_pool_gen_a) - self.assertAllClose(expected_fid, actual_fid, 0.01) + self.assertAllClose(expected_fid, actual_fid, 0.0001) def test_trace_sqrt_product_value(self): """Test that `trace_sqrt_product` gives the correct value.""" diff --git a/tensorflow/contrib/gan/python/features/__init__.py b/tensorflow/contrib/gan/python/features/__init__.py index 6d0972f8db418d6fcf517cc6f7e96093ae08a9e4..4816daf760143af9f1502873b123ffad8e5ec8ce 100644 --- a/tensorflow/contrib/gan/python/features/__init__.py +++ b/tensorflow/contrib/gan/python/features/__init__.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""TFGAN grouped API. Please see README.md for details and usage.""" +"""TFGAN features module. + +This module includes support for virtual batch normalization, buffer replay, +conditioning, etc. +""" from __future__ import absolute_import from __future__ import division @@ -22,10 +26,12 @@ from __future__ import print_function # pylint: disable=unused-import,wildcard-import from tensorflow.contrib.gan.python.features.python import clip_weights from tensorflow.contrib.gan.python.features.python import conditioning_utils +from tensorflow.contrib.gan.python.features.python import random_tensor_pool from tensorflow.contrib.gan.python.features.python import virtual_batchnorm from tensorflow.contrib.gan.python.features.python.clip_weights import * from tensorflow.contrib.gan.python.features.python.conditioning_utils import * +from tensorflow.contrib.gan.python.features.python.random_tensor_pool import * from tensorflow.contrib.gan.python.features.python.virtual_batchnorm import * # pylint: enable=unused-import,wildcard-import @@ -33,5 +39,6 @@ from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = clip_weights.__all__ _allowed_symbols += conditioning_utils.__all__ +_allowed_symbols += random_tensor_pool.__all__ _allowed_symbols += virtual_batchnorm.__all__ remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/gan/python/features/python/clip_weights_test.py b/tensorflow/contrib/gan/python/features/python/clip_weights_test.py index 030e37ec679ec58e3b534fd3644ffe1d23173404..2b7bb5f14e7f3d1b3f913d3426efaaae19079ffb 100644 --- a/tensorflow/contrib/gan/python/features/python/clip_weights_test.py +++ b/tensorflow/contrib/gan/python/features/python/clip_weights_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for tfgan.python.features.clip_weights.""" +"""Tests for features.clip_weights.""" from __future__ import absolute_import from __future__ import division @@ -31,17 +31,18 @@ class ClipWeightsTest(test.TestCase): """Tests for `discriminator_weight_clip`.""" def setUp(self): + super(ClipWeightsTest, self).setUp() self.variables = [variables.Variable(2.0)] self.tuple = collections.namedtuple( 'VarTuple', ['discriminator_variables'])(self.variables) def _test_weight_clipping_helper(self, use_tuple): - loss = self.variables[0] * 2.0 + loss = self.variables[0] opt = training.GradientDescentOptimizer(1.0) if use_tuple: - opt_clip = clip_weights.weight_clip(opt, self.variables, 0.1) + opt_clip = clip_weights.clip_variables(opt, self.variables, 0.1) else: - opt_clip = clip_weights.discriminator_weight_clip(opt, self.tuple, 0.1) + opt_clip = clip_weights.clip_discriminator_weights(opt, self.tuple, 0.1) train_op1 = opt.minimize(loss, var_list=self.variables) train_op2 = opt_clip.minimize(loss, var_list=self.variables) @@ -72,10 +73,14 @@ class ClipWeightsTest(test.TestCase): clip_weights.clip_discriminator_weights(opt, self.tuple, weight_clip=-1) else: with self.assertRaisesRegexp(ValueError, 'must be positive'): - clip_weights.clip_weights(opt, self.variables, weight_clip=-1) + clip_weights.clip_variables(opt, self.variables, weight_clip=-1) def test_incorrect_weight_clip_value_argsonly(self): self._test_incorrect_weight_clip_value_helper(False) def test_incorrect_weight_clip_value_tuple(self): self._test_incorrect_weight_clip_value_helper(True) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/gan/python/features/python/random_tensor_pool.py b/tensorflow/contrib/gan/python/features/python/random_tensor_pool.py new file mode 100644 index 0000000000000000000000000000000000000000..ca904971fa8cb0440d3e0c9060f13cc214c9eaad --- /dev/null +++ b/tensorflow/contrib/gan/python/features/python/random_tensor_pool.py @@ -0,0 +1,35 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A tensor pool stores values from an input tensor and returns a stored one. + +See the following papers for more details. +1) `Learning from simulated and unsupervised images through adversarial + training` (https://arxiv.org/abs/1612.07828). +2) `Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial + Networks` (https://arxiv.org/abs/1703.10593). +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.gan.python.features.python import random_tensor_pool_impl +# pylint: disable=wildcard-import +from tensorflow.contrib.gan.python.features.python.random_tensor_pool_impl import * +# pylint: enable=wildcard-import +from tensorflow.python.util.all_util import remove_undocumented + +__all__ = random_tensor_pool_impl.__all__ +remove_undocumented(__name__, __all__) diff --git a/tensorflow/contrib/gan/python/features/python/random_tensor_pool_impl.py b/tensorflow/contrib/gan/python/features/python/random_tensor_pool_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..9d733b6ff9f6afc44e8a0d9364729de506fc36d2 --- /dev/null +++ b/tensorflow/contrib/gan/python/features/python/random_tensor_pool_impl.py @@ -0,0 +1,134 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A tensor pool stores values from an input tensor and returns a stored one. + +We use this to keep a history of values created by a generator, such that +a discriminator can randomly be trained on some older samples, not just the +current one. This can help to not let the discriminator get too far ahead of the +generator and also to keep the system from oscilating, if the discriminator +forgets too fast what past samples from the generator looked like. + +See the following papers for more details. +1) `Learning from simulated and unsupervised images through adversarial + training` (https://arxiv.org/abs/1612.07828). +2) `Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial + Networks` (https://arxiv.org/abs/1703.10593). +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import data_flow_ops +from tensorflow.python.ops import random_ops + +__all__ = [ + 'tensor_pool', +] + + +def _to_tuple(x): + if isinstance(x, (list, tuple)): + return tuple(x) + return (x,) + + +def tensor_pool(input_values, + pool_size, + pooling_probability=0.5, + name='tensor_pool'): + """Queue storing input values and returning random previously stored ones. + + Every time the returned `output_value` is evaluated, `input_value` is + evaluated and its value either directly returned (with + `1-pooling_probability`) or stored in the pool and a random one of the samples + currently in the pool is popped and returned. As long as the pool in not fully + filled, the input_value is always directly returned, as well as stored in the + pool. Note during inference / testing, it may be appropriate to set + `pool_size` = 0 or `pooling_probability` = 0. + + Args: + input_values: A `Tensor`, or a list or tuple of `Tensor`s from which to read + values to be pooled. + pool_size: An integer specifying the maximum size of the pool. + pooling_probability: A float `Tensor` specifying the probability of getting + a value from the pool, as opposed to just the current input. + name: A string prefix for the name scope for all tensorflow ops. + + Returns: + A `Tensor`, or a list or tuple of `Tensor`s (according to the type ofx + `input_values`) which is with given probability either the `input_values` or + a randomly chosen sample that was previously inserted in the pool. + + Raises: + ValueError: If `pool_size` is negative. + """ + pool_size = int(pool_size) + if pool_size < 0: + raise ValueError('`pool_size` is negative.') + elif pool_size == 0: + return input_values + + original_input_values = input_values + input_values = _to_tuple(input_values) + + with ops.name_scope( + '{}_pool_queue'.format(name), + values=input_values + (pooling_probability,)): + pool_queue = data_flow_ops.RandomShuffleQueue( + capacity=pool_size, + min_after_dequeue=0, + dtypes=[v.dtype for v in input_values], + shapes=None) + + # In pseudeo code this code does the following: + # if not pool_full: + # enqueue(input_values) + # return input_values + # else + # dequeue_values = dequeue_random_sample() + # enqueue(input_values) + # if rand() < pooling_probability: + # return dequeue_values + # else + # return input_values + + def _get_input_value_pooled(): + enqueue_op = pool_queue.enqueue(input_values) + with ops.control_dependencies([enqueue_op]): + return tuple(array_ops.identity(v) for v in input_values) + + def _get_random_pool_value_and_enqueue_input(): + dequeue_values = _to_tuple(pool_queue.dequeue()) + with ops.control_dependencies(dequeue_values): + enqueue_op = pool_queue.enqueue(input_values) + with ops.control_dependencies([enqueue_op]): + prob = random_ops.random_uniform( + (), dtype=dtypes.float32) < pooling_probability + return control_flow_ops.cond(prob, lambda: dequeue_values, + lambda: input_values) + + output_values = _to_tuple(control_flow_ops.cond( + pool_queue.size() < pool_size, _get_input_value_pooled, + _get_random_pool_value_and_enqueue_input)) + + if isinstance(original_input_values, list): + return list(output_values) + elif isinstance(original_input_values, tuple): + return output_values + return output_values[0] diff --git a/tensorflow/contrib/gan/python/features/python/random_tensor_pool_test.py b/tensorflow/contrib/gan/python/features/python/random_tensor_pool_test.py new file mode 100644 index 0000000000000000000000000000000000000000..cef3a87ab34f9754099073eefcb3f1b1c97a3762 --- /dev/null +++ b/tensorflow/contrib/gan/python/features/python/random_tensor_pool_test.py @@ -0,0 +1,110 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tf.contrib.gan.python.features.random_tensor_pool.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.gan.python.features.python.random_tensor_pool_impl import tensor_pool +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class TensorPoolTest(test.TestCase): + + def test_pool_unknown_input_shape(self): + """Checks that `input_value` can have unknown shape.""" + input_value = array_ops.placeholder( + dtype=dtypes.int32, shape=[None, None, 3]) + output_value = tensor_pool(input_value, pool_size=10) + + with self.test_session(use_gpu=True) as session: + for i in range(10): + session.run(output_value, {input_value: [[[i] * 3]]}) + session.run(output_value, {input_value: [[[i] * 3] * 2]}) + session.run(output_value, {input_value: [[[i] * 3] * 5] * 2}) + + def test_pool_sequence(self): + """Checks that values are pooled and returned maximally twice.""" + input_value = array_ops.placeholder(dtype=dtypes.int32, shape=[]) + output_value = tensor_pool(input_value, pool_size=10) + + with self.test_session(use_gpu=True) as session: + outs = [] + for i in range(50): + out = session.run(output_value, {input_value: i}) + outs.append(out) + self.assertLessEqual(out, i) + + _, counts = np.unique(outs, return_counts=True) + # Check that each value is returned maximally twice. + self.assertTrue((counts <= 2).all()) + + def test_never_pool(self): + """Checks that setting `pooling_probability` to zero works.""" + input_value = array_ops.placeholder(dtype=dtypes.int32, shape=[]) + output_value = tensor_pool( + input_value, pool_size=10, pooling_probability=0.0) + + with self.test_session(use_gpu=True) as session: + for i in range(50): + out = session.run(output_value, {input_value: i}) + self.assertEqual(out, i) + + def test_pooling_probability(self): + """Checks that `pooling_probability` works.""" + input_value = array_ops.placeholder(dtype=dtypes.int32, shape=[]) + pool_size = 10 + pooling_probability = 0.2 + output_value = tensor_pool( + input_value, + pool_size=pool_size, + pooling_probability=pooling_probability) + + with self.test_session(use_gpu=True) as session: + not_pooled = 0 + total = 1000 + for i in range(total): + out = session.run(output_value, {input_value: i}) + if out == i: + not_pooled += 1 + self.assertAllClose( + (not_pooled - pool_size) / (total - pool_size), + 1 - pooling_probability, + atol=0.03) + + def test_input_values_tuple(self): + """Checks that `input_values` can be a tuple.""" + input_values = (array_ops.placeholder(dtype=dtypes.int32, shape=[]), + array_ops.placeholder(dtype=dtypes.int32, shape=[])) + output_values = tensor_pool(input_values, pool_size=3) + self.assertEqual(len(output_values), len(input_values)) + + with self.test_session(use_gpu=True) as session: + for i in range(10): + outs = session.run(output_values, { + input_values[0]: i, + input_values[1]: i + 1 + }) + self.assertEqual(len(outs), len(input_values)) + self.assertEqual(outs[1] - outs[0], 1) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/gan/python/losses/__init__.py b/tensorflow/contrib/gan/python/losses/__init__.py index 290ff867a1e443f20a63e27fd97f53fed8a6cc11..d9bf8ebfdf65dfc76e4569dcaf26e0e51c7fc107 100644 --- a/tensorflow/contrib/gan/python/losses/__init__.py +++ b/tensorflow/contrib/gan/python/losses/__init__.py @@ -12,7 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""TFGAN grouped API. Please see README.md for details and usage.""" +"""TFGAN losses and penalties. + +Losses can be used with individual arguments or with GANModel tuples. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/gan/python/namedtuples.py b/tensorflow/contrib/gan/python/namedtuples.py index 48f5e8e47dbcd5d32c23806b967a0d1e7403d2f7..3d4e315ebd0bd52b3b5e3e4a8655df8bfe9cebe8 100644 --- a/tensorflow/contrib/gan/python/namedtuples.py +++ b/tensorflow/contrib/gan/python/namedtuples.py @@ -79,6 +79,7 @@ class InfoGANModel( collections.namedtuple('InfoGANModel', GANModel._fields + ( 'structured_generator_inputs', 'predicted_distributions', + 'discriminator_and_aux_fn', ))): """An InfoGANModel contains all the pieces needed for InfoGAN training. @@ -91,6 +92,8 @@ class InfoGANModel( predicted_distributions: A list of tf.Distributions. Predicted by the recognizer, and used to evaluate the likelihood of the structured noise. List length should match `structured_generator_inputs`. + discriminator_and_aux_fn: The original discriminator function that returns + a tuple of (logits, `predicted_distributions`). """ diff --git a/tensorflow/contrib/gan/python/train.py b/tensorflow/contrib/gan/python/train.py index 06dd281489be7b12d9123ca83d926bc7b81f7e10..27c1a2245135299ac943bc2b2dd89dd10e52ea1b 100644 --- a/tensorflow/contrib/gan/python/train.py +++ b/tensorflow/contrib/gan/python/train.py @@ -58,6 +58,7 @@ __all__ = [ 'get_sequential_train_hooks', 'get_joint_train_hooks', 'get_sequential_train_steps', + 'RunTrainOpsHook', ] @@ -214,7 +215,8 @@ def infogan_model( disc_scope, lambda x, y: discriminator_fn(x, y)[0], # conform to non-InfoGAN API structured_generator_inputs, - predicted_distributions) + predicted_distributions, + discriminator_fn) def acgan_model( @@ -421,7 +423,7 @@ def gan_loss( ac_disc_loss = tfgan_losses.acgan_discriminator_loss( model, add_summaries=add_summaries) dis_loss += aux_cond_discriminator_weight * ac_disc_loss - # Gathers auxilliary losses. + # Gathers auxiliary losses. if model.generator_scope: gen_reg_loss = losses.get_regularization_loss(model.generator_scope.name) else: diff --git a/tensorflow/contrib/gan/python/train_test.py b/tensorflow/contrib/gan/python/train_test.py index 6b27b6926102b6e5a7ff134ceed75c23459a6534..4d4ede706c51ec17d0ea5bd1854ea2cd79358bdb 100644 --- a/tensorflow/contrib/gan/python/train_test.py +++ b/tensorflow/contrib/gan/python/train_test.py @@ -145,14 +145,16 @@ def get_infogan_model(): return namedtuples.InfoGANModel( *get_gan_model(), structured_generator_inputs=[constant_op.constant(0)], - predicted_distributions=[categorical.Categorical([1.0])]) + predicted_distributions=[categorical.Categorical([1.0])], + discriminator_and_aux_fn=infogan_discriminator_model) def get_callable_infogan_model(): return namedtuples.InfoGANModel( *get_callable_gan_model(), structured_generator_inputs=[constant_op.constant(0)], - predicted_distributions=[categorical.Categorical([1.0])]) + predicted_distributions=[categorical.Categorical([1.0])], + discriminator_and_aux_fn=infogan_discriminator_model) def create_infogan_model(): diff --git a/tensorflow/contrib/gdr/BUILD b/tensorflow/contrib/gdr/BUILD index a417dba87543d82526ab856e5b915ee47f496d46..bdbe6f0a72621e59562fe113da101ff5a2b8c06d 100644 --- a/tensorflow/contrib/gdr/BUILD +++ b/tensorflow/contrib/gdr/BUILD @@ -103,6 +103,7 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core/distributed_runtime:base_rendezvous_mgr", + "//tensorflow/core/distributed_runtime:tensor_coding", "//tensorflow/core/distributed_runtime:worker_cache", "//tensorflow/core/distributed_runtime:worker_env", "//tensorflow/core/distributed_runtime:worker_interface", diff --git a/tensorflow/contrib/graph_editor/util.py b/tensorflow/contrib/graph_editor/util.py index 959905e9826fe439112078a32fef9a5f5b96e9ac..30bc33b9ee42ba78bc7307c67c0fc0af9f3356ef 100644 --- a/tensorflow/contrib/graph_editor/util.py +++ b/tensorflow/contrib/graph_editor/util.py @@ -93,6 +93,8 @@ class ListView(object): # TODO(fkp): very generic code, it should be moved in a more generic place. def is_iterable(obj): """Return true if the object is iterable.""" + if isinstance(obj, tf_ops.Tensor): + return False try: _ = iter(obj) except Exception: # pylint: disable=broad-except diff --git a/tensorflow/contrib/image/BUILD b/tensorflow/contrib/image/BUILD index 157e97d237021d95c935a6be66aa57842b97125c..54502cfc6eecb9d064ffde9773e97d893a24133a 100755 --- a/tensorflow/contrib/image/BUILD +++ b/tensorflow/contrib/image/BUILD @@ -9,6 +9,7 @@ package(default_visibility = ["//visibility:public"]) load( "//tensorflow:tensorflow.bzl", + "tf_cc_test", "tf_custom_op_library", "tf_gen_op_libs", "tf_gen_op_wrapper_py", @@ -106,10 +107,33 @@ tf_custom_op_library( name = "python/ops/_distort_image_ops.so", srcs = [ "kernels/adjust_hsv_in_yiq_op.cc", + "kernels/adjust_hsv_in_yiq_op.h", "ops/distort_image_ops.cc", ], + gpu_srcs = [ + "kernels/adjust_hsv_in_yiq_op_gpu.cu.cc", + "kernels/adjust_hsv_in_yiq_op.h", + ], deps = [ - "@protobuf_archive//:protobuf", + "//tensorflow/core/kernels:gpu_util_hdrs", + ], +) + +tf_cc_test( + name = "adjust_hsv_in_yiq_op_test", + size = "small", + srcs = [ + "kernels/adjust_hsv_in_yiq_op.h", + "kernels/adjust_hsv_in_yiq_op_test.cc", + ], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/kernels:ops_testutil", + "//tensorflow/core/kernels:ops_util", + "//third_party/eigen3", ], ) @@ -122,19 +146,6 @@ tf_gen_op_wrapper_py( deps = [":distort_image_ops_op_lib"], ) -cc_library( - name = "distort_image_ops_cc", - srcs = [ - "kernels/adjust_hsv_in_yiq_op.cc", - ], - deps = [ - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//third_party/eigen3", - ], - alwayslink = 1, -) - py_library( name = "distort_image_py", srcs = [ diff --git a/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.cc b/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.cc index f4962ed69dc68d4bad06ef29d7a167e0ba8ae044..478b716d88321101c971789f36c0ff8ecd3f418e 100644 --- a/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.cc +++ b/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.cc @@ -12,14 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include +#if GOOGLE_CUDA +#define EIGEN_USE_GPU +#endif + +#include "tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.h" #include -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/util/work_sharder.h" @@ -36,10 +37,10 @@ class AdjustHsvInYiqOpBase : public OpKernel { struct ComputeOptions { const Tensor* input = nullptr; + Tensor* output = nullptr; const Tensor* delta_h = nullptr; const Tensor* scale_s = nullptr; const Tensor* scale_v = nullptr; - Tensor* output = nullptr; int64 channel_count = 0; }; @@ -65,7 +66,7 @@ class AdjustHsvInYiqOpBase : public OpKernel { scale_v.shape().DebugString())); auto channels = input.dim_size(input.dims() - 1); OP_REQUIRES( - context, channels == 3, + context, channels == kChannelSize, errors::InvalidArgument("input must have 3 channels but instead has ", channels, " channels.")); @@ -101,53 +102,21 @@ class AdjustHsvInYiqOp : public AdjustHsvInYiqOpBase { const Tensor* input = options.input; Tensor* output = options.output; const int64 channel_count = options.channel_count; - static const int kChannelSize = 3; auto input_data = input->shaped({channel_count, kChannelSize}); const float delta_h = options.delta_h->scalar()(); const float scale_s = options.scale_s->scalar()(); const float scale_v = options.scale_v->scalar()(); auto output_data = output->shaped({channel_count, kChannelSize}); + float tranformation_matrix[kChannelSize * kChannelSize] = {0}; + internal::compute_tranformation_matrix( + delta_h, scale_s, scale_v, tranformation_matrix); const int kCostPerChannel = 10; const DeviceBase::CpuWorkerThreads& worker_threads = *context->device()->tensorflow_cpu_worker_threads(); Shard(worker_threads.num_threads, worker_threads.workers, channel_count, kCostPerChannel, - [channel_count, &input_data, &output_data, delta_h, scale_s, scale_v]( + [channel_count, &input_data, &output_data, &tranformation_matrix]( int64 start_channel, int64 end_channel) { - // Using approximate linear transfomation described in: - // https://beesbuzz.biz/code/hsv_color_transforms.php - /** Get the constants from sympy - from sympy import Matrix - from sympy.abc import u, w - # Projection matrix to YIQ. http://en.wikipedia.org/wiki/YIQ - tyiq = Matrix([[0.299, 0.587, 0.114], - [0.596, -0.274, -0.322], - [0.211, -0.523, 0.312]]) - # Hue rotation matrix in YIQ space. - hue_proj = Matrix(3,3, [v, 0, 0, 0, vsu, -vsw, 0, vsw, vsu]) - m = tyiq.inv() * hue_proj * tyiq - **/ - // TODO(huangyp): directly compute the projection matrix from tyiq. - static const float t[kChannelSize][kChannelSize][kChannelSize] = { - {{.299, .701, .16862179492229}, - {.587, -.587, .329804745287403}, - {.114, -.114, -0.498426540209694}}, - {{.299, -.299, -.327963394172371}, - {.587, .413, .0346106879248821}, - {.114, -.114, .293352706247489}}, - {{.299, -.299, 1.24646136576682}, - {.587, -.587, -1.04322888291964}, - {.114, .886, -.203232482847173}}}; - float m[kChannelSize][kChannelSize] = {{0.}}; - float su = scale_s * std::cos(delta_h); - float sw = scale_s * std::sin(delta_h); - for (int q_index = 0; q_index < kChannelSize; q_index++) { - for (int p_index = 0; p_index < kChannelSize; p_index++) { - m[q_index][p_index] = scale_v * (t[q_index][p_index][0] + - t[q_index][p_index][1] * su + - t[q_index][p_index][2] * sw); - } - } // Applying projection matrix to input RGB vectors. const float* p = input_data.data() + start_channel * kChannelSize; float* q = output_data.data() + start_channel * kChannelSize; @@ -155,7 +124,9 @@ class AdjustHsvInYiqOp : public AdjustHsvInYiqOpBase { for (int q_index = 0; q_index < kChannelSize; q_index++) { q[q_index] = 0; for (int p_index = 0; p_index < kChannelSize; p_index++) { - q[q_index] += m[q_index][p_index] * p[p_index]; + q[q_index] += + p[p_index] * + tranformation_matrix[q_index + kChannelSize * p_index]; } } p += kChannelSize; @@ -165,8 +136,33 @@ class AdjustHsvInYiqOp : public AdjustHsvInYiqOpBase { } }; -REGISTER_KERNEL_BUILDER(Name("AdjustHsvInYiq").Device(DEVICE_CPU), - AdjustHsvInYiqOp); +REGISTER_KERNEL_BUILDER( + Name("AdjustHsvInYiq").Device(DEVICE_CPU).TypeConstraint("T"), + AdjustHsvInYiqOp); + +#if GOOGLE_CUDA +template <> +class AdjustHsvInYiqOp : public AdjustHsvInYiqOpBase { + public: + explicit AdjustHsvInYiqOp(OpKernelConstruction* context) + : AdjustHsvInYiqOpBase(context) {} + + void DoCompute(OpKernelContext* ctx, const ComputeOptions& options) override { + const int64 number_of_elements = options.input->NumElements(); + if (number_of_elements <= 0) { + return; + } + const float* delta_h = options.delta_h->flat().data(); + const float* scale_s = options.scale_s->flat().data(); + const float* scale_v = options.scale_v->flat().data(); + functor::AdjustHsvInYiqGPU()(ctx, options.channel_count, options.input, + delta_h, scale_s, scale_v, options.output); + } +}; + +REGISTER_KERNEL_BUILDER( + Name("AdjustHsvInYiq").Device(DEVICE_GPU).TypeConstraint("T"), + AdjustHsvInYiqOp); +#endif -// TODO(huangyp): add the GPU kernel } // namespace tensorflow diff --git a/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.h b/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.h new file mode 100644 index 0000000000000000000000000000000000000000..194ae2ba47456cac66c01989a78ab4ce607d1295 --- /dev/null +++ b/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.h @@ -0,0 +1,87 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_IMAGE_KERNELS_ADJUST_HSV_IN_YIQ_OP_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_IMAGE_KERNELS_ADJUST_HSV_IN_YIQ_OP_H_ + +#if GOOGLE_CUDA +#define EIGEN_USE_GPU +#endif // GOOGLE_CUDA + +#include +#include "third_party/eigen3/Eigen/Core" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/types.h" + +namespace tensorflow { + +static constexpr int kChannelSize = 3; + +namespace internal { + +template +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void compute_tranformation_matrix( + const float delta_h, const float scale_s, const float scale_v, + float* matrix) { + static_assert(MATRIX_SIZE == kChannelSize * kChannelSize, + "Size of matrix should be 9."); + // Projection matrix from RGB to YIQ. Numbers from wikipedia + // https://en.wikipedia.org/wiki/YIQ + Eigen::Matrix3f yiq; + /* clang-format off */ + yiq << 0.299, 0.587, 0.114, + 0.596, -0.274, -0.322, + 0.211, -0.523, 0.312; + Eigen::Matrix3f yiq_inverse; + yiq_inverse << 1, 0.95617069, 0.62143257, + 1, -0.2726886, -0.64681324, + 1, -1.103744, 1.70062309; + /* clang-format on */ + // Construct hsv linear transformation matrix in YIQ space. + // https://beesbuzz.biz/code/hsv_color_transforms.php + float vsu = scale_v * scale_s * std::cos(delta_h); + float vsw = scale_v * scale_s * std::sin(delta_h); + Eigen::Matrix3f hsv_transform; + /* clang-format off */ + hsv_transform << scale_v, 0, 0, + 0, vsu, -vsw, + 0, vsw, vsu; + /* clang-format on */ + // Compute final transformation matrix = inverse_yiq * hsv_transform * yiq + Eigen::Map> eigen_matrix(matrix); + eigen_matrix = yiq_inverse * hsv_transform * yiq; +} +} // namespace internal + +#if GOOGLE_CUDA +typedef Eigen::GpuDevice GPUDevice; + +namespace functor { + +struct AdjustHsvInYiqGPU { + void operator()(OpKernelContext* ctx, int channel_count, + const Tensor* const input, const float* const delta_h, + const float* const scale_s, const float* const scale_v, + Tensor* const output); +}; + +} // namespace functor + +#endif // GOOGLE_CUDA + +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_IMAGE_KERNELS_ADJUST_HSV_IN_YIQ_OP_H_ diff --git a/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op_gpu.cu.cc b/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op_gpu.cu.cc new file mode 100644 index 0000000000000000000000000000000000000000..b71ff9cd507faac66b3a33d3c02ec9b5901d814a --- /dev/null +++ b/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op_gpu.cu.cc @@ -0,0 +1,84 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.h" +#include "tensorflow/core/kernels/gpu_utils.h" +#include "tensorflow/core/platform/stream_executor.h" +#include "tensorflow/core/util/cuda_kernel_helper.h" + +namespace tensorflow { + +namespace internal { + +__global__ void compute_tranformation_matrix_cuda(const float* const delta_h, + const float* const scale_s, + const float* const scale_v, + float* const matrix, + const int matrix_size) { + if (matrix_size == kChannelSize * kChannelSize) { + compute_tranformation_matrix( + *delta_h, *scale_s, *scale_v, matrix); + } +} +} // namespace internal + +namespace functor { + +void AdjustHsvInYiqGPU::operator()(OpKernelContext* ctx, int channel_count, + const Tensor* const input, + const float* const delta_h, + const float* const scale_s, + const float* const scale_v, + Tensor* const output) { + const uint64 m = channel_count; + const uint64 k = kChannelSize; + const uint64 n = kChannelSize; + auto* cu_stream = ctx->eigen_device().stream(); + OP_REQUIRES(ctx, cu_stream, errors::Internal("No GPU stream available.")); + Tensor tranformation_matrix; + OP_REQUIRES_OK(ctx, ctx->allocate_temp( + DT_FLOAT, TensorShape({kChannelSize * kChannelSize}), + &tranformation_matrix)); + // TODO(huangyp): It takes about 3.5 us to comute tranformation_matrix + // with one thread. Improve its performance if necessary. + internal::compute_tranformation_matrix_cuda<<<1, 1, 0, cu_stream>>>( + delta_h, scale_s, scale_v, tranformation_matrix.flat().data(), + tranformation_matrix.flat().size()); + // Call cuBlas C = A * B directly. + auto no_transpose = perftools::gputools::blas::Transpose::kNoTranspose; + auto a_ptr = + AsDeviceMemory(input->flat().data(), input->flat().size()); + auto b_ptr = AsDeviceMemory(tranformation_matrix.flat().data(), + tranformation_matrix.flat().size()); + auto c_ptr = AsDeviceMemory(output->flat().data(), + output->flat().size()); + auto* stream = ctx->op_device_context()->stream(); + OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available.")); + // TODO(huangyp): share/use autotune cublas algorithms in Matmul.op. + bool blas_launch_status = + stream + ->ThenBlasGemm(no_transpose, no_transpose, n, m, k, 1.0f, b_ptr, n, + a_ptr, k, 0.0f, &c_ptr, n) + .ok(); + if (!blas_launch_status) { + ctx->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m, + ", n=", n, ", k=", k)); + } +} +} // namespace functor +} // namespace tensorflow +#endif // GOOGLE_CUDA diff --git a/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op_test.cc b/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..4cbbd277840133c9419f9ce3d945b7d099679dc0 --- /dev/null +++ b/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op_test.cc @@ -0,0 +1,48 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +class AdjustHsvInYiqOpTest : public OpsTestBase { + protected: +}; + +TEST_F(AdjustHsvInYiqOpTest, IdentiyTransformMatrix) { + Tensor matrix(allocator(), DT_FLOAT, TensorShape({9})); + internal::compute_tranformation_matrix<9>(0.0, 1.0, 1.0, + matrix.flat().data()); + Tensor expected(allocator(), DT_FLOAT, TensorShape({9})); + test::FillValues(&expected, {1, 0, 0, 0, 1, 0, 0, 0, 1}); + test::ExpectClose(matrix, expected); +} + +TEST_F(AdjustHsvInYiqOpTest, ScaleValueTransformMatrix) { + float scale_v = 2.3; + Tensor matrix(allocator(), DT_FLOAT, TensorShape({9})); + internal::compute_tranformation_matrix<9>(0.0, 1.0, scale_v, + matrix.flat().data()); + Tensor expected(allocator(), DT_FLOAT, TensorShape({9})); + test::FillValues(&expected, + {scale_v, 0, 0, 0, scale_v, 0, 0, 0, scale_v}); + test::ExpectClose(matrix, expected); +} + +} // end namespace tensorflow diff --git a/tensorflow/contrib/image/ops/single_image_random_dot_stereograms_ops.cc b/tensorflow/contrib/image/ops/single_image_random_dot_stereograms_ops.cc index 2b6799213827537f77deda4e052bb7ec16f46343..f8b56ab1c5400694b3aa8d4a0c19c7769aa8cbce 100755 --- a/tensorflow/contrib/image/ops/single_image_random_dot_stereograms_ops.cc +++ b/tensorflow/contrib/image/ops/single_image_random_dot_stereograms_ops.cc @@ -40,7 +40,7 @@ REGISTER_OP("SingleImageRandomDotStereograms") .Doc(R"doc( Outputs a single image random dot stereogram for export via encode_PNG/JPG OP. -Given the 2-D tensor 'depth_values' with encoded Z values, this operation will +Given the 2-D tensor 'depth_values' with encoded Z values, this operation will encode 3-D data into a 2-D image. The output of this Op is suitable for the encode_PNG/JPG ops. Be careful with image compression as this may corrupt the encode 3-D data witin the image. @@ -68,14 +68,14 @@ with open('picture_out.png', 'wb') as f: f.write(png) ``` -depth_values: Z values of data to encode into 'output_data_window' window, +depth_values: Z values of data to encode into 'output_data_window' window, lower values are further away {0.0 floor(far), 1.0 ceiling(near) after normalization}, must be 2-D tensor hidden_surface_removal: Activate hidden surface removal convergence_dots_size: Black dot size in pixels to help view converge image, drawn on bottom of image dots_per_inch: Output device in dots/inch eye_separation: Separation between eyes in inches mu: Depth of field, Fraction of viewing distance (eg. 1/3 = .3333) -normalize: Normalize input data to [0.0, 1.0] +normalize: Normalize input data to [0.0, 1.0] normalize_max: Fix MAX value for Normalization - if < MIN, autoscale normalize_min: Fix MIN value for Normalization - if > MAX, autoscale border_level: Value of border depth 0.0 {far} to 1.0 {near} diff --git a/tensorflow/contrib/image/python/kernel_tests/distort_image_ops_test.py b/tensorflow/contrib/image/python/kernel_tests/distort_image_ops_test.py index b85f19d29b79defa10493bdbaa4a1b237cb2a9ee..a495b58b7f6481d4cdedf73f23615d0390eb6a45 100644 --- a/tensorflow/contrib/image/python/kernel_tests/distort_image_ops_test.py +++ b/tensorflow/contrib/image/python/kernel_tests/distort_image_ops_test.py @@ -172,7 +172,7 @@ class AdjustValueInYiqTest(test_util.TensorFlowTestCase): raise AssertionError('Invalid test style: %s' % (test_style)) y_np = self._adjust_value_in_yiq_np(x_np, scale) y_tf = self._adjust_value_in_yiq_tf(x_np, scale) - self.assertAllClose(y_tf, y_np, rtol=2e-5, atol=1e-5) + self.assertAllClose(y_tf, y_np, rtol=2e-4, atol=1e-4) def test_invalid_shapes(self): x_np = np.random.rand(2, 3) * 255. @@ -237,7 +237,7 @@ class AdjustSaturationInYiqTest(test_util.TensorFlowTestCase): raise AssertionError('Invalid test style: %s' % (test_style)) y_baseline = self._adjust_saturation_in_yiq_np(x_np, scale) y_tf = self._adjust_saturation_in_yiq_tf(x_np, scale) - self.assertAllClose(y_tf, y_baseline, rtol=2e-5, atol=1e-5) + self.assertAllClose(y_tf, y_baseline, rtol=2e-4, atol=1e-4) def test_invalid_shapes(self): x_np = np.random.rand(2, 3) * 255. @@ -291,6 +291,9 @@ class AdjustHueInYiqBenchmark(test.Benchmark): def benchmark_adjust_hue_in_yiqCpuAll(self): self._benchmark_adjust_hue_in_yiq('/cpu:0', None) + def benchmark_adjust_hue_in_yiq_gpu_all(self): + self._benchmark_adjust_hue_in_yiq(test.gpu_device_name(), None) + class AdjustSaturationInYiqBenchmark(test.Benchmark): @@ -333,6 +336,9 @@ class AdjustSaturationInYiqBenchmark(test.Benchmark): def benchmark_adjust_saturation_in_yiq_cpu_all(self): self._benchmark_adjust_saturation_in_yiq('/cpu:0', None) + def benchmark_adjust_saturation_in_yiq_gpu_all(self): + self._benchmark_adjust_saturation_in_yiq(test.gpu_device_name(), None) + if __name__ == '__main__': googletest.main() diff --git a/tensorflow/contrib/image/python/ops/image_ops.py b/tensorflow/contrib/image/python/ops/image_ops.py index 011ddeaa9a1eebaa507c9e0d33f9546ff3497166..faedee6f87772016561671bacd87f88657eafffb 100644 --- a/tensorflow/contrib/image/python/ops/image_ops.py +++ b/tensorflow/contrib/image/python/ops/image_ops.py @@ -224,7 +224,8 @@ def transform(images, transforms, interpolation="NEAREST", name=None): `(x, y)` to a transformed *input* point `(x', y') = ((a0 x + a1 y + a2) / k, (b0 x + b1 y + b2) / k)`, where `k = c0 x + c1 y + 1`. The transforms are *inverted* compared to - the transform mapping input points to output points. + the transform mapping input points to output points. Note that gradients + are not backpropagated into transformation parameters. interpolation: Interpolation mode. Supported values: "NEAREST", "BILINEAR". Returns: diff --git a/tensorflow/contrib/image/python/ops/single_image_random_dot_stereograms.py b/tensorflow/contrib/image/python/ops/single_image_random_dot_stereograms.py index 5cccf26028ca6bf269dbc67a33075351edecb407..bb766e59d2cee648042cc08be466796d9233ad66 100755 --- a/tensorflow/contrib/image/python/ops/single_image_random_dot_stereograms.py +++ b/tensorflow/contrib/image/python/ops/single_image_random_dot_stereograms.py @@ -68,7 +68,7 @@ def single_image_random_dot_stereograms( ``` Args: - depth_values: A `Tensor`. Must be one of the following types: + depth_values: A `Tensor`. Must be one of the following types: `float64`, `float32`, `int64`, `int32`. Z values of data to encode into 'output_data_window' window, lower further away {0.0 floor(far), 1.0 ceiling(near) after norm}, must be 2-D tensor @@ -84,17 +84,17 @@ def single_image_random_dot_stereograms( mu: An optional `float`. Defaults to `0.3333`. Depth of field, Fraction of viewing distance (eg. 1/3 = 0.3333) normalize: An optional `bool`. Defaults to `True`. - Normalize input data to [0.0, 1.0] + Normalize input data to [0.0, 1.0] normalize_max: An optional `float`. Defaults to `-100`. Fix MAX value for Normalization (0.0) - if < MIN, autoscale normalize_min: An optional `float`. Defaults to `100`. Fix MIN value for Normalization (0.0) - if > MAX, autoscale border_level: An optional `float`. Defaults to `0`. - Value of bord in depth 0.0 {far} to 1.0 {near} + Value of bord in depth 0.0 {far} to 1.0 {near} number_colors: An optional `int`. Defaults to `256`. 2 (Black & White), 256 (grayscale), and Numbers > 256 (Full Color) are supported - output_image_shape: An optional `tf.TensorShape` or list of `ints`. + output_image_shape: An optional `tf.TensorShape` or list of `ints`. Defaults to shape `[1024, 768, 1]`. Defines output shape of returned image in '[X,Y, Channels]' 1-grayscale, 3 color; channels will be updated to 3 if number_colors > 256 diff --git a/tensorflow/contrib/kfac/python/kernel_tests/BUILD b/tensorflow/contrib/kfac/python/kernel_tests/BUILD index 5d86373a232d55cd281d06cfc0606f4224d8f669..95fba59e3c96ae3c69e0b154740785b0d2bcb3c9 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/BUILD +++ b/tensorflow/contrib/kfac/python/kernel_tests/BUILD @@ -16,6 +16,7 @@ py_test( "//tensorflow/contrib/kfac/python/ops:utils", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:init_ops", @@ -33,6 +34,7 @@ py_test( "//tensorflow/contrib/kfac/python/ops:fisher_factors", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:gradients", @@ -68,6 +70,7 @@ py_test( srcs = ["layer_collection_test.py"], srcs_version = "PY2AND3", deps = [ + "//tensorflow/contrib/kfac/python/ops:fisher_blocks", "//tensorflow/contrib/kfac/python/ops:fisher_factors", "//tensorflow/contrib/kfac/python/ops:layer_collection", "//tensorflow/python:array_ops", @@ -75,6 +78,7 @@ py_test( "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:linalg_ops", + "//tensorflow/python:math_ops", "//tensorflow/python:random_ops", "//tensorflow/python:random_seed", "//tensorflow/python:variable_scope", @@ -88,7 +92,6 @@ py_test( deps = [ "//tensorflow/contrib/kfac/python/ops:kfac_optimizer", "//tensorflow/contrib/kfac/python/ops:layer_collection", - "//tensorflow/contrib/kfac/python/ops:loss_functions", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_ops", @@ -139,6 +142,7 @@ py_test( "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", "//tensorflow/python:framework_ops", + "//tensorflow/python:random_ops", "//third_party/py/numpy", ], ) diff --git a/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py b/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py index b52a7b52a7efd4292ad514c5a744c4da07082142..9b28c45c7263208d21b1514ae5f05b7e81e315a3 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py @@ -21,6 +21,7 @@ from __future__ import print_function from tensorflow.contrib.kfac.python.ops import estimator from tensorflow.contrib.kfac.python.ops import layer_collection as lc from tensorflow.contrib.kfac.python.ops import utils +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -33,6 +34,30 @@ from tensorflow.python.platform import test _ALL_ESTIMATION_MODES = ["gradients", "empirical", "curvature_prop", "exact"] +class DeviceContextGeneratorTest(test.TestCase): + + def testNoDevice(self): + device_context_generator = estimator._DeviceContextGenerator(None) + with ops.device("/device:CPU:0"): # This is what will be used + with device_context_generator(): # Does nothing + a = constant_op.constant([2.0], name="a") + self.assertEqual("/device:CPU:0", a.op.device) + + def testTwoDevices(self): + device_context_generator = estimator._DeviceContextGenerator( + ["/device:GPU:0", "/device:GPU:1"]) + with ops.device("/device:CPU:0"): # Will be over-ridden by the inner scopes + with device_context_generator(): + a = constant_op.constant([2.0], name="a") + with device_context_generator(): + b = constant_op.constant([2.0], name="b") + with device_context_generator(): + c = constant_op.constant([2.0], name="c") + self.assertEqual("/device:GPU:0", a.op.device) + self.assertEqual("/device:GPU:1", b.op.device) + self.assertEqual("/device:GPU:0", c.op.device) + + class EstimatorTest(test.TestCase): def setUp(self): diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py index dbf40fccc8257b1dec6cbd790adfa59161ab9049..2d9b28185ce0db32d5cd7d84737fdf96e2c98851 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py @@ -40,13 +40,29 @@ def _make_psd(dim): return array_ops.constant(mat) +class UtilsTest(test.TestCase): + + def testComputePiTracenorm(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + left_factor = array_ops.diag([1., 2., 0., 1.]) + right_factor = array_ops.ones([2., 2.]) + + # pi is the sqrt of the left trace norm divided by the right trace norm + pi = fb._compute_pi_tracenorm(left_factor, right_factor) + + pi_val = sess.run(pi) + self.assertEqual(1., pi_val) + + class FullFBTest(test.TestCase): def testFullFBInitSingleTensor(self): with ops.Graph().as_default(): random_seed.set_random_seed(200) params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) - block = fb.FullFB(lc.LayerCollection(), params, 32) + block = fb.FullFB(lc.LayerCollection(), params) + block.register_additional_minibatch(32) self.assertAllEqual(params, block.tensors_to_compute_grads()) @@ -54,7 +70,8 @@ class FullFBTest(test.TestCase): with ops.Graph().as_default(): random_seed.set_random_seed(200) params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) - block = fb.FullFB(lc.LayerCollection(), params, 32) + block = fb.FullFB(lc.LayerCollection(), params) + block.register_additional_minibatch(32) self.assertAllEqual(params, block.tensors_to_compute_grads()) @@ -62,7 +79,8 @@ class FullFBTest(test.TestCase): with ops.Graph().as_default(): random_seed.set_random_seed(200) params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) - block = fb.FullFB(lc.LayerCollection(), params, 32) + block = fb.FullFB(lc.LayerCollection(), params) + block.register_additional_minibatch(32) grads = (params[0]**2, math_ops.sqrt(params[1])) block.instantiate_factors(grads, 0.5) @@ -71,7 +89,8 @@ class FullFBTest(test.TestCase): with ops.Graph().as_default(), self.test_session() as sess: random_seed.set_random_seed(200) params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) - block = fb.FullFB(lc.LayerCollection(), params, 32) + block = fb.FullFB(lc.LayerCollection(), params) + block.register_additional_minibatch(32) grads = (params[0]**2, math_ops.sqrt(params[1])) block.instantiate_factors((grads,), 0.5) @@ -88,7 +107,8 @@ class FullFBTest(test.TestCase): with ops.Graph().as_default(), self.test_session() as sess: random_seed.set_random_seed(200) params = array_ops.constant([[1.], [2.]]) - block = fb.FullFB(lc.LayerCollection(), params, 32) + block = fb.FullFB(lc.LayerCollection(), params) + block.register_additional_minibatch(32) grads = params**2 block.instantiate_factors((grads,), 0.5) @@ -105,7 +125,8 @@ class FullFBTest(test.TestCase): with ops.Graph().as_default(), self.test_session() as sess: random_seed.set_random_seed(200) params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) - block = fb.FullFB(lc.LayerCollection(), params, 32) + block = fb.FullFB(lc.LayerCollection(), params) + block.register_additional_minibatch(32) grads = (array_ops.constant([2., 3.]), array_ops.constant(4.)) damping = 0.5 block.instantiate_factors((grads,), damping) @@ -131,7 +152,8 @@ class NaiveDiagonalFBTest(test.TestCase): with ops.Graph().as_default(): random_seed.set_random_seed(200) params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) - block = fb.NaiveDiagonalFB(lc.LayerCollection(), params, 32) + block = fb.NaiveDiagonalFB(lc.LayerCollection(), params) + block.register_additional_minibatch(32) self.assertAllEqual(params, block.tensors_to_compute_grads()) @@ -139,7 +161,8 @@ class NaiveDiagonalFBTest(test.TestCase): with ops.Graph().as_default(): random_seed.set_random_seed(200) params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) - block = fb.NaiveDiagonalFB(lc.LayerCollection(), params, 32) + block = fb.NaiveDiagonalFB(lc.LayerCollection(), params) + block.register_additional_minibatch(32) self.assertAllEqual(params, block.tensors_to_compute_grads()) @@ -147,7 +170,8 @@ class NaiveDiagonalFBTest(test.TestCase): with ops.Graph().as_default(): random_seed.set_random_seed(200) params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) - block = fb.NaiveDiagonalFB(lc.LayerCollection(), params, 32) + block = fb.NaiveDiagonalFB(lc.LayerCollection(), params) + block.register_additional_minibatch(32) grads = (params[0]**2, math_ops.sqrt(params[1])) block.instantiate_factors(grads, 0.5) @@ -156,7 +180,8 @@ class NaiveDiagonalFBTest(test.TestCase): with ops.Graph().as_default(), self.test_session() as sess: random_seed.set_random_seed(200) params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) - block = fb.NaiveDiagonalFB(lc.LayerCollection(), params, 32) + block = fb.NaiveDiagonalFB(lc.LayerCollection(), params) + block.register_additional_minibatch(32) grads = (params[0]**2, math_ops.sqrt(params[1])) block.instantiate_factors((grads,), 0.5) @@ -173,7 +198,8 @@ class NaiveDiagonalFBTest(test.TestCase): with ops.Graph().as_default(), self.test_session() as sess: random_seed.set_random_seed(200) params = array_ops.constant([[1.], [2.]]) - block = fb.NaiveDiagonalFB(lc.LayerCollection(), params, 32) + block = fb.NaiveDiagonalFB(lc.LayerCollection(), params) + block.register_additional_minibatch(32) grads = params**2 block.instantiate_factors((grads,), 0.5) @@ -189,7 +215,8 @@ class NaiveDiagonalFBTest(test.TestCase): with ops.Graph().as_default(), self.test_session() as sess: random_seed.set_random_seed(200) params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) - block = fb.NaiveDiagonalFB(lc.LayerCollection(), params, 32) + block = fb.NaiveDiagonalFB(lc.LayerCollection(), params) + block.register_additional_minibatch(32) grads = (params[0]**2, math_ops.sqrt(params[1])) damping = 0.5 block.instantiate_factors((grads,), damping) @@ -289,8 +316,7 @@ class FullyConnectedDiagonalFB(test.TestCase): multiply_result_big, multiply_inverse_result_big = self.runFisherBlockOps( self.w, [self.inputs], [self.outputs], [self.output_grads]) multiply_result_small, multiply_inverse_result_small = ( - self.runFisherBlockOps(self.w, - np.split(self.inputs, 2), + self.runFisherBlockOps(self.w, np.split(self.inputs, 2), np.split(self.outputs, 2), np.split(self.output_grads, 2))) @@ -572,8 +598,7 @@ class ConvDiagonalFBTest(test.TestCase): multiply_result_big, multiply_inverse_result_big = self.runFisherBlockOps( self.w, [self.inputs], [self.outputs], [self.output_grads]) multiply_result_small, multiply_inverse_result_small = ( - self.runFisherBlockOps(self.w, - np.split(self.inputs, 2), + self.runFisherBlockOps(self.w, np.split(self.inputs, 2), np.split(self.outputs, 2), np.split(self.output_grads, 2))) @@ -596,8 +621,9 @@ class ConvDiagonalFBTest(test.TestCase): self.kernel_size, self.kernel_size, self.input_channels + 1, self.output_channels ]) - expected_result = (expected_result[:, :, 0:-1, :], np.reshape( - expected_result[:, :, -1, :], [self.output_channels])) + expected_result = (expected_result[:, :, 0:-1, :], + np.reshape(expected_result[:, :, -1, :], + [self.output_channels])) self.assertEqual(len(result), 2) self.assertAllClose(expected_result[0], result[0]) @@ -680,8 +706,8 @@ class ConvKFCBasicFBTest(test.TestCase): sess.run(block._input_factor.make_inverse_update_ops()) sess.run(block._output_factor.make_inverse_update_ops()) - vector = (np.arange(1, 15).reshape(7, 2).astype(np.float32), np.arange( - 2, 4).reshape(2, 1).astype(np.float32)) + vector = (np.arange(1, 15).reshape(7, 2).astype(np.float32), + np.arange(2, 4).reshape(2, 1).astype(np.float32)) output = block.multiply_inverse((array_ops.constant(vector[0]), array_ops.constant(vector[1]))) @@ -764,11 +790,50 @@ class ConvKFCBasicFBTest(test.TestCase): self.assertAllClose(output_flat, explicit) +class FullyConnectedSeriesFBTest(test.TestCase): + + def testFullyConnectedSeriesFBInit(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + inputs = array_ops.constant([1., 2.]) + outputs = array_ops.constant([3., 4.]) + block = fb.FullyConnectedSeriesFB( + lc.LayerCollection(), inputs=[inputs], outputs=[outputs]) + self.assertAllEqual([outputs], block.tensors_to_compute_grads()) + + def testInstantiateFactorsHasBias(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + inputs = array_ops.constant([[1., 2.], [3., 4.]]) + outputs = array_ops.constant([[3., 4.], [5., 6.]]) + block = fb.FullyConnectedSeriesFB( + lc.LayerCollection(), + inputs=[inputs], + outputs=[outputs], + has_bias=True) + grads = outputs**2 + block.instantiate_factors(((grads,),), 0.5) + + def testInstantiateFactorsNoBias(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + inputs = array_ops.constant([[1., 2.], [3., 4.]]) + outputs = array_ops.constant([[3., 4.], [5., 6.]]) + block = fb.FullyConnectedSeriesFB( + lc.LayerCollection(), + inputs=[inputs], + outputs=[outputs], + has_bias=False) + grads = outputs**2 + block.instantiate_factors(((grads,),), 0.5) + + def as_tensors(tensor_or_tuple): """Converts a potentially nested tuple of np.array to Tensors.""" if isinstance(tensor_or_tuple, (tuple, list)): return tuple(as_tensors(t) for t in tensor_or_tuple) return ops.convert_to_tensor(tensor_or_tuple) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py index fbb3d219139a4bc05253841a89e73645ef37dddd..70e56db055078bd4399b03e4d4a877e34249cc5e 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py @@ -22,6 +22,7 @@ import numpy as np import numpy.random as npr from tensorflow.contrib.kfac.python.ops import fisher_factors as ff +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops as tf_ops from tensorflow.python.framework import random_seed @@ -32,6 +33,25 @@ from tensorflow.python.ops import variables as tf_variables from tensorflow.python.platform import test +class MaybeColocateTest(test.TestCase): + + def testFalse(self): + with tf_ops.Graph().as_default(): + a = constant_op.constant([2.0], name='a') + with ff._maybe_colocate_with(a, False): + b = constant_op.constant(3.0, name='b') + self.assertEqual([b'loc:@a'], a.op.colocation_groups()) + self.assertEqual([b'loc:@b'], b.op.colocation_groups()) + + def testTrue(self): + with tf_ops.Graph().as_default(): + a = constant_op.constant([2.0], name='a') + with ff._maybe_colocate_with(a, True): + b = constant_op.constant(3.0, name='b') + self.assertEqual([b'loc:@a'], a.op.colocation_groups()) + self.assertEqual([b'loc:@a'], b.op.colocation_groups()) + + class FisherFactorTestingDummy(ff.FisherFactor): """Dummy class to test the non-abstract methods on ff.FisherFactor.""" @@ -47,12 +67,19 @@ class FisherFactorTestingDummy(ff.FisherFactor): def _num_sources(self): return 1 + @property + def _dtype(self): + return dtypes.float32 + def _compute_new_cov(self): raise NotImplementedError def instantiate_covariance(self): pass + def make_inverse_update_ops(self): + return [] + class InverseProvidingFactorTestingDummy(ff.InverseProvidingFactor): """Dummy class to test the non-abstract methods on ff.InverseProvidingFactor. @@ -74,6 +101,10 @@ class InverseProvidingFactorTestingDummy(ff.InverseProvidingFactor): def _num_sources(self): return 1 + @property + def _dtype(self): + return dtypes.float32 + def _compute_new_cov(self): raise NotImplementedError @@ -101,7 +132,7 @@ class NumericalUtilsTest(test.TestCase): normalizer = 10. x = npr.randn(100, 3) - cov = ff._compute_cov(array_ops.constant(x), normalizer) + cov = ff._compute_cov(array_ops.constant(x), normalizer=normalizer) np_cov = np.dot(x.T, x) / normalizer self.assertAllClose(sess.run(cov), np_cov) @@ -247,13 +278,13 @@ class InverseProvidingFactorTest(test.TestCase): for i in range(1, ff.EIGENVALUE_DECOMPOSITION_THRESHOLD + 1): factor.register_damped_inverse(1. / i) ops = factor.make_inverse_update_ops() - self.assertEqual(ff.EIGENVALUE_DECOMPOSITION_THRESHOLD, len(ops)) + self.assertEqual(1, len(ops)) sess.run(tf_variables.global_variables_initializer()) new_invs = [] + sess.run(ops) for i in range(1, ff.EIGENVALUE_DECOMPOSITION_THRESHOLD + 1): # The inverse op will assign the damped inverse of cov to the inv var. - sess.run(ops[i - 1]) new_invs.append(sess.run(factor._inverses_by_damping[1. / i])) # We want to see that the new invs are all different from each other. for i in range(len(new_invs)): @@ -311,6 +342,16 @@ class FullFactorTest(test.TestCase): factor = ff.FullFactor((tensor,), 32) self.assertEqual([6, 6], factor.get_cov().get_shape().as_list()) + def testFullFactorInitFloat64(self): + with tf_ops.Graph().as_default(): + dtype = dtypes.float64_ref + random_seed.set_random_seed(200) + tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c') + factor = ff.FullFactor((tensor,), 32) + cov = factor.get_cov() + self.assertEqual(cov.dtype, dtype) + self.assertEqual([6, 6], cov.get_shape().as_list()) + def testMakeCovarianceUpdateOp(self): with tf_ops.Graph().as_default(), self.test_session() as sess: random_seed.set_random_seed(200) @@ -331,6 +372,16 @@ class NaiveDiagonalFactorTest(test.TestCase): factor = ff.NaiveDiagonalFactor((tensor,), 32) self.assertEqual([6, 1], factor.get_cov().get_shape().as_list()) + def testNaiveDiagonalFactorInitFloat64(self): + with tf_ops.Graph().as_default(): + dtype = dtypes.float64_ref + random_seed.set_random_seed(200) + tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c') + factor = ff.NaiveDiagonalFactor((tensor,), 32) + cov = factor.get_cov() + self.assertEqual(cov.dtype, dtype) + self.assertEqual([6, 1], cov.get_shape().as_list()) + def testMakeCovarianceUpdateOp(self): with tf_ops.Graph().as_default(), self.test_session() as sess: random_seed.set_random_seed(200) @@ -344,18 +395,25 @@ class NaiveDiagonalFactorTest(test.TestCase): class FullyConnectedKroneckerFactorTest(test.TestCase): - def _testFullyConnectedKroneckerFactorInit(self, has_bias, final_shape): + def _testFullyConnectedKroneckerFactorInit(self, + has_bias, + final_shape, + dtype=dtypes.float32_ref): with tf_ops.Graph().as_default(): random_seed.set_random_seed(200) - tensor = array_ops.ones((2, 3), name='a/b/c') + tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c') factor = ff.FullyConnectedKroneckerFactor((tensor,), has_bias=has_bias) - self.assertEqual(final_shape, factor.get_cov().get_shape().as_list()) + cov = factor.get_cov() + self.assertEqual(cov.dtype, dtype) + self.assertEqual(final_shape, cov.get_shape().as_list()) def testFullyConnectedKroneckerFactorInitNoBias(self): - self._testFullyConnectedKroneckerFactorInit(False, [3, 3]) + for dtype in (dtypes.float32_ref, dtypes.float64_ref): + self._testFullyConnectedKroneckerFactorInit(False, [3, 3], dtype=dtype) def testFullyConnectedKroneckerFactorInitWithBias(self): - self._testFullyConnectedKroneckerFactorInit(True, [4, 4]) + for dtype in (dtypes.float32_ref, dtypes.float64_ref): + self._testFullyConnectedKroneckerFactorInit(True, [4, 4], dtype=dtype) def testMakeCovarianceUpdateOpWithBias(self): with tf_ops.Graph().as_default(), self.test_session() as sess: @@ -398,6 +456,18 @@ class ConvInputKroneckerFactorTest(test.TestCase): self.assertEqual([1 * 2 * 3 + 1, 1 * 2 * 3 + 1], factor.get_cov().get_shape().as_list()) + def testConvInputKroneckerFactorInitFloat64(self): + with tf_ops.Graph().as_default(): + dtype = dtypes.float64_ref + random_seed.set_random_seed(200) + tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c') + factor = ff.ConvInputKroneckerFactor( + tensor, (1, 2, 3, 4), 3, 2, has_bias=True) + cov = factor.get_cov() + self.assertEqual(cov.dtype, dtype) + self.assertEqual([1 * 2 * 3 + 1, 1 * 2 * 3 + 1], + cov.get_shape().as_list()) + def testMakeCovarianceUpdateOpWithBias(self): with tf_ops.Graph().as_default(), self.test_session() as sess: random_seed.set_random_seed(200) @@ -433,6 +503,16 @@ class ConvOutputKroneckerFactorTest(test.TestCase): factor = ff.ConvOutputKroneckerFactor((tensor,)) self.assertEqual([5, 5], factor.get_cov().get_shape().as_list()) + def testConvOutputKroneckerFactorInitFloat64(self): + with tf_ops.Graph().as_default(): + dtype = dtypes.float64_ref + random_seed.set_random_seed(200) + tensor = array_ops.ones((2, 3, 4, 5), dtype=dtype, name='a/b/c') + factor = ff.ConvOutputKroneckerFactor((tensor,)) + cov = factor.get_cov() + self.assertEqual(cov.dtype, dtype) + self.assertEqual([5, 5], cov.get_shape().as_list()) + def testConvOutputKroneckerFactorInitNotEnoughDims(self): with tf_ops.Graph().as_default(): random_seed.set_random_seed(200) @@ -451,5 +531,49 @@ class ConvOutputKroneckerFactorTest(test.TestCase): self.assertAllClose([[43, 46.5], [46.5, 51.5]], new_cov) +class FullyConnectedMultiKFTest(test.TestCase): + + def testFullyConnectedMultiKFInit(self): + with tf_ops.Graph().as_default(): + random_seed.set_random_seed(200) + tensor = array_ops.ones((2, 3), name='a/b/c') + tensor_list = [tensor] + factor = ff.FullyConnectedMultiKF((tensor_list,), has_bias=False) + self.assertEqual([3, 3], factor.get_cov().get_shape().as_list()) + + def testFullyConnectedMultiKFInitFloat64(self): + with tf_ops.Graph().as_default(): + dtype = dtypes.float64_ref + random_seed.set_random_seed(200) + tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c') + tensor_list = [tensor] + factor = ff.FullyConnectedMultiKF((tensor_list,), has_bias=False) + cov = factor.get_cov() + self.assertEqual(cov.dtype, dtype) + self.assertEqual([3, 3], cov.get_shape().as_list()) + + def testMakeCovarianceUpdateOpWithBias(self): + with tf_ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c') + tensor_list = [tensor] + factor = ff.FullyConnectedMultiKF((tensor_list,), has_bias=True) + + sess.run(tf_variables.global_variables_initializer()) + new_cov = sess.run(factor.make_covariance_update_op(.5)) + self.assertAllClose([[3, 3.5, 1], [3.5, 5.5, 1.5], [1, 1.5, 1]], new_cov) + + def testMakeCovarianceUpdateOpNoBias(self): + with tf_ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c') + tensor_list = [tensor] + factor = ff.FullyConnectedMultiKF((tensor_list,)) + + sess.run(tf_variables.global_variables_initializer()) + new_cov = sess.run(factor.make_covariance_update_op(.5)) + self.assertAllClose([[3, 3.5], [3.5, 5.5]], new_cov) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py b/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py index db7ab63c7d1166649acbe41851a5876d8af476db..b8ccbeadd0a9d69edb41fef50e3edb090457adf2 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.kfac.python.ops import fisher_blocks from tensorflow.contrib.kfac.python.ops import fisher_factors from tensorflow.contrib.kfac.python.ops import layer_collection from tensorflow.python.framework import dtypes @@ -25,11 +26,27 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.ops import array_ops from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import variable_scope from tensorflow.python.platform import test +class MockFisherBlock(object): + """A fake FisherBlock.""" + + num_registered_minibatches = 2 + + def __init__(self, name='MockFisherBlock'): + self.name = name + + def __eq__(self, other): + return isinstance(other, MockFisherBlock) and other.name == self.name + + def __hash__(self): + return hash(self.name) + + class LayerParametersDictTest(test.TestCase): def testSetItem(self): @@ -90,8 +107,10 @@ class LayerCollectionTest(test.TestCase): array_ops.constant(4), [1, 1, 1, 1], 'SAME', array_ops.ones((1, 1, 1, 1)), array_ops.constant(3)) lc.register_conv2d( - array_ops.constant(4), [1, 1, 1, 1], 'SAME', - array_ops.ones((1, 1, 1, 1)), array_ops.constant(3), + array_ops.constant(4), [1, 1, 1, 1], + 'SAME', + array_ops.ones((1, 1, 1, 1)), + array_ops.constant(3), approx=layer_collection.APPROX_DIAGONAL_NAME) lc.register_generic( array_ops.constant(5), 16, approx=layer_collection.APPROX_FULL_NAME) @@ -107,10 +126,11 @@ class LayerCollectionTest(test.TestCase): random_seed.set_random_seed(200) lc = layer_collection.LayerCollection() key = array_ops.constant(1) - lc.register_fully_connected(key, - array_ops.constant(2), array_ops.constant(3)) - with self.assertRaises(ValueError): + lc.register_fully_connected(key, array_ops.constant(2), + array_ops.constant(3)) + with self.assertRaises(ValueError) as cm: lc.register_generic(key, 16) + self.assertIn('already in LayerCollection', str(cm.exception)) def testRegisterSingleParamNotRegistered(self): x = variable_scope.get_variable('x', initializer=array_ops.constant(1,)) @@ -125,16 +145,18 @@ class LayerCollectionTest(test.TestCase): x = variable_scope.get_variable('x', initializer=array_ops.constant(1,)) lc = layer_collection.LayerCollection() lc.fisher_blocks = {x: '1'} - with self.assertRaises(ValueError): + with self.assertRaises(ValueError) as cm: lc.register_block(x, 'foo') + self.assertIn('already in LayerCollection', str(cm.exception)) def testRegisterSingleParamRegisteredInTuple(self): x = variable_scope.get_variable('x', initializer=array_ops.constant(1,)) y = variable_scope.get_variable('y', initializer=array_ops.constant(1,)) lc = layer_collection.LayerCollection() lc.fisher_blocks = {(x, y): '1'} - lc.register_block(x, 'foo') - self.assertEqual(set(['1']), set(lc.get_blocks())) + with self.assertRaises(ValueError) as cm: + lc.register_block(x, 'foo') + self.assertIn('was already registered', str(cm.exception)) def testRegisterTupleParamNotRegistered(self): x = variable_scope.get_variable('x', initializer=array_ops.constant(1,)) @@ -154,8 +176,9 @@ class LayerCollectionTest(test.TestCase): lc = layer_collection.LayerCollection() lc.fisher_blocks = {(x, y): '1'} - with self.assertRaises(ValueError): + with self.assertRaises(ValueError) as cm: lc.register_block((x, y), 'foo') + self.assertIn('already in LayerCollection', str(cm.exception)) def testRegisterTupleParamRegisteredInSuperset(self): x = variable_scope.get_variable('x', initializer=array_ops.constant(1,)) @@ -164,18 +187,20 @@ class LayerCollectionTest(test.TestCase): lc = layer_collection.LayerCollection() lc.fisher_blocks = {(x, y, z): '1'} - lc.register_block((x, y), 'foo') - self.assertEqual(set(['1']), set(lc.get_blocks())) + with self.assertRaises(ValueError) as cm: + lc.register_block((x, y), 'foo') + self.assertIn('was already registered', str(cm.exception)) def testRegisterTupleParamSomeRegistered(self): x = variable_scope.get_variable('x', initializer=array_ops.constant(1,)) y = variable_scope.get_variable('y', initializer=array_ops.constant(1,)) z = variable_scope.get_variable('z', initializer=array_ops.constant(1,)) lc = layer_collection.LayerCollection() - lc.fisher_blocks = {x: '1', z: '2'} + lc.fisher_blocks = {x: MockFisherBlock('1'), z: MockFisherBlock('2')} - lc.register_block((x, y), 'foo') - self.assertEqual(set(['2', 'foo']), set(lc.get_blocks())) + with self.assertRaises(ValueError) as cm: + lc.register_block((x, y), MockFisherBlock('foo')) + self.assertIn('was already registered', str(cm.exception)) def testRegisterTupleVarSomeRegisteredInOtherTuples(self): x = variable_scope.get_variable('x', initializer=array_ops.constant(1,)) @@ -185,8 +210,9 @@ class LayerCollectionTest(test.TestCase): lc = layer_collection.LayerCollection() lc.fisher_blocks = {(x, z): '1', (z, w): '2'} - with self.assertRaises(ValueError): + with self.assertRaises(ValueError) as cm: lc.register_block((x, y), 'foo') + self.assertIn('was already registered', str(cm.exception)) def testRegisterCategoricalPredictiveDistribution(self): with ops.Graph().as_default(), self.test_session() as sess: @@ -406,6 +432,23 @@ class LayerCollectionTest(test.TestCase): self.ensureLayerReuseWorks(register_fn) + def testReuseWithInvalidRegistration(self): + """Invalid registrations shouldn't overwrite existing blocks.""" + with ops.Graph().as_default(): + inputs = array_ops.ones([2, 5, 5, 10]) + outputs = array_ops.zeros([2, 5, 5, 3]) + w = variable_scope.get_variable('w', [1, 1, 10, 3]) + b = variable_scope.get_variable('b', [3]) + lc = layer_collection.LayerCollection() + lc.register_fully_connected(w, inputs, outputs) + self.assertEqual(lc.fisher_blocks[w].num_registered_minibatches, 1) + with self.assertRaises(KeyError): + lc.register_fully_connected((w, b), inputs, outputs, reuse=True) + self.assertNotIn((w, b), lc.fisher_blocks) + self.assertEqual(lc.fisher_blocks[w].num_registered_minibatches, 1) + lc.register_fully_connected(w, inputs, outputs, reuse=True) + self.assertEqual(lc.fisher_blocks[w].num_registered_minibatches, 2) + def testMakeOrGetFactor(self): with ops.Graph().as_default(): random_seed.set_random_seed(200) @@ -438,11 +481,6 @@ class LayerCollectionTest(test.TestCase): def testGetUseCountMap(self): """Ensure get_use_count_map() sums 'num_registered_minibatches'.""" - - class MockFisherBlock(object): - - num_registered_minibatches = 2 - lc = layer_collection.LayerCollection() lc.fisher_blocks = { 'a': MockFisherBlock(), @@ -452,6 +490,66 @@ class LayerCollectionTest(test.TestCase): use_count_map = lc.get_use_count_map() self.assertDictEqual({'a': 4, 'b': 2, 'c': 4}, use_count_map) + def testIdentifyLinkedParametersSomeRegisteredInOtherTuples(self): + x = variable_scope.get_variable('x', shape=()) + y = variable_scope.get_variable('y', shape=()) + z = variable_scope.get_variable('z', shape=()) + lc = layer_collection.LayerCollection() + lc.define_linked_parameters((x, y)) + + with self.assertRaises(ValueError): + lc.define_linked_parameters((x, z)) + + def testIdentifySubsetPreviouslyRegisteredTensor(self): + x = variable_scope.get_variable('x', shape=()) + y = variable_scope.get_variable('y', shape=()) + lc = layer_collection.LayerCollection() + lc.define_linked_parameters((x, y)) + + with self.assertRaises(ValueError): + lc.define_linked_parameters(x) + + def testSpecifyApproximation(self): + w_0 = variable_scope.get_variable('w_0', [10, 10]) + w_1 = variable_scope.get_variable('w_1', [10, 10]) + + b_0 = variable_scope.get_variable('b_0', [10]) + b_1 = variable_scope.get_variable('b_1', [10]) + + x_0 = array_ops.placeholder(dtypes.float32, shape=(32, 10)) + x_1 = array_ops.placeholder(dtypes.float32, shape=(32, 10)) + + pre_bias_0 = math_ops.matmul(x_0, w_0) + pre_bias_1 = math_ops.matmul(x_1, w_1) + + # Build the fully connected layers in the graph. + pre_bias_0 + b_0 # pylint: disable=pointless-statement + pre_bias_1 + b_1 # pylint: disable=pointless-statement + + lc = layer_collection.LayerCollection() + lc.define_linked_parameters( + w_0, approximation=layer_collection.APPROX_DIAGONAL_NAME) + lc.define_linked_parameters( + w_1, approximation=layer_collection.APPROX_DIAGONAL_NAME) + lc.define_linked_parameters( + b_0, approximation=layer_collection.APPROX_FULL_NAME) + lc.define_linked_parameters( + b_1, approximation=layer_collection.APPROX_FULL_NAME) + + lc.register_fully_connected(w_0, x_0, pre_bias_0) + lc.register_fully_connected( + w_1, x_1, pre_bias_1, approx=layer_collection.APPROX_KRONECKER_NAME) + self.assertIsInstance(lc.fisher_blocks[w_0], + fisher_blocks.FullyConnectedDiagonalFB) + self.assertIsInstance(lc.fisher_blocks[w_1], + fisher_blocks.FullyConnectedKFACBasicFB) + + lc.register_generic(b_0, batch_size=1) + lc.register_generic( + b_1, batch_size=1, approx=layer_collection.APPROX_DIAGONAL_NAME) + self.assertIsInstance(lc.fisher_blocks[b_0], fisher_blocks.FullFB) + self.assertIsInstance(lc.fisher_blocks[b_1], fisher_blocks.NaiveDiagonalFB) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py b/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py index 87339cb059802ec8944d5d1ae4557ee34550cd60..39ce3e9337157c8206107bc40c489e44019743ab 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py @@ -24,6 +24,7 @@ from tensorflow.contrib.kfac.python.ops import loss_functions from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import random_ops from tensorflow.python.platform import test @@ -96,6 +97,22 @@ class CategoricalLogitsNegativeLogProbLossTest(test.TestCase): # difficult to say if the output is correct or not... neg_log_prob = sess.run(neg_log_prob) + def testMultiMinibatchRegistration(self): + """Ensure this loss function supports registering multiple minibatches.""" + with ops.Graph().as_default(): + tower_logits = [] + loss = None + num_towers = 5 + for _ in range(num_towers): + logits = random_ops.random_uniform(shape=[2, 3]) + tower_logits.append(logits) + if loss is None: + loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(logits) + else: + loss.register_additional_minibatch(logits) + self.assertListEqual(loss.input_minibatches, tower_logits) + self.assertEqual(loss.num_registered_minibatches, num_towers) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py b/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py index 55fe38e3e9aab2dbd70a45cdc8fa0c208b036db0..d255a6e7160386d8eb6fca00765eea8a318f4eaa 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py @@ -222,18 +222,6 @@ class UtilsTest(test.TestCase): self.assertAllClose(b, np.array([4., 5.])) self.assertAllClose(c, np.array([[6.], [7.], [8.], [9.]])) - def testComputePi(self): - with ops.Graph().as_default(), self.test_session() as sess: - random_seed.set_random_seed(200) - left_factor = array_ops.diag([1., 2., 0., 1.]) - right_factor = array_ops.ones([2., 2.]) - - # pi is the sqrt of the left trace norm divided by the right trace norm - pi = utils.compute_pi(left_factor, right_factor) - - pi_val = sess.run(pi) - self.assertEqual(1., pi_val) - def testPosDefInvCholesky(self): with ops.Graph().as_default(), self.test_session() as sess: random_seed.set_random_seed(200) diff --git a/tensorflow/contrib/kfac/python/ops/BUILD b/tensorflow/contrib/kfac/python/ops/BUILD index de4b8920b849dbf2117657de6e7c26f94f4d0363..3d731c7bc206d6f168e9b8f29b66bf4f1dbe8542 100644 --- a/tensorflow/contrib/kfac/python/ops/BUILD +++ b/tensorflow/contrib/kfac/python/ops/BUILD @@ -38,6 +38,7 @@ py_library( ":utils", "//tensorflow/python:array_ops", "//tensorflow/python:framework_ops", + "//tensorflow/python:init_ops", "//tensorflow/python:linalg_ops", "//tensorflow/python:math_ops", "//tensorflow/python:special_math_ops", @@ -171,6 +172,7 @@ py_library( deps = [ ":utils", "//tensorflow/python:control_flow_ops", + "//tensorflow/python:framework_ops", "//tensorflow/python:gradients", "//tensorflow/python:util", "//third_party/py/numpy", diff --git a/tensorflow/contrib/kfac/python/ops/estimator.py b/tensorflow/contrib/kfac/python/ops/estimator.py index 6e2c9ecdce7ad9f98a5beb016770ad2b1e197b0a..5e1680967c184bf19f2a2578219db07a48264dc9 100644 --- a/tensorflow/contrib/kfac/python/ops/estimator.py +++ b/tensorflow/contrib/kfac/python/ops/estimator.py @@ -18,16 +18,53 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import math +import contextlib +import itertools import numpy as np from tensorflow.contrib.kfac.python.ops import utils +from tensorflow.python.framework import ops as tf_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.util import nest +class _DeviceContextGenerator(object): + """Class for generating device contexts in a round-robin fashion.""" + + def __init__(self, devices): + """Creates a _DeviceContextGenerator object. + + Example usage: + + ```python + dcg = _DeviceContextGenerator(['/gpu:0', 'gpu:1']) + with dcg(): + # All operations in this context will be placed on GPU 0 + ... + with dcg(): + # All operations in this context will be placed on GPU 1 + ... + ``` + + Args: + devices: An iterable of device strings (or None). Successive calls to + __call__ will give contexts which place devices on these devices in + a round-robin fashion. + """ + self._cycle = None if devices is None else itertools.cycle(devices) + + @contextlib.contextmanager + def __call__(self): + """Returns a context manager specifying the default device.""" + if self._cycle is None: + yield + else: + with tf_ops.device(next(self._cycle)): + yield + + class FisherEstimator(object): """Fisher estimator class supporting various approximations of the Fisher.""" @@ -36,7 +73,10 @@ class FisherEstimator(object): cov_ema_decay, damping, layer_collection, - estimation_mode="gradients"): + estimation_mode="gradients", + colocate_gradients_with_ops=False, + cov_devices=None, + inv_devices=None): """Create a FisherEstimator object. Args: @@ -54,7 +94,7 @@ class FisherEstimator(object): blocks, kronecker factors, and losses associated with the graph. estimation_mode: The type of estimator to use for the Fishers. Can be - 'gradients', 'empirical', 'curvature_propagation', or 'exact'. + 'gradients', 'empirical', 'curvature_prop', or 'exact'. (Default: 'gradients'). 'gradients' is the basic estimation approach from the original K-FAC paper. 'empirical' computes the 'empirical' Fisher information matrix (which uses the data's distribution for the @@ -69,6 +109,14 @@ class FisherEstimator(object): for each coordinate of the output instead of using 1/-1 vectors. It is more expensive to compute than the other three options by a factor equal to the output dimension, roughly speaking. + colocate_gradients_with_ops: Whether we should request gradients be + colocated with their respective ops. + cov_devices: Iterable of device strings (e.g. '/gpu:0'). Covariance + computations will be placed on these devices in a round-robin fashion. + Can be None, which means that no devices are specified. + inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion + computations will be placed on these devices in a round-robin fashion. + Can be None, which means that no devices are specified. Raises: ValueError: If no losses have been registered with layer_collection. @@ -79,13 +127,19 @@ class FisherEstimator(object): self._estimation_mode = estimation_mode self._layers = layer_collection self._layers.create_subgraph() - self._check_registration(variables) + self._layers.check_registration(variables) self._gradient_fns = { "gradients": self._get_grads_lists_gradients, "empirical": self._get_grads_lists_empirical, "curvature_prop": self._get_grads_lists_curvature_prop, "exact": self._get_grads_lists_exact } + self._colocate_gradients_with_ops = colocate_gradients_with_ops + self._cov_device_context_generator = _DeviceContextGenerator(cov_devices) + if inv_devices == cov_devices: + self._inv_device_context_generator = self._cov_device_context_generator + else: + self._inv_device_context_generator = _DeviceContextGenerator(inv_devices) setup = self._setup(cov_ema_decay) self.cov_update_op, self.inv_update_op, self.inv_updates_dict = setup @@ -148,49 +202,6 @@ class FisherEstimator(object): return self._apply_transformation(vecs_and_vars, lambda fb, vec: fb.multiply(vec)) - def _check_registration(self, variables): - """Checks that all variable uses have been registered properly. - - Args: - variables: List of variables. - - Raises: - ValueError: If any registered variables are not included in the list. - ValueError: If any variable in the list is not registered. - ValueError: If any variable in the list is registered with the wrong - number of "uses" in the subgraph recorded (vs the number of times that - variable is actually used in the subgraph). - """ - # Note that overlapping parameters (i.e. those that share variables) will - # be caught by layer_collection.LayerParametersDict during registration. - - reg_use_map = self._layers.get_use_count_map() - - error_messages = [] - - for var in variables: - total_uses = self._layers.subgraph.variable_uses(var) - reg_uses = reg_use_map[var] - - if reg_uses == 0: - error_messages.append("Variable {} not registered.".format(var)) - elif (not math.isinf(reg_uses)) and reg_uses != total_uses: - error_messages.append( - "Variable {} registered with wrong number of uses ({} " - "vs {} actual).".format(var, reg_uses, total_uses)) - - num_get_vars = len(reg_use_map) - - if num_get_vars > len(variables): - error_messages.append("{} registered variables were not included in list." - .format(num_get_vars - len(variables))) - - if error_messages: - error_messages = [ - "Found the following errors with variable registration:" - ] + error_messages - raise ValueError("\n\t".join(error_messages)) - def _setup(self, cov_ema_decay): """Sets up the various operations. @@ -219,8 +230,13 @@ class FisherEstimator(object): raise ValueError("Unrecognized value {} for estimation_mode.".format( self._estimation_mode)) + # TODO(b/68033310): This loop round-robins the "concat" operations which + # gather the inputs for the cov_updates. In future, we might do these + # computations locally then communicate the results, which would require a + # modification to this code. for grads_list, fb in zip(grads_lists, fisher_blocks_list): - fb.instantiate_factors(grads_list, self.damping) + with self._cov_device_context_generator(): + fb.instantiate_factors(grads_list, self.damping) cov_updates = [ factor.make_covariance_update_op(cov_ema_decay) @@ -233,18 +249,23 @@ class FisherEstimator(object): def _get_all_inverse_update_ops(self): for factor in self._layers.get_factors(): - for op in factor.make_inverse_update_ops(): - yield op + with self._inv_device_context_generator(): + for op in factor.make_inverse_update_ops(): + yield op def _get_grads_lists_gradients(self, tensors): - grads_flat = gradients_impl.gradients(self._layers.total_sampled_loss(), - nest.flatten(tensors)) + grads_flat = gradients_impl.gradients( + self._layers.total_sampled_loss(), + nest.flatten(tensors), + colocate_gradients_with_ops=self._colocate_gradients_with_ops) grads_all = nest.pack_sequence_as(tensors, grads_flat) return tuple((grad,) for grad in grads_all) def _get_grads_lists_empirical(self, tensors): - grads_flat = gradients_impl.gradients(self._layers.total_loss(), - nest.flatten(tensors)) + grads_flat = gradients_impl.gradients( + self._layers.total_loss(), + nest.flatten(tensors), + colocate_gradients_with_ops=self._colocate_gradients_with_ops) grads_all = nest.pack_sequence_as(tensors, grads_flat) return tuple((grad,) for grad in grads_all) @@ -262,11 +283,13 @@ class FisherEstimator(object): grads_flat = gradients_impl.gradients( nest.flatten(loss_inputs), nest.flatten(tensors), - grad_ys=nest.flatten(transformed_random_signs)) + grad_ys=nest.flatten(transformed_random_signs), + colocate_gradients_with_ops=self._colocate_gradients_with_ops) grads_all = nest.pack_sequence_as(tensors, grads_flat) return tuple((grad,) for grad in grads_all) def _get_grads_lists_exact(self, tensors): + """No docstring required.""" # Loop over all coordinates of all losses. grads_all = [] for loss in self._layers.losses: @@ -274,6 +297,9 @@ class FisherEstimator(object): transformed_one_hot = loss.multiply_fisher_factor_replicated_one_hot( index) grads_flat = gradients_impl.gradients( - loss.inputs, nest.flatten(tensors), grad_ys=transformed_one_hot) + loss.inputs, + nest.flatten(tensors), + grad_ys=transformed_one_hot, + colocate_gradients_with_ops=self._colocate_gradients_with_ops) grads_all.append(nest.pack_sequence_as(tensors, grads_flat)) return zip(*grads_all) diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py index efffaaef8d56aed3a1cdbf2df1d8209d58b3502f..1ccb9e040f2bb6bcfd217886918abd40e3cc1cfb 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py @@ -38,6 +38,7 @@ from __future__ import division from __future__ import print_function import abc +import enum # pylint: disable=g-bad-import-order import six @@ -52,14 +53,54 @@ from tensorflow.python.ops import math_ops # damping /= num_replications ** NORMALIZE_DAMPING_POWER NORMALIZE_DAMPING_POWER = 1.0 +# Methods for adjusting damping for FisherBlocks. See +# _compute_pi_adjusted_damping() for details. +PI_OFF_NAME = "off" +PI_TRACENORM_NAME = "tracenorm" +PI_TYPE = PI_TRACENORM_NAME -def set_global_constants(normalize_damping_power=None): + +def set_global_constants(normalize_damping_power=None, pi_type=None): """Sets various global constants used by the classes in this module.""" global NORMALIZE_DAMPING_POWER + global PI_TYPE if normalize_damping_power is not None: NORMALIZE_DAMPING_POWER = normalize_damping_power + if pi_type is not None: + PI_TYPE = pi_type + + +def _compute_pi_tracenorm(left_cov, right_cov): + """Computes the scalar constant pi for Tikhonov regularization/damping. + + pi = sqrt( (trace(A) / dim(A)) / (trace(B) / dim(B)) ) + See section 6.3 of https://arxiv.org/pdf/1503.05671.pdf for details. + + Args: + left_cov: The left Kronecker factor "covariance". + right_cov: The right Kronecker factor "covariance". + + Returns: + The computed scalar constant pi for these Kronecker Factors (as a Tensor). + """ + # Instead of dividing by the dim of the norm, we multiply by the dim of the + # other norm. This works out the same in the ratio. + left_norm = math_ops.trace(left_cov) * right_cov.shape.as_list()[0] + right_norm = math_ops.trace(right_cov) * left_cov.shape.as_list()[0] + return math_ops.sqrt(left_norm / right_norm) + + +def _compute_pi_adjusted_damping(left_cov, right_cov, damping): + + if PI_TYPE == PI_TRACENORM_NAME: + pi = _compute_pi_tracenorm(left_cov, right_cov) + return (damping * pi, damping / pi) + + elif PI_TYPE == PI_OFF_NAME: + return (damping, damping) + @six.add_metaclass(abc.ABCMeta) class FisherBlock(object): @@ -133,16 +174,15 @@ class FullFB(FisherBlock): to any type of parameter in principle, but has very high variance. """ - def __init__(self, layer_collection, params, batch_size): + def __init__(self, layer_collection, params): """Creates a FullFB block. Args: layer_collection: The collection of all layers in the K-FAC approximate Fisher information matrix to which this FisherBlock belongs. params: The parameters of this layer (Tensor or tuple of Tensors). - batch_size: The batch size, used in the covariance estimator. """ - self._batch_size = batch_size + self._batch_sizes = [] self._params = params super(FullFB, self).__init__(layer_collection) @@ -154,7 +194,7 @@ class FullFB(FisherBlock): self._factor.register_damped_inverse(damping) def multiply_inverse(self, vector): - inverse = self._factor.get_inverse(self._damping) + inverse = self._factor.get_damped_inverse(self._damping) out_flat = math_ops.matmul(inverse, utils.tensors_to_column(vector)) return utils.column_to_tensors(vector, out_flat) @@ -172,9 +212,21 @@ class FullFB(FisherBlock): def tensors_to_compute_grads(self): return self._params + def register_additional_minibatch(self, batch_size): + """Register an additional minibatch. + + Args: + batch_size: The batch size, used in the covariance estimator. + """ + self._batch_sizes.append(batch_size) + @property def num_registered_minibatches(self): - return 1 # Multiple minibatches not supported. + return len(self._batch_sizes) + + @property + def _batch_size(self): + return math_ops.reduce_sum(self._batch_sizes) class NaiveDiagonalFB(FisherBlock): @@ -186,17 +238,16 @@ class NaiveDiagonalFB(FisherBlock): to any type of parameter in principle, but has very high variance. """ - def __init__(self, layer_collection, params, batch_size): + def __init__(self, layer_collection, params): """Creates a NaiveDiagonalFB block. Args: layer_collection: The collection of all layers in the K-FAC approximate Fisher information matrix to which this FisherBlock belongs. params: The parameters of this layer (Tensor or tuple of Tensors). - batch_size: The batch size, used in the covariance estimator. """ self._params = params - self._batch_size = batch_size + self._batch_sizes = [] super(NaiveDiagonalFB, self).__init__(layer_collection) @@ -221,9 +272,21 @@ class NaiveDiagonalFB(FisherBlock): def tensors_to_compute_grads(self): return self._params + def register_additional_minibatch(self, batch_size): + """Register an additional minibatch. + + Args: + batch_size: The batch size, used in the covariance estimator. + """ + self._batch_sizes.append(batch_size) + @property def num_registered_minibatches(self): - return 1 # Multiple minibatches not supported. + return len(self._batch_sizes) + + @property + def _batch_size(self): + return math_ops.reduce_sum(self._batch_sizes) class FullyConnectedDiagonalFB(FisherBlock): @@ -389,7 +452,7 @@ class ConvDiagonalFB(FisherBlock): (self._strides[1] * self._strides[2])) if NORMALIZE_DAMPING_POWER: - damping /= self._num_locations ** NORMALIZE_DAMPING_POWER + damping /= self._num_locations**NORMALIZE_DAMPING_POWER self._damping = damping self._factor = self._layer_collection.make_or_get_factor( @@ -443,11 +506,10 @@ class KroneckerProductFB(FisherBlock): Args: damping: The base damping factor (float or Tensor) for the damped inverse. """ - pi = utils.compute_pi(self._input_factor.get_cov(), - self._output_factor.get_cov()) - - self._input_damping = math_ops.sqrt(damping) * pi - self._output_damping = math_ops.sqrt(damping) / pi + self._input_damping, self._output_damping = _compute_pi_adjusted_damping( + self._input_factor.get_cov(), + self._output_factor.get_cov(), + damping**0.5) self._input_factor.register_damped_inverse(self._input_damping) self._output_factor.register_damped_inverse(self._output_damping) @@ -465,8 +527,9 @@ class KroneckerProductFB(FisherBlock): return 1.0 def multiply_inverse(self, vector): - left_factor_inv = self._input_factor.get_inverse(self._input_damping) - right_factor_inv = self._output_factor.get_inverse(self._output_damping) + left_factor_inv = self._input_factor.get_damped_inverse(self._input_damping) + right_factor_inv = self._output_factor.get_damped_inverse( + self._output_damping) reshaped_vector = utils.layer_params_to_mat2d(vector) reshaped_out = math_ops.matmul(left_factor_inv, math_ops.matmul(reshaped_vector, @@ -698,3 +761,260 @@ def _concat_along_batch_dim(tensor_list): def _num_conv_locations(input_shape, strides): """Returns the number of locations a Conv kernel is applied to.""" return input_shape[1] * input_shape[2] // (strides[1] * strides[2]) + + +class FullyConnectedMultiIndepFB(KroneckerProductFB): + """FisherBlock for fully-connected layers that share parameters. + """ + + def __init__(self, layer_collection, inputs, outputs, has_bias=False): + """Creates a FullyConnectedMultiIndepFB block. + + Args: + layer_collection: LayerCollection instance. + inputs: list or tuple of Tensors. Each Tensor has shape [batch_size, + inputs_size]. + outputs: list or tuple of Tensors. Each Tensor has shape [batch_size, + outputs_size]. + has_bias: bool. If True, estimates Fisher with respect to a bias + parameter as well as the layer's parameters. + """ + + assert len(inputs) == len(outputs) + # We need to make sure inputs and outputs are tuples and not lists so that + # they get hashed by layer_collection.make_or_get_factor properly. + self._inputs = tuple(inputs) + self._outputs = tuple(outputs) + self._has_bias = has_bias + self._num_uses = len(inputs) + + super(FullyConnectedMultiIndepFB, self).__init__(layer_collection) + + @property + def num_registered_minibatches(self): + # TODO(b/69411207): Add support for registering additional minibatches. + return 1 + + def instantiate_factors(self, grads_list, damping): + + self._input_factor = self._layer_collection.make_or_get_factor( + fisher_factors.FullyConnectedMultiKF, + ((self._inputs,), self._has_bias)) + + self._output_factor = self._layer_collection.make_or_get_factor( + fisher_factors.FullyConnectedMultiKF, (grads_list,)) + + if NORMALIZE_DAMPING_POWER: + damping /= self._num_uses**NORMALIZE_DAMPING_POWER + + self._register_damped_input_and_output_inverses(damping) + + @property + def _renorm_coeff(self): + return self._num_uses + + def tensors_to_compute_grads(self): + return self._outputs + + def num_inputs(self): + return len(self._inputs) + + +class SeriesFBApproximation(enum.IntEnum): + """See FullyConnectedSeriesFB.__init__ for description and usage.""" + option1 = 1 + option2 = 2 + + +class FullyConnectedSeriesFB(FisherBlock): + """FisherBlock for fully-connected layers that share parameters across time. + + See the following preprint for details: + https://openreview.net/pdf?id=HyMTkQZAb + + See the end of the appendix of the paper for a pseudo-code of the + algorithm being implemented by multiply_inverse here. Note that we are + using pre-computed versions of certain matrix-matrix products to speed + things up. This is explicitly explained wherever it is done. + """ + + def __init__(self, + layer_collection, + inputs, + outputs, + has_bias=False, + option=SeriesFBApproximation.option2): + """Constructs a new `FullyConnectedSeriesFB`. + + Args: + layer_collection: The collection of all layers in the K-FAC approximate + Fisher information matrix to which this FisherBlock belongs. + inputs: List of tensors of shape [batch_size, input_size]. + Inputs to the layer. + outputs: List of tensors of shape [batch_size, input_size]. + Outputs of the layer (before activations). + has_bias: Whether the layer includes a bias parameter. + option: A `SeriesFBApproximation` specifying the simplifying assumption + to be used in this block. `option1` approximates the cross-covariance + over time as a symmetric matrix, while `option2` makes + the assumption that training sequences are infinitely long. See section + 3.5 of the paper for more details. + """ + + assert len(inputs) == len(outputs) + # We need to make sure inputs and outputs are tuples and not lists so that + # they get hashed by layer_collection.make_or_get_factor properly. + self._inputs = tuple(inputs) + self._outputs = tuple(outputs) + self._has_bias = has_bias + self._num_timesteps = len(inputs) + self._option = option + + super(FullyConnectedSeriesFB, self).__init__(layer_collection) + + @property + def num_registered_minibatches(self): + # TODO(b/69411207): Add support for registering additional minibatches. + return 1 + + def instantiate_factors(self, grads_list, damping): + + self._input_factor = self._layer_collection.make_or_get_factor( + fisher_factors.FullyConnectedMultiKF, ((self._inputs,), self._has_bias)) + + self._output_factor = self._layer_collection.make_or_get_factor( + fisher_factors.FullyConnectedMultiKF, (grads_list,)) + + if NORMALIZE_DAMPING_POWER: + damping /= self._num_timesteps**NORMALIZE_DAMPING_POWER + + self._damping_input, self._damping_output = _compute_pi_adjusted_damping( + self._input_factor.get_cov(), + self._output_factor.get_cov(), + damping**0.5) + + if self._option == SeriesFBApproximation.option1: + self._input_factor.register_option1quants(self._damping_input) + self._output_factor.register_option1quants(self._damping_output) + elif self._option == SeriesFBApproximation.option2: + self._input_factor.register_option2quants(self._damping_input) + self._output_factor.register_option2quants(self._damping_output) + else: + raise ValueError( + "Unrecognized FullyConnectedSeriesFB approximation: {}".format( + self._option)) + + def multiply_inverse(self, vector): + # pylint: disable=invalid-name + + Z = utils.layer_params_to_mat2d(vector) + + # Derivations were done for "batch_dim==1" case so we need to convert to + # that orientation: + Z = array_ops.transpose(Z) + + if self._option == SeriesFBApproximation.option1: + + # Note that L_A = A0^(-1/2) * U_A and L_G = G0^(-1/2) * U_G. + L_A, psi_A = self._input_factor.get_option1quants(self._damping_input) + L_G, psi_G = self._output_factor.get_option1quants(self._damping_output) + + def gamma(x): + # We are assuming that each case has the same number of time-steps. + # If this stops being the case one shouldn't simply replace this T + # with its average value. Instead, one needs to go back to the + # definition of the gamma function from the paper. + T = self._num_timesteps + return (1 - x)**2 / (T * (1 - x**2) - 2 * x * (1 - x**T)) + + # Y = gamma( psi_G*psi_A^T ) (computed element-wise) + # Even though Y is Z-independent we are recomputing it from the psi's + # each since Y depends on both A and G quantities, and it is relatively + # cheap to compute. + Y = gamma(array_ops.reshape(psi_G, [int(psi_G.shape[0]), -1]) * psi_A) + + # Z = L_G^T * Z * L_A + # This is equivalent to the following computation from the original + # pseudo-code: + # Z = G0^(-1/2) * Z * A0^(-1/2) + # Z = U_G^T * Z * U_A + Z = math_ops.matmul(L_G, math_ops.matmul(Z, L_A), transpose_a=True) + + # Z = Z .* Y + Z *= Y + + # Z = L_G * Z * L_A^T + # This is equivalent to the following computation from the original + # pseudo-code: + # Z = U_G * Z * U_A^T + # Z = G0^(-1/2) * Z * A0^(-1/2) + Z = math_ops.matmul(L_G, math_ops.matmul(Z, L_A, transpose_b=True)) + + elif self._option == SeriesFBApproximation.option2: + + # Note that P_A = A_1^T * A_0^(-1) and P_G = G_1^T * G_0^(-1), + # and K_A = A_0^(-1/2) * E_A and K_G = G_0^(-1/2) * E_G. + P_A, K_A, mu_A = self._input_factor.get_option2quants(self._damping_input) + P_G, K_G, mu_G = self._output_factor.get_option2quants( + self._damping_output) + + # Our approach differs superficially from the pseudo-code in the paper + # in order to reduce the total number of matrix-matrix multiplies. + # In particular, the first three computations in the pseudo code are + # Z = G0^(-1/2) * Z * A0^(-1/2) + # Z = Z - hPsi_G^T * Z * hPsi_A + # Z = E_G^T * Z * E_A + # Noting that hPsi = C0^(-1/2) * C1 * C0^(-1/2), so that + # C0^(-1/2) * hPsi = C0^(-1) * C1 * C0^(-1/2) = P^T * C0^(-1/2) + # the entire computation can be written as + # Z = E_G^T * (G0^(-1/2) * Z * A0^(-1/2) + # - hPsi_G^T * G0^(-1/2) * Z * A0^(-1/2) * hPsi_A) * E_A + # = E_G^T * (G0^(-1/2) * Z * A0^(-1/2) + # - G0^(-1/2) * P_G * Z * P_A^T * A0^(-1/2)) * E_A + # = E_G^T * G0^(-1/2) * Z * A0^(-1/2) * E_A + # - E_G^T* G0^(-1/2) * P_G * Z * P_A^T * A0^(-1/2) * E_A + # = K_G^T * Z * K_A - K_G^T * P_G * Z * P_A^T * K_A + # This final expression is computed by the following two lines: + # Z = Z - P_G * Z * P_A^T + Z -= math_ops.matmul(P_G, math_ops.matmul(Z, P_A, transpose_b=True)) + # Z = K_G^T * Z * K_A + Z = math_ops.matmul(K_G, math_ops.matmul(Z, K_A), transpose_a=True) + + # Z = Z ./ (1*1^T - mu_G*mu_A^T) + # Be careful with the outer product. We don't want to accidentally + # make it an inner-product instead. + tmp = 1.0 - array_ops.reshape(mu_G, [int(mu_G.shape[0]), -1]) * mu_A + # Prevent some numerical issues by setting any 0.0 eigs to 1.0 + tmp += 1.0 * math_ops.cast(math_ops.equal(tmp, 0.0), dtype=tmp.dtype) + Z /= tmp + + # We now perform the transpose/reverse version of the operations + # derived above, whose derivation from the original pseudo-code is + # analgous. + # Z = K_G * Z * K_A^T + Z = math_ops.matmul(K_G, math_ops.matmul(Z, K_A, transpose_b=True)) + + # Z = Z - P_G^T * Z * P_A + Z -= math_ops.matmul(P_G, math_ops.matmul(Z, P_A), transpose_a=True) + + # Z = normalize (1/E[T]) * Z + # Note that this normalization is done because we compute the statistics + # by averaging, not summing, over time. (And the gradient is presumably + # summed over time, not averaged, and thus their scales are different.) + Z /= math_ops.cast(self._num_timesteps, Z.dtype) + + # Convert back to the "batch_dim==0" orientation. + Z = array_ops.transpose(Z) + + return utils.mat2d_to_layer_params(vector, Z) + + # pylint: enable=invalid-name + + def multiply(self, vector): + raise NotImplementedError + + def tensors_to_compute_grads(self): + return self._outputs + + def num_inputs(self): + return len(self._inputs) diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors.py b/tensorflow/contrib/kfac/python/ops/fisher_factors.py index 4e36813369e69de1d6f13ddb00566bda912244f6..5a6d1a93ff217c3922f45a047b4d548086ac5258 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_factors.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_factors.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import abc +import contextlib import numpy as np import six @@ -26,6 +27,8 @@ import six from tensorflow.contrib.kfac.python.ops import utils from tensorflow.python.framework import ops as tf_ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import init_ops from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import special_math_ops @@ -50,7 +53,22 @@ EIGENVALUE_DECOMPOSITION_THRESHOLD = 2 EIGENVALUE_CLIPPING_THRESHOLD = 0.0 -def set_global_constants(init_covariances_at_zero=None, zero_debias=None, +@contextlib.contextmanager +def _maybe_colocate_with(op, colocate_cov_ops_with_inputs): + """Context to colocate with `op` if `colocate_cov_ops_with_inputs`.""" + if colocate_cov_ops_with_inputs: + if isinstance(op, (list, tuple)): + with tf_ops.colocate_with(op[0]): + yield + else: + with tf_ops.colocate_with(op): + yield + else: + yield + + +def set_global_constants(init_covariances_at_zero=None, + zero_debias=None, eigenvalue_decomposition_threshold=None, eigenvalue_clipping_threshold=None): """Sets various global constants used by the classes in this module.""" @@ -85,7 +103,7 @@ def diagonal_covariance_initializer(shape, dtype, partition_info): # pylint: di return array_ops.ones(shape, dtype) -def _compute_cov(tensor, normalizer=None): +def _compute_cov(tensor, tensor_right=None, normalizer=None): """Compute the empirical second moment of the rows of a 2D Tensor. This function is meant to be applied to random matrices for which the true row @@ -93,6 +111,8 @@ def _compute_cov(tensor, normalizer=None): Args: tensor: A 2D Tensor. + tensor_right: An optional 2D Tensor. If provided, this function computes + the matrix product tensor^T * tensor_right instead of tensor^T * tensor. normalizer: optional scalar for the estimator (by default, the normalizer is the number of rows of tensor). @@ -101,9 +121,14 @@ def _compute_cov(tensor, normalizer=None): """ if normalizer is None: normalizer = array_ops.shape(tensor)[0] - cov = (math_ops.matmul(tensor, tensor, transpose_a=True) / math_ops.cast( - normalizer, tensor.dtype)) - return (cov + array_ops.transpose(cov)) / math_ops.cast(2, cov.dtype) + if tensor_right is None: + cov = ( + math_ops.matmul(tensor, tensor, transpose_a=True) / math_ops.cast( + normalizer, tensor.dtype)) + return (cov + array_ops.transpose(cov)) / math_ops.cast(2.0, cov.dtype) + else: + return (math_ops.matmul(tensor, tensor_right, transpose_a=True) / + math_ops.cast(normalizer, tensor.dtype)) def _append_homog(tensor): @@ -119,7 +144,7 @@ def _append_homog(tensor): rank = len(tensor.shape.as_list()) shape = array_ops.concat([array_ops.shape(tensor)[:-1], [1]], axis=0) ones = array_ops.ones(shape, dtype=tensor.dtype) - return array_ops.concat([tensor, ones], axis=rank-1) + return array_ops.concat([tensor, ones], axis=rank - 1) def scope_string_from_params(params): @@ -157,8 +182,8 @@ def scope_string_from_params(params): elif isinstance(param, (tf_ops.Tensor, variables.Variable)): name_parts.append(scope_string_from_name(param)) else: - raise ValueError( - "Encountered an unsupported param type {}".format(type(param))) + raise ValueError("Encountered an unsupported param type {}".format( + type(param))) return "_".join(name_parts) @@ -209,6 +234,10 @@ class FisherFactor(object): """ pass + @abc.abstractproperty + def _dtype(self): + pass + @property def _cov_initializer(self): return covariance_initializer @@ -220,7 +249,8 @@ class FisherFactor(object): "cov", initializer=self._cov_initializer, shape=self._cov_shape, - trainable=False) + trainable=False, + dtype=self._dtype) @abc.abstractmethod def _compute_new_cov(self, idx=0): @@ -240,9 +270,10 @@ class FisherFactor(object): return moving_averages.assign_moving_average( self._cov, new_cov, ema_decay, zero_debias=ZERO_DEBIAS) + @abc.abstractmethod def make_inverse_update_ops(self): """Create and return update ops corresponding to registered computations.""" - return [] + pass def get_cov(self): return self._cov @@ -257,6 +288,13 @@ class InverseProvidingFactor(FisherFactor): _cov_shape properties. """ + # TODO(b/69108481): This class (and its subclasses) should be refactored to + # serve the matrix quantities it computes as both (potentially stale) + # variables, updated by the inverse update ops, and fresh values stored in + # tensors that recomputed once every session.run() call. Currently matpower + # and damp_inverse have the former behavior, while eigendecomposition has + # the latter. + def __init__(self): self._inverses_by_damping = {} self._matpower_by_exp_and_damping = {} @@ -267,6 +305,10 @@ class InverseProvidingFactor(FisherFactor): def register_damped_inverse(self, damping): """Registers a damped inverse needed by a FisherBlock. + This creates a variable and signals make_inverse_update_ops to make the + corresponding update op. The variable can be read via the method + get_inverse. + Args: damping: The damping value (float or Tensor) for this factor. """ @@ -277,12 +319,17 @@ class InverseProvidingFactor(FisherFactor): "inv_damp{}".format(damping_string), initializer=inverse_initializer, shape=self._cov_shape, - trainable=False) + trainable=False, + dtype=self._dtype) self._inverses_by_damping[damping] = inv def register_matpower(self, exp, damping): """Registers a matrix power needed by a FisherBlock. + This creates a variable and signals make_inverse_update_ops to make the + corresponding update op. The variable can be read via the method + get_matpower. + Args: exp: The exponent (float or Tensor) to raise the matrix to. damping: The damping value (float or Tensor). @@ -295,57 +342,78 @@ class InverseProvidingFactor(FisherFactor): "matpower_exp{}_damp{}".format(exp_string, damping_string), initializer=inverse_initializer, shape=self._cov_shape, - trainable=False) + trainable=False, + dtype=self._dtype) self._matpower_by_exp_and_damping[(exp, damping)] = matpower def register_eigendecomp(self): - """Registers that an eigendecomposition is needed by a FisherBlock.""" + """Registers an eigendecomposition. + + Unlike register_damp_inverse and register_matpower this doesn't create + any variables or inverse ops. Instead it merely makes tensors containing + the eigendecomposition available to anyone that wants them. They will be + recomputed (once) for each session.run() call (when they needed by some op). + """ if not self._eigendecomp: - self._eigendecomp = linalg_ops.self_adjoint_eig(self._cov) + eigenvalues, eigenvectors = linalg_ops.self_adjoint_eig(self._cov) + + # The matrix self._cov is positive semidefinite by construction, but the + # numerical eigenvalues could be negative due to numerical errors, so here + # we clip them to be at least FLAGS.eigenvalue_clipping_threshold + clipped_eigenvalues = math_ops.maximum(eigenvalues, + EIGENVALUE_CLIPPING_THRESHOLD) + self._eigendecomp = (clipped_eigenvalues, eigenvectors) def make_inverse_update_ops(self): """Create and return update ops corresponding to registered computations.""" - ops = super(InverseProvidingFactor, self).make_inverse_update_ops() + ops = [] num_inverses = len(self._inverses_by_damping) matrix_power_registered = bool(self._matpower_by_exp_and_damping) - use_eig = (self._eigendecomp or matrix_power_registered or - num_inverses >= EIGENVALUE_DECOMPOSITION_THRESHOLD) + use_eig = ( + self._eigendecomp or matrix_power_registered or + num_inverses >= EIGENVALUE_DECOMPOSITION_THRESHOLD) if use_eig: self.register_eigendecomp() # ensures self._eigendecomp is set eigenvalues, eigenvectors = self._eigendecomp # pylint: disable=unpacking-non-sequence - # The matrix self._cov is positive semidefinite by construction, but the - # numerical eigenvalues could be negative due to numerical errors, so here - # we clip them to be at least EIGENVALUE_CLIPPING_THRESHOLD. - clipped_eigenvalues = math_ops.maximum(eigenvalues, - EIGENVALUE_CLIPPING_THRESHOLD) - for damping, inv in self._inverses_by_damping.items(): ops.append( inv.assign( - math_ops.matmul(eigenvectors / (clipped_eigenvalues + damping), + math_ops.matmul(eigenvectors / (eigenvalues + damping), array_ops.transpose(eigenvectors)))) for (exp, damping), matpower in self._matpower_by_exp_and_damping.items(): ops.append( matpower.assign( - math_ops.matmul(eigenvectors * (clipped_eigenvalues + damping)** - exp, array_ops.transpose(eigenvectors)))) + math_ops.matmul(eigenvectors * + (eigenvalues + damping)**exp, + array_ops.transpose(eigenvectors)))) + # These ops share computation and should be run on a single device. + ops = [control_flow_ops.group(*ops)] else: for damping, inv in self._inverses_by_damping.items(): ops.append(inv.assign(utils.posdef_inv(self._cov, damping))) return ops - def get_inverse(self, damping): + def get_damped_inverse(self, damping): + # Note that this function returns a variable which gets updated by the + # inverse ops. It may be stale / inconsistent with the latest value of + # get_cov(). return self._inverses_by_damping[damping] def get_matpower(self, exp, damping): + # Note that this function returns a variable which gets updated by the + # inverse ops. It may be stale / inconsistent with the latest value of + # get_cov(). return self._matpower_by_exp_and_damping[(exp, damping)] def get_eigendecomp(self): + # Unlike get_inverse and get_matpower this doesn't retrieve a stored + # variable, but instead always computes a fresh version from the current + # value of get_cov(). return self._eigendecomp @@ -356,12 +424,21 @@ class FullFactor(InverseProvidingFactor): to any type of parameter in principle, but has very high variance. """ - def __init__(self, params_grads, batch_size): + def __init__(self, + params_grads, + batch_size, + colocate_cov_ops_with_inputs=False): self._batch_size = batch_size + self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs self._orig_params_grads_name = scope_string_from_params( [params_grads, self._batch_size]) - self._params_grads_flat = tuple( - utils.tensors_to_column(params_grad) for params_grad in params_grads) + params_grads_flat = [] + for params_grad in params_grads: + with _maybe_colocate_with(params_grad, + self._colocate_cov_ops_with_inputs): + col = utils.tensors_to_column(params_grad) + params_grads_flat.append(col) + self._params_grads_flat = tuple(params_grads_flat) super(FullFactor, self).__init__() @property @@ -377,11 +454,17 @@ class FullFactor(InverseProvidingFactor): def _num_sources(self): return len(self._params_grads_flat) + @property + def _dtype(self): + return self._params_grads_flat[0].dtype + def _compute_new_cov(self, idx=0): # This will be a very basic rank 1 estimate - return ((self._params_grads_flat[idx] * array_ops.transpose( - self._params_grads_flat[idx])) / math_ops.cast( - self._batch_size, self._params_grads_flat[idx].dtype)) + with _maybe_colocate_with(self._params_grads_flat[idx], + self._colocate_cov_ops_with_inputs): + return ((self._params_grads_flat[idx] * array_ops.transpose( + self._params_grads_flat[idx])) / math_ops.cast( + self._batch_size, self._params_grads_flat[idx].dtype)) class DiagonalFactor(FisherFactor): @@ -394,6 +477,9 @@ class DiagonalFactor(FisherFactor): def _cov_initializer(self): return diagonal_covariance_initializer + def make_inverse_update_ops(self): + return [] + class NaiveDiagonalFactor(DiagonalFactor): """FisherFactor for a diagonal approximation of any type of param's Fisher. @@ -402,10 +488,19 @@ class NaiveDiagonalFactor(DiagonalFactor): to any type of parameter in principle, but has very high variance. """ - def __init__(self, params_grads, batch_size): + def __init__(self, + params_grads, + batch_size, + colocate_cov_ops_with_inputs=False): self._batch_size = batch_size - self._params_grads = tuple( - utils.tensors_to_column(params_grad) for params_grad in params_grads) + self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs + params_grads_flat = [] + for params_grad in params_grads: + with _maybe_colocate_with(params_grad, + self._colocate_cov_ops_with_inputs): + col = utils.tensors_to_column(params_grad) + params_grads_flat.append(col) + self._params_grads = tuple(params_grads_flat) self._orig_params_grads_name = scope_string_from_params( [self._params_grads, self._batch_size]) super(NaiveDiagonalFactor, self).__init__() @@ -422,9 +517,15 @@ class NaiveDiagonalFactor(DiagonalFactor): def _num_sources(self): return len(self._params_grads) + @property + def _dtype(self): + return self._params_grads[0].dtype + def _compute_new_cov(self, idx=0): - return (math_ops.square(self._params_grads[idx]) / math_ops.cast( - self._batch_size, self._params_grads[idx].dtype)) + with _maybe_colocate_with(self._params_grads[idx], + self._colocate_cov_ops_with_inputs): + return (math_ops.square(self._params_grads[idx]) / math_ops.cast( + self._batch_size, self._params_grads[idx].dtype)) class FullyConnectedDiagonalFactor(DiagonalFactor): @@ -440,7 +541,11 @@ class FullyConnectedDiagonalFactor(DiagonalFactor): # TODO(jamesmartens): add units tests for this class - def __init__(self, inputs, outputs_grads, has_bias=False): + def __init__(self, + inputs, + outputs_grads, + has_bias=False, + colocate_cov_ops_with_inputs=False): """Instantiate FullyConnectedDiagonalFactor. Args: @@ -449,18 +554,22 @@ class FullyConnectedDiagonalFactor(DiagonalFactor): outputs_grads: List of Tensors of shape [batch_size, output_size]. Gradient of loss with respect to layer's preactivations. has_bias: bool. If True, append '1' to each input. + colocate_cov_ops_with_inputs: Whether to colocate cov_update ops with + their inputs. """ self._outputs_grads = outputs_grads + self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs self._batch_size = array_ops.shape(inputs)[0] - self._orig_tensors_name = scope_string_from_params((inputs,) + - tuple(outputs_grads)) + self._orig_tensors_name = scope_string_from_params( + (inputs,) + tuple(outputs_grads)) # Note that we precompute the required operations on the inputs since the # inputs don't change with the 'idx' argument to _compute_new_cov. (Only # the target entry of _outputs_grads changes with idx.) - if has_bias: - inputs = _append_homog(inputs) - self._squared_inputs = math_ops.square(inputs) + with _maybe_colocate_with(inputs, self._colocate_cov_ops_with_inputs): + if has_bias: + inputs = _append_homog(inputs) + self._squared_inputs = math_ops.square(inputs) super(FullyConnectedDiagonalFactor, self).__init__() @@ -476,17 +585,23 @@ class FullyConnectedDiagonalFactor(DiagonalFactor): def _num_sources(self): return len(self._outputs_grads) + @property + def _dtype(self): + return self._outputs_grads[0].dtype + def _compute_new_cov(self, idx=0): # The well-known special formula that uses the fact that the entry-wise # square of an outer product is the outer-product of the entry-wise squares. # The gradient is the outer product of the input and the output gradients, # so we just square both and then take their outer-product. - new_cov = math_ops.matmul( - self._squared_inputs, - math_ops.square(self._outputs_grads[idx]), - transpose_a=True) - new_cov /= math_ops.cast(self._batch_size, new_cov.dtype) - return new_cov + with _maybe_colocate_with(self._squared_inputs, + self._colocate_cov_ops_with_inputs): + new_cov = math_ops.matmul( + self._squared_inputs, + math_ops.square(self._outputs_grads[idx]), + transpose_a=True) + new_cov /= math_ops.cast(self._batch_size, new_cov.dtype) + return new_cov class ConvDiagonalFactor(DiagonalFactor): @@ -494,8 +609,14 @@ class ConvDiagonalFactor(DiagonalFactor): # TODO(jamesmartens): add units tests for this class - def __init__(self, inputs, outputs_grads, filter_shape, strides, padding, - has_bias=False): + def __init__(self, + inputs, + outputs_grads, + filter_shape, + strides, + padding, + has_bias=False, + colocate_cov_ops_with_inputs=False): """Creates a ConvDiagonalFactor object. Args: @@ -510,29 +631,36 @@ class ConvDiagonalFactor(DiagonalFactor): padding: The padding in this layer (1-D of Tensor length 4). has_bias: Python bool. If True, the layer is assumed to have a bias parameter in addition to its filter parameter. + colocate_cov_ops_with_inputs: Whether to colocate cov_update ops with + their inputs. """ self._filter_shape = filter_shape self._has_bias = has_bias self._outputs_grads = outputs_grads + self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs - self._orig_tensors_name = scope_string_from_name((inputs,) - + tuple(outputs_grads)) + self._orig_tensors_name = scope_string_from_name( + (inputs,) + tuple(outputs_grads)) # Note that we precompute the required operations on the inputs since the # inputs don't change with the 'idx' argument to _compute_new_cov. (Only # the target entry of _outputs_grads changes with idx.) - filter_height, filter_width, _, _ = self._filter_shape - patches = array_ops.extract_image_patches( - inputs, - ksizes=[1, filter_height, filter_width, 1], - strides=strides, - rates=[1, 1, 1, 1], - padding=padding) + with _maybe_colocate_with(inputs, self._colocate_cov_ops_with_inputs): + filter_height, filter_width, _, _ = self._filter_shape - if has_bias: - patches = _append_homog(patches) + # TODO(b/64144716): there is potential here for a big savings in terms of + # memory use. + patches = array_ops.extract_image_patches( + inputs, + ksizes=[1, filter_height, filter_width, 1], + strides=strides, + rates=[1, 1, 1, 1], + padding=padding) + + if has_bias: + patches = _append_homog(patches) - self._patches = patches + self._patches = patches super(ConvDiagonalFactor, self).__init__() @@ -543,21 +671,29 @@ class ConvDiagonalFactor(DiagonalFactor): @property def _cov_shape(self): filter_height, filter_width, in_channels, out_channels = self._filter_shape - return [filter_height * filter_width * in_channels + self._has_bias, - out_channels] + return [ + filter_height * filter_width * in_channels + self._has_bias, + out_channels + ] @property def _num_sources(self): return len(self._outputs_grads) + @property + def _dtype(self): + return self._outputs_grads[0].dtype + def _compute_new_cov(self, idx=0): - outputs_grad = self._outputs_grads[idx] - batch_size = array_ops.shape(self._patches)[0] + with _maybe_colocate_with(self._outputs_grads[idx], + self._colocate_cov_ops_with_inputs): + outputs_grad = self._outputs_grads[idx] + batch_size = array_ops.shape(self._patches)[0] - new_cov = self._convdiag_sum_of_squares(self._patches, outputs_grad) - new_cov /= math_ops.cast(batch_size, new_cov.dtype) + new_cov = self._convdiag_sum_of_squares(self._patches, outputs_grad) + new_cov /= math_ops.cast(batch_size, new_cov.dtype) - return new_cov + return new_cov def _convdiag_sum_of_squares(self, patches, outputs_grad): # This computes the sum of the squares of the per-training-case "gradients". @@ -572,19 +708,24 @@ class FullyConnectedKroneckerFactor(InverseProvidingFactor): """Kronecker factor for the input or output side of a fully-connected layer. """ - def __init__(self, tensors, has_bias=False): + def __init__(self, + tensors, + has_bias=False, + colocate_cov_ops_with_inputs=False): """Instantiate FullyConnectedKroneckerFactor. Args: tensors: List of Tensors of shape [batch_size, n]. Represents either a layer's inputs or its output's gradients. - has_bias: bool. If True, assume this factor is for the layer's inputs and - append '1' to each row. + has_bias: bool. If True, append '1' to each row. + colocate_cov_ops_with_inputs: Whether to colocate cov_update ops with + their inputs. """ # The tensor argument is either a tensor of input activations or a tensor of # output pre-activation gradients. self._has_bias = has_bias self._tensors = tensors + self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs super(FullyConnectedKroneckerFactor, self).__init__() @property @@ -601,11 +742,17 @@ class FullyConnectedKroneckerFactor(InverseProvidingFactor): def _num_sources(self): return len(self._tensors) + @property + def _dtype(self): + return self._tensors[0].dtype + def _compute_new_cov(self, idx=0): - tensor = self._tensors[idx] - if self._has_bias: - tensor = _append_homog(tensor) - return _compute_cov(tensor) + with _maybe_colocate_with(self._tensors[idx], + self._colocate_cov_ops_with_inputs): + tensor = self._tensors[idx] + if self._has_bias: + tensor = _append_homog(tensor) + return _compute_cov(tensor) class ConvInputKroneckerFactor(InverseProvidingFactor): @@ -618,7 +765,13 @@ class ConvInputKroneckerFactor(InverseProvidingFactor): Section 3.1 Estimating the factors. """ - def __init__(self, inputs, filter_shape, strides, padding, has_bias=False): + def __init__(self, + inputs, + filter_shape, + strides, + padding, + has_bias=False, + colocate_cov_ops_with_inputs=False): """Initializes ConvInputKroneckerFactor. Args: @@ -630,12 +783,15 @@ class ConvInputKroneckerFactor(InverseProvidingFactor): width_stride, in_channel_stride]. padding: str. Padding method for layer. "SAME" or "VALID". has_bias: bool. If True, append 1 to in_channel. + colocate_cov_ops_with_inputs: Whether to colocate cov_update ops with + their inputs. """ self._filter_shape = filter_shape self._strides = strides self._padding = padding self._has_bias = has_bias self._inputs = inputs + self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs super(ConvInputKroneckerFactor, self).__init__() @property @@ -655,26 +811,34 @@ class ConvInputKroneckerFactor(InverseProvidingFactor): def _num_sources(self): return 1 + @property + def _dtype(self): + return self._inputs.dtype + def _compute_new_cov(self, idx=0): if idx != 0: raise ValueError("ConvInputKroneckerFactor only supports idx = 0") # TODO(jamesmartens): factor this patches stuff out into a utility function - filter_height, filter_width, in_channels, _ = self._filter_shape - patches = array_ops.extract_image_patches( - self._inputs, - ksizes=[1, filter_height, filter_width, 1], - strides=self._strides, - rates=[1, 1, 1, 1], - padding=self._padding) + with _maybe_colocate_with(self._inputs, self._colocate_cov_ops_with_inputs): + filter_height, filter_width, in_channels, _ = self._filter_shape - flatten_size = (filter_height * filter_width * in_channels) - patches_flat = array_ops.reshape(patches, [-1, flatten_size]) + # TODO(b/64144716): there is potential here for a big savings in terms of + # memory use. + patches = array_ops.extract_image_patches( + self._inputs, + ksizes=[1, filter_height, filter_width, 1], + strides=self._strides, + rates=[1, 1, 1, 1], + padding=self._padding) - if self._has_bias: - patches_flat = _append_homog(patches_flat) + flatten_size = (filter_height * filter_width * in_channels) + patches_flat = array_ops.reshape(patches, [-1, flatten_size]) - return _compute_cov(patches_flat) + if self._has_bias: + patches_flat = _append_homog(patches_flat) + + return _compute_cov(patches_flat) class ConvOutputKroneckerFactor(InverseProvidingFactor): @@ -688,15 +852,18 @@ class ConvOutputKroneckerFactor(InverseProvidingFactor): Section 3.1 Estimating the factors. """ - def __init__(self, outputs_grads): + def __init__(self, outputs_grads, colocate_cov_ops_with_inputs=False): """Initializes ConvOutputKroneckerFactor. Args: outputs_grads: list of Tensors. Each Tensor is of shape - [batch_size, height, width, out_channels]. + [batch_size, height, width, out_channels]. + colocate_cov_ops_with_inputs: Whether to colocate cov_update ops with + their inputs. """ self._out_channels = outputs_grads[0].shape.as_list()[3] self._outputs_grads = outputs_grads + self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs super(ConvOutputKroneckerFactor, self).__init__() @property @@ -712,7 +879,286 @@ class ConvOutputKroneckerFactor(InverseProvidingFactor): def _num_sources(self): return len(self._outputs_grads) + @property + def _dtype(self): + return self._outputs_grads[0].dtype + def _compute_new_cov(self, idx=0): - reshaped_tensor = array_ops.reshape(self._outputs_grads[idx], - [-1, self._out_channels]) - return _compute_cov(reshaped_tensor) + with _maybe_colocate_with(self._outputs_grads[idx], + self._colocate_cov_ops_with_inputs): + reshaped_tensor = array_ops.reshape(self._outputs_grads[idx], + [-1, self._out_channels]) + return _compute_cov(reshaped_tensor) + + +class FullyConnectedMultiKF(InverseProvidingFactor): + """Kronecker factor for a fully connected recurrent layer.""" + + def __init__(self, + tensor_lists, + has_bias=False, + colocate_cov_ops_with_inputs=False): + """Constructs a new `FullyConnectedMultiKF`. + + Args: + tensor_lists: List of lists of Tensors of shape [batch_size, n]. + has_bias: bool. If True, '1' is appended to each row. + colocate_cov_ops_with_inputs: Whether to colocate cov_update ops with + their inputs. + """ + + self._orig_tensors_name = scope_string_from_params(tensor_lists) + self._batch_size = array_ops.shape(tensor_lists[0][0])[0] + self._num_timesteps = len(tensor_lists[0]) + + tensors = tuple( + array_ops.concat(tensor_list, 0) for tensor_list in tensor_lists) + if has_bias: + tensors = tuple(_append_homog(tensor) for tensor in tensors) + self._tensors = tensors + + self._cov_dt1 = None + self._option1quants_by_damping = {} + self._option2quants_by_damping = {} + self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs + + super(FullyConnectedMultiKF, self).__init__() + + @property + def _var_scope(self): + return "ff_fc_multi/" + self._orig_tensors_name + + @property + def _num_sources(self): + return len(self._tensors) + + @property + def _dtype(self): + return self._tensors[0].dtype + + def make_covariance_update_op(self, ema_decay): + with _maybe_colocate_with(self._tensors, + self._colocate_cov_ops_with_inputs): + op = super(FullyConnectedMultiKF, + self).make_covariance_update_op(ema_decay) + + if self._cov_dt1 is not None: + new_cov_dt1 = math_ops.add_n( + tuple( + self._compute_new_cov_dt1(idx) + for idx in range(self._num_sources))) + op2 = moving_averages.assign_moving_average( + self._cov_dt1, new_cov_dt1, ema_decay, zero_debias=ZERO_DEBIAS) + + # TODO(b/69112164): + # It's important that _cov and _cov_dt1 remain consistent with each + # other while the inverse ops are happening. How can we ensure this? + # We will need to add explicit synchronization for this to + # work with asynchronous training. + op = control_flow_ops.group(op, op2) + + return op + + def _compute_new_cov(self, idx=0): + tensor = self._tensors[idx] + normalizer = self._num_timesteps * self._batch_size + return _compute_cov(tensor, normalizer=normalizer) + + def _compute_new_cov_dt1(self, idx=0): + tensor = self._tensors[idx] + normalizer = self._num_timesteps * self._batch_size + tensor_present = tensor[:-self._batch_size, :] + tensor_future = tensor[self._batch_size:, :] + return _compute_cov( + tensor_future, tensor_right=tensor_present, normalizer=normalizer) + + @property + def _cov_shape(self): + size = self._tensors[0].shape[1] + return [size, size] + + @property + def _vec_shape(self): + size = self._tensors[0].shape[1] + return [size] + + def get_option1quants(self, damping): + return self._option1quants_by_damping[damping] + + def get_option2quants(self, damping): + return self._option2quants_by_damping[damping] + + def get_cov_dt1(self): + assert self._cov_dt1 is not None + return self._cov_dt1 + + def register_cov_dt1(self): + """Create a variable representing temporal cross-covariance. + + (This is technically the second moment, not covariance, since it's + not mean subtracted.) + """ + if self._cov_dt1 is None: + with variable_scope.variable_scope(self._var_scope): + self._cov_dt1 = variable_scope.get_variable( + "cov_dt1", + initializer=init_ops.zeros_initializer, + shape=self._cov_shape, + trainable=False, + dtype=self._dtype) + + def register_option1quants(self, damping): + + self.register_eigendecomp() + self.register_cov_dt1() + + if damping not in self._option1quants_by_damping: + # It's questionable as to whether we should initialize with stuff like + # this at all. Ideally these values should never be used until they are + # updated at least once. + damping_string = scalar_or_tensor_to_string(damping) + with variable_scope.variable_scope(self._var_scope): + Lmat = variable_scope.get_variable( # pylint: disable=invalid-name + "Lmat_damp{}".format(damping_string), + initializer=inverse_initializer, + shape=self._cov_shape, + trainable=False, + dtype=self._dtype) + psi = variable_scope.get_variable( + "psi_damp{}".format(damping_string), + initializer=init_ops.ones_initializer, + shape=self._vec_shape, + trainable=False, + dtype=self._dtype) + + self._option1quants_by_damping[damping] = (Lmat, psi) + + def register_option2quants(self, damping): + + self.register_eigendecomp() + self.register_cov_dt1() + + if damping not in self._option2quants_by_damping: + # It's questionable as to whether we should initialize with stuff like + # this at all. Ideally these values should never be used until they are + # updated at least once. + damping_string = scalar_or_tensor_to_string(damping) + with variable_scope.variable_scope(self._var_scope): + Pmat = variable_scope.get_variable( # pylint: disable=invalid-name + "Lmat_damp{}".format(damping_string), + initializer=inverse_initializer, + shape=self._cov_shape, + trainable=False, + dtype=self._dtype) + Kmat = variable_scope.get_variable( # pylint: disable=invalid-name + "Kmat_damp{}".format(damping_string), + initializer=inverse_initializer, + shape=self._cov_shape, + trainable=False, + dtype=self._dtype) + mu = variable_scope.get_variable( + "mu_damp{}".format(damping_string), + initializer=init_ops.ones_initializer, + shape=self._vec_shape, + trainable=False, + dtype=self._dtype) + + self._option2quants_by_damping[damping] = (Pmat, Kmat, mu) + + def make_inverse_update_ops(self): + """Create and return update ops corresponding to registered computations.""" + # TODO(b/69918258): Add correctness tests for this method. + # pylint: disable=invalid-name + + ops = super(FullyConnectedMultiKF, self).make_inverse_update_ops() + + if (len(self._option1quants_by_damping) + + len(self._option2quants_by_damping)): + + # Note that C0 and C1 are stand-ins for A0 and A1, or G0 and G1, from + # the pseudo-code in the original paper. Because the computations for + # the A and G case are essentially the same they can both be performed by + # the same class (this one). + + C1 = self.get_cov_dt1() + + # Get the eigendecomposition of C0 (= self.get_cov()) + eigen_e, eigen_V = self.get_eigendecomp() + + # TODO(b/69678661): Note, there is an implicit assumption here that C1 + # and C0 (as represented here by its eigen-decomp) are consistent. This + # could fail to be the case if self._cov and self._cov_dt1 are not updated + # consistently, or are somehow read between or during the cov updates. + # Can this possibly happen? Is there a way to prevent it? + + for damping, (Lmat_var, + psi_var) in self._option1quants_by_damping.items(): + + invsqrtC0 = math_ops.matmul( + eigen_V * (eigen_e + damping)**(-0.5), eigen_V, transpose_b=True) + + # Might need to enforce symmetry lost due to numerical issues. + invsqrtC0 = (invsqrtC0 + array_ops.transpose(invsqrtC0)) / 2.0 + + # The following line imposses the symmetry assumed by "Option 1" on C1. + # Stangely the code can work okay with this line commented out, + # depending on how psd_eig is defined. I'm not sure why. + C1 = (C1 + array_ops.transpose(C1)) / 2.0 + + # hPsi = C0^(-1/2) * C1 * C0^(-1/2) (hPsi means \hat{Psi}) + hPsi = math_ops.matmul(math_ops.matmul(invsqrtC0, C1), invsqrtC0) + + # Compute the decomposition U*diag(psi)*U^T = hPsi + psi, U = utils.posdef_eig(hPsi) + + # L = C0^(-1/2) * U + Lmat = math_ops.matmul(invsqrtC0, U) + + ops.append(Lmat_var.assign(Lmat)) + ops.append(psi_var.assign(psi)) + + for damping, (Pmat_var, Kmat_var, + mu_var) in self._option2quants_by_damping.items(): + + # compute C0^(-1/2) + invsqrtC0 = math_ops.matmul( + eigen_V * (eigen_e + damping)**(-0.5), eigen_V, transpose_b=True) + + # Might need to enforce symmetry lost due to numerical issues. + invsqrtC0 = (invsqrtC0 + array_ops.transpose(invsqrtC0)) / 2.0 + + # Compute the product C0^(-1/2) * C1 + invsqrtC0C1 = math_ops.matmul(invsqrtC0, C1) + + # hPsi = C0^(-1/2) * C1 * C0^(-1/2) (hPsi means \hat{Psi}) + hPsi = math_ops.matmul(invsqrtC0C1, invsqrtC0) + + # Compute the decomposition E*diag(mu)*E^T = hPsi^T * hPsi + # Note that we using the notation mu instead of "m" for the eigenvalues. + # Instead of computing the product hPsi^T * hPsi and then doing an + # eigen-decomposition of this we just compute the SVD of hPsi and then + # square the singular values to get the eigenvalues. For a justification + # of this approach, see: + # https://en.wikipedia.org/wiki/Singular-value_decomposition#Relation_to_eigenvalue_decomposition + sqrtmu, _, E = linalg_ops.svd(hPsi) + mu = math_ops.square(sqrtmu) + + # Mathematically, the eigenvalues should not should not exceed 1.0, but + # due to numerical issues, or possible issues with inconsistent + # values of C1 and (the eigen-decomposition of) C0 they might. So + # we enforce this condition. + mu = math_ops.minimum(mu, 1.0) + + # P = (C0^(-1/2) * C1)^T * C0^(-1/2) = C_1^T * C_0^(-1) + Pmat = math_ops.matmul(invsqrtC0C1, invsqrtC0, transpose_a=True) + + # K = C_0^(-1/2) * E + Kmat = math_ops.matmul(invsqrtC0, E) + + ops.append(Pmat_var.assign(Pmat)) + ops.append(Kmat_var.assign(Kmat)) + ops.append(mu_var.assign(mu)) + + return [control_flow_ops.group(*ops)] + + # pylint: enable=invalid-name diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection.py b/tensorflow/contrib/kfac/python/ops/layer_collection.py index 1806f5d8651e0b922fc30aed58d19de7faa5b265..ca42afe6fb2f5c7d7de8b5b087dc11be30a75d5e 100644 --- a/tensorflow/contrib/kfac/python/ops/layer_collection.py +++ b/tensorflow/contrib/kfac/python/ops/layer_collection.py @@ -26,7 +26,9 @@ from __future__ import print_function from collections import defaultdict from collections import OrderedDict +from functools import partial +import math import six from tensorflow.contrib.kfac.python.ops import fisher_blocks as fb @@ -35,20 +37,51 @@ from tensorflow.contrib.kfac.python.ops import utils from tensorflow.python.framework import ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope -from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import nest - # Names for various approximations that can be requested for Fisher blocks. APPROX_KRONECKER_NAME = "kron" APPROX_DIAGONAL_NAME = "diagonal" APPROX_FULL_NAME = "full" +_GENERIC_APPROX_TO_BLOCK_TYPES = { + APPROX_FULL_NAME: fb.FullFB, + APPROX_DIAGONAL_NAME: fb.NaiveDiagonalFB, +} + +_FULLY_CONNECTED_APPROX_TO_BLOCK_TYPES = { + APPROX_KRONECKER_NAME: fb.FullyConnectedKFACBasicFB, + APPROX_DIAGONAL_NAME: fb.FullyConnectedDiagonalFB, +} + +_CONV2D_APPROX_TO_BLOCK_TYPES = { + APPROX_KRONECKER_NAME: fb.ConvKFCBasicFB, + APPROX_DIAGONAL_NAME: fb.ConvDiagonalFB, +} + +APPROX_KRONECKER_INDEP_NAME = "kron_indep" +APPROX_KRONECKER_SERIES_1_NAME = "kron_series_1" +APPROX_KRONECKER_SERIES_2_NAME = "kron_series_2" + +_FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES = { + APPROX_KRONECKER_INDEP_NAME: fb.FullyConnectedMultiIndepFB, + APPROX_KRONECKER_SERIES_1_NAME: partial(fb.FullyConnectedSeriesFB, + option=1), + APPROX_KRONECKER_SERIES_2_NAME: partial(fb.FullyConnectedSeriesFB, + option=2) +} + # Possible value for 'reuse' keyword argument. Sets 'reuse' to # tf.get_variable_scope().reuse. VARIABLE_SCOPE = "VARIABLE_SCOPE" -# TODO(jamesmartens): need to add find_canonical_output back into this somewhere + +def ensure_sequence(obj): + """If `obj` isn't a tuple or list, return a tuple containing `obj`.""" + if isinstance(obj, (tuple, list)): + return obj + else: + return (obj,) class LayerParametersDict(OrderedDict): @@ -103,21 +136,27 @@ class LayerCollection(object): fisher_blocks: a LayersParamsDict (subclass of OrderedDict) mapping layer parameters (Tensors or tuples of Tensors) to FisherBlock instances. fisher_factors: an OrderedDict mapping tuples to FisherFactor instances. - generic_registrations: a list of variables registered via a generic layer - registration. Generic registrations handle any and all of the ways a - variable is used in the graph, which means we don't need to check - their registration when verifying the correctness of the graph. losses: a list of LossFunction objects. The loss to be optimized is their sum. """ - def __init__(self, graph=None, name="LayerCollection"): + def __init__(self, + graph=None, + colocate_cov_ops_with_inputs=False, + name="LayerCollection"): self.fisher_blocks = LayerParametersDict() self.fisher_factors = OrderedDict() - self._generic_registrations = set() + self._linked_parameters = dict( + ) # dict mapping sets of variables to optionally specified approximations. self._graph = graph or ops.get_default_graph() self._loss_dict = {} # {str: LossFunction} self._subgraph = None + self._default_generic_approximation = APPROX_FULL_NAME + self._default_fully_connected_approximation = APPROX_KRONECKER_NAME + self._default_convolution_2d_approximation = APPROX_KRONECKER_NAME + self._default_fully_connected_multi_approximation = ( + APPROX_KRONECKER_SERIES_2_NAME) + self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs with variable_scope.variable_scope(None, default_name=name) as scope: self._var_scope = scope.name @@ -127,113 +166,195 @@ class LayerCollection(object): """LossFunctions registered with this LayerCollection.""" return list(self._loss_dict.values()) - def register_block(self, layer_key, fisher_block): - """Validates and registers the layer_key associated with the fisher_block. + @property + def registered_variables(self): + """A tuple of all of the variables currently registered.""" + tuple_of_tuples = (ensure_sequence(key) for key, block + in six.iteritems(self.fisher_blocks)) + flat_tuple = tuple(item for tuple_ in tuple_of_tuples for item in tuple_) + return flat_tuple + + @property + def linked_parameters(self): + """Groups of parameters with an optionally specified approximation. - Validation consists of checking whether the key was already registered or - if any of the elements of layer_key (if it's a tuple) were already - registered as part of another tuple (throws an error if so). If any of the - elements were registered by themselves, or as part of tuples that are - subsets of this layer_key, those registrations are first removed. - - If the layer_key is a subset of an existing registration, registration of - the new, smaller layer_key is skipped. - - e.g. If registrations include {'a': foo, ('b', 'c'): bar}, then - - register_layer('a', baz) -> ValueError - - register_layer(('b', 'c', 'd'), baz) -> - {'a': foo, ('b', 'c', 'd'): baz} - - register_layer('b', baz) -> - {'a': foo, ('b', 'c'): bar} (No change) - - register_layer(('a', 'd'), baz) -> - {('a', 'd'): baz, ('b', 'c'): bar} - - register_layer(('b', 'd'), baz) -> ValueError + Linked parameters can be added using `define_linked_parameters`. + If an approximation is specified, then this approximation will be used + when registering a layer with exactly these parameters, unless an + approximation is specified when calling the registration function. + + Returns: + A `dict` mapping tuples of parameters to an optional string. + """ + return self._linked_parameters + + @property + def default_generic_approximation(self): + return self._default_generic_approximation + + def set_default_generic_approximation(self, value): + if value not in _GENERIC_APPROX_TO_BLOCK_TYPES: + raise ValueError( + "{} is not a valid approximation for generic variables.".format( + value)) + self._default_generic_approximation = value + + @property + def default_fully_connected_approximation(self): + return self._default_fully_connected_approximation + + def set_default_fully_connected_approximation(self, value): + if value not in _FULLY_CONNECTED_APPROX_TO_BLOCK_TYPES: + raise ValueError( + "{} is not a valid approximation for fully connected layers.".format( + value)) + self._default_fully_connected_approximation = value + + @property + def default_conv2d_approximation(self): + return self._default_convolution_2d_approximation + + def set_default_conv2d_approximation(self, value): + if value not in _CONV2D_APPROX_TO_BLOCK_TYPES: + raise ValueError( + "{} is not a valid approximation for 2d convolutional layers.".format( + value)) + self._default_convolution_2d_approximation = value + + @property + def default_fully_connected_multi_approximation(self): + return self._default_fully_connected_multi_approximation + + def set_default_fully_connected_multi_approximation(self, value): + if value not in _FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES: + raise ValueError("{} is not a valid approximation for a fully-connected " + "multi layer.".format(value)) + self._default_fully_connected_multi_approximation = value + + def register_block(self, layer_key, fisher_block, reuse=VARIABLE_SCOPE): + """Validates and registers the layer_key associated with the fisher_block. Args: - layer_key: The key to check for in existing registrations and to register - if valid. - fisher_block: The associated fisher block. + layer_key: A variable or tuple of variables. The key to check for in + existing registrations and to register if valid. + fisher_block: The associated `FisherBlock`. + reuse: Method to use for inserting new `FisherBlock`s. One of True, False, + or 'VARIABLE_SCOPE'. Raises: - ValueError: If the layer_key was already registered, or if a subset of the - layer_key has already been registered as part of a different tuple. + ValueError: If `layer_key` was already registered and reuse is `False`, + if `layer_key` was registered with a different block type, or if + `layer_key` shares any variables with but is not equal to a previously + registered key. + KeyError: If `reuse` is `True` but `layer_key` was not previously + registered. + + Returns: + The `FisherBlock` registered under `layer_key`. If `layer_key` was already + registered, this will be the previously registered `FisherBlock`. """ + if reuse is VARIABLE_SCOPE: + reuse = variable_scope.get_variable_scope().reuse + + if reuse is True or (reuse is variable_scope.AUTO_REUSE and + layer_key in self.fisher_blocks): + result = self.fisher_blocks[layer_key] + if type(result) != type(fisher_block): # pylint: disable=unidiomatic-typecheck + raise ValueError( + "Attempted to register FisherBlock of type %s when existing " + "FisherBlock has type %s." % (type(fisher_block), type(result))) + return result + if reuse is False and layer_key in self.fisher_blocks: + raise ValueError("FisherBlock for %s is already in LayerCollection." % + (layer_key,)) + + # Insert fisher_block into self.fisher_blocks. if layer_key in self.fisher_blocks: raise ValueError("Duplicate registration: {}".format(layer_key)) - if isinstance(layer_key, (tuple, list)): - self._register_block_with_sequence_key(layer_key, fisher_block) - else: - self._register_block_with_nonsequence_key(layer_key, fisher_block) - - def _register_block_with_sequence_key(self, layer_key, fisher_block): - """Validates and registers the layer_key if it's a sequence.""" - inclusions = { - fisher_elt - for layer_elt in layer_key for fisher_elt in self.fisher_blocks - if self._equal_or_subset(layer_elt, fisher_elt) + # Raise an error if any variable in layer_key has been registered in any + # other blocks. + variable_to_block = { + var: (params, block) + for (params, block) in self.fisher_blocks.items() + for var in ensure_sequence(params) } - - if not inclusions: - self.fisher_blocks[layer_key] = fisher_block - return - - for key in inclusions: - fisher_block_key = key if isinstance(key, (tuple, list)) else (key,) - if set(layer_key).issubset(fisher_block_key): - logging.warning("Graph Registration Warning: tried to register " - "a subset ({}) of an already registered tuple " - "({}), skipping".format(layer_key, fisher_block_key)) - return - if not set(fisher_block_key).issubset(layer_key): + for variable in ensure_sequence(layer_key): + if variable in variable_to_block: + prev_key, prev_block = variable_to_block[variable] raise ValueError( - "Inconsistent registration, expected new key to be a subset or " - "superset of the existing key: existing is {}, new is {}".format( - key, layer_key)) - else: - self.fisher_blocks.pop(key) - + "Attempted to register layer_key {} with block {}, but variable {}" + " was already registered in key {} with block {}.".format( + layer_key, fisher_block, variable, prev_key, prev_block)) self.fisher_blocks[layer_key] = fisher_block - - def _register_block_with_nonsequence_key(self, layer_key, fisher_block): - """Validates and registers the layer_key if it's not a sequence.""" - inclusions = { - fisher_elt - for fisher_elt in self.fisher_blocks - if self._equal_or_subset(layer_key, fisher_elt) - } - - if not inclusions: - self.fisher_blocks[layer_key] = fisher_block - else: - logging.warning("Graph Registration Warning: tried to register " - "variable ({}) but a containing tuple was already " - "registered ({}), skipping".format(layer_key, inclusions)) - - def _equal_or_subset(self, elt1, elt2): - """Checks if the elements are equal or one is contained in the other.""" - return (elt1 == elt2 or (isinstance(elt1, - (tuple, list)) and elt2 in elt1) or - (isinstance(elt2, (tuple, list)) and elt1 in elt2)) + return fisher_block def get_use_count_map(self): """Returns a dict of variables to their number of registrations.""" + # TODO(b/70283403): Reimplement this in the old way, where each + # registration function would be responsible for incrementing the count. + # Also, this version has a bug: it won't do the right thing for generic + # registration for parameters that are shared. i.e. it won't set the use + # count to infinity. vars_to_uses = defaultdict(int) for key, block in six.iteritems(self.fisher_blocks): - key = key if isinstance(key, (tuple, list)) else (key,) + n = ( + block.num_inputs()*block.num_registered_minibatches if isinstance( + block, (fb.FullyConnectedSeriesFB, fb.FullyConnectedMultiIndepFB)) + else block.num_registered_minibatches) + key = ensure_sequence(key) for k in key: - vars_to_uses[k] += block.num_registered_minibatches + vars_to_uses[k] += n return vars_to_uses + def check_registration(self, variables): + """Checks that all variable uses have been registered properly. + + Args: + variables: List of variables. + + Raises: + ValueError: If any registered variables are not included in the list. + ValueError: If any variable in the list is not registered. + ValueError: If any variable in the list is registered with the wrong + number of "uses" in the subgraph recorded (vs the number of times that + variable is actually used in the subgraph). + """ + # Note that overlapping parameters (i.e. those that share variables) will + # be caught by layer_collection.LayerParametersDict during registration. + + reg_use_map = self.get_use_count_map() + + error_messages = [] + + for var in variables: + total_uses = self.subgraph.variable_uses(var) + reg_uses = reg_use_map[var] + + if reg_uses == 0: + error_messages.append("Variable {} not registered.".format(var)) + elif (not math.isinf(reg_uses)) and reg_uses != total_uses: + error_messages.append( + "Variable {} registered with wrong number of uses ({} " + "registrations vs {} uses).".format(var, reg_uses, total_uses)) + + num_get_vars = len(reg_use_map) + + if num_get_vars > len(variables): + error_messages.append("{} registered variables were not included in list." + .format(num_get_vars - len(variables))) + + if error_messages: + error_messages = [ + "Found the following errors with variable registration:" + ] + error_messages + raise ValueError("\n\t".join(error_messages)) + def get_blocks(self): return self.fisher_blocks.values() def get_factors(self): return self.fisher_factors.values() - @property - def generic_registrations(self): - return self._generic_registrations - @property def graph(self): return self._graph @@ -242,6 +363,49 @@ class LayerCollection(object): def subgraph(self): return self._subgraph + def define_linked_parameters(self, params, approximation=None): + """Identify a set of parameters that should be grouped together. + + During automatic graph scanning, any matches containing variables that have + been identified as part of a linked group will be filtered out unless + the match parameters are exactly equal to the ones specified in the linked + group. + + Args: + params: A variable, or a tuple or list of variables. The variables + to be linked. + approximation: Optional string specifying the type of approximation to use + for these variables. If unspecified, this layer collection's default + approximation for the layer type will be used. + + Raises: + ValueError: If the parameters were already registered in a layer or + identified as part of an incompatible group. + """ + params = frozenset(ensure_sequence(params)) + + # Check if any of the variables in 'params' is already in + # 'self.fisher_blocks.keys()'. + for registered_params, fisher_block in self.fisher_blocks.items(): + registered_params_set = set(ensure_sequence(registered_params)) + for variable in params: + if (variable in registered_params_set and + params != registered_params_set): + raise ValueError( + "Can't link parameters {}, variable {} was already registered in " + "group {} with layer {}".format(params, variable, + registered_params, fisher_block)) + + # Check if any of the variables in 'params' is already in + # 'self.linked_parameters'. + for variable in params: + for other_linked_params in self.linked_parameters: + if variable in other_linked_params: + raise ValueError("Can't link parameters {}, variable {} was already " + "linked in group {}.".format(params, variable, + other_linked_params)) + self._linked_parameters[params] = approximation + def create_subgraph(self): if not self.losses: raise ValueError("Must have at least one registered loss.") @@ -255,11 +419,19 @@ class LayerCollection(object): return math_ops.add_n( tuple(loss.evaluate_on_sample() for loss in self.losses)) + def _get_linked_approx(self, params): + """If params were linked, return their specified approximation.""" + params_set = frozenset(ensure_sequence(params)) + if params_set in self.linked_parameters: + return self.linked_parameters[params_set] + else: + return None + def register_fully_connected(self, params, inputs, outputs, - approx=APPROX_KRONECKER_NAME, + approx=None, reuse=VARIABLE_SCOPE): """Registers a fully connnected layer. @@ -268,11 +440,11 @@ class LayerCollection(object): this layer. Weight matrix should have shape [input_size, output_size]. Bias should have shape [output_size]. inputs: Tensor of shape [batch_size, input_size]. Inputs to layer. - outputs: Tensor of shape [batch_size, output_size]. Preactivations + outputs: Tensor of shape [batch_size, output_size]. Outputs produced by layer. - approx: str. One of APPROX_KRONECKER_NAME or APPROX_DIAGONAL_NAME. + approx: str. One of "kron" or "diagonal". reuse: bool or str. If True, reuse an existing FisherBlock. If False, - create a new FisherBlock. If VARIABLE_SCOPE, use + create a new FisherBlock. If "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. Raises: @@ -280,35 +452,18 @@ class LayerCollection(object): KeyError: If reuse == True but no FisherBlock found for 'params'. ValueError: If reuse == True and FisherBlock found but of the wrong type. """ - approx_to_block_types = { - APPROX_KRONECKER_NAME: fb.FullyConnectedKFACBasicFB, - APPROX_DIAGONAL_NAME: fb.FullyConnectedDiagonalFB, - } + if approx is None: + approx = self._get_linked_approx(params) + if approx is None: + approx = self.default_fully_connected_approximation - if approx not in approx_to_block_types: + if approx not in _FULLY_CONNECTED_APPROX_TO_BLOCK_TYPES: raise ValueError("Bad value {} for approx.".format(approx)) - block_type = approx_to_block_types[approx] + block_type = _FULLY_CONNECTED_APPROX_TO_BLOCK_TYPES[approx] has_bias = isinstance(params, (tuple, list)) - if reuse == VARIABLE_SCOPE: - reuse = variable_scope.get_variable_scope().reuse - - if reuse: - block = self.fisher_blocks.get(params, None) - if block is None: - raise KeyError( - "Reuse requested but no FisherBlock found for params {}.".format( - params)) - if not isinstance(block, block_type): - raise ValueError( - "Requested block of type {} but block of type {} already exists " - "for params {}.".format(block_type, type(block), params)) - - else: - block = block_type(self, has_bias) - self.register_block(params, block) - + block = self.register_block(params, block_type(self, has_bias), reuse=reuse) block.register_additional_minibatch(inputs, outputs) def register_conv2d(self, @@ -317,7 +472,7 @@ class LayerCollection(object): padding, inputs, outputs, - approx=APPROX_KRONECKER_NAME, + approx=None, reuse=VARIABLE_SCOPE): """Registers a convolutional layer. @@ -331,10 +486,10 @@ class LayerCollection(object): inputs: Tensor of shape [batch_size, height, width, in_channels]. Inputs to layer. outputs: Tensor of shape [batch_size, height, width, out_channels]. - Preactivations produced by layer. - approx: str. One of APPROX_KRONECKER_NAME or APPROX_DIAGONAL_NAME. + Output produced by layer. + approx: str. One of "kron" or "diagonal". reuse: bool or str. If True, reuse an existing FisherBlock. If False, - create a new FisherBlock. If VARIABLE_SCOPE, use + create a new FisherBlock. If "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. Raises: @@ -342,50 +497,93 @@ class LayerCollection(object): KeyError: If reuse == True but no FisherBlock found for 'params'. ValueError: If reuse == True and FisherBlock found but of the wrong type. """ - approx_to_block_types = { - APPROX_KRONECKER_NAME: fb.ConvKFCBasicFB, - APPROX_DIAGONAL_NAME: fb.ConvDiagonalFB, - } - if approx not in approx_to_block_types: + if approx is None: + approx = self._get_linked_approx(params) + if approx is None: + approx = self.default_conv2d_approximation + + if approx not in _CONV2D_APPROX_TO_BLOCK_TYPES: raise ValueError("Bad value {} for approx.".format(approx)) - block_type = approx_to_block_types[approx] + block_type = _CONV2D_APPROX_TO_BLOCK_TYPES[approx] + block = self.register_block( + params, block_type(self, params, strides, padding), reuse=reuse) + block.register_additional_minibatch(inputs, outputs) - if reuse == VARIABLE_SCOPE: - reuse = variable_scope.get_variable_scope().reuse + def register_generic(self, + params, + batch_size, + approx=None, + reuse=VARIABLE_SCOPE): + """Registers a generic layer. - if reuse: - block = self.fisher_blocks.get(params, None) - if block is None: - raise KeyError( - "Reuse requested but no FisherBlock found for params {}.".format( - params)) - if not isinstance(block, block_type): - raise ValueError( - "Requested block of type {} but block of type {} already exists " - "for params {}.".format(block_type, type(block), params)) + Args: + params: Tensor or tuple of Tensors corresponding to the parameters. + batch_size: 0-D Tensor. Size of the minibatch. + approx: str. One of "full" or "diagonal". + reuse: bool or str. If True, reuse an existing FisherBlock. If False, + create a new FisherBlock. If "VARIABLE_SCOPE", use + tf.get_variable_scope().reuse. - else: - block = block_type(self, params, strides, padding) - self.register_block(params, block) + Raises: + ValueError: For improper value to 'approx'. + KeyError: If reuse == True but no FisherBlock found for 'params'. + ValueError: If reuse == True and FisherBlock found but of the wrong type. + """ - block.register_additional_minibatch(inputs, outputs) + if approx is None: + approx = self._get_linked_approx(params) + if approx is None: + approx = self.default_generic_approximation - def register_generic(self, params, batch_size, approx=APPROX_DIAGONAL_NAME): - params = params if isinstance(params, (tuple, list)) else (params,) - self._generic_registrations |= set(params) - - # Generic registrations do not need special registration rules because we do - # not care about multiple generic registrations. Add them to the - # fisher_block dictionary manually rather than going through the logic in - # self.register_block. - if approx == APPROX_FULL_NAME: - self.fisher_blocks[params] = fb.FullFB(self, params, batch_size) - elif approx == APPROX_DIAGONAL_NAME: - self.fisher_blocks[params] = fb.NaiveDiagonalFB(self, params, batch_size) - else: + if approx not in _GENERIC_APPROX_TO_BLOCK_TYPES: + raise ValueError("Bad value {} for approx.".format(approx)) + + block_type = _GENERIC_APPROX_TO_BLOCK_TYPES[approx] + block = self.register_block(params, block_type(self, params), reuse=reuse) + block.register_additional_minibatch(batch_size) + + def register_fully_connected_multi(self, params, inputs, outputs, + approx=None): + """Register fully connected layers with shared parameters. + + This can handle general fully-connected layers with shared parameters, but + has specialized approximations to deal with the case where there is a + meaningful linear order to the share instances (such as in an RNN). + + Args: + params: Tensor or 2-tuple of Tensors corresponding to weight and bias of + this layer. Weight matrix should have shape [input_size, output_size]. + Bias should have shape [output_size]. + inputs: A list of tensors, each of shape [batch_size, input_size]. Inputs + to layer. In the case of RNNs, one Tensor per time step. + outputs: A list of tensors, the same length as 'inputs', each of shape + [batch_size, output_size]. Outputs produced by layer. In the case of + RNNs, one Tensor per time step. + approx: str. One of "kron_indep", "kron_series_1", or "kron_series_2". + + Raises: + ValueError: For improper value to 'approx'. + """ + if approx is None: + approx = self._get_linked_approx(params) + if approx is None: + approx = self.default_fully_connected_multi_approximation + has_bias = isinstance(params, (tuple, list)) + + # TODO(b/70283649): something along the lines of find_canonical_output + # should be added back in here (and for the other block types, arguably). + + if approx not in _FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES: raise ValueError("Bad value {} for approx.".format(approx)) + block_type = _FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES[approx] + + # For now we don't support multiple minibatches for this type of layer, so + # we set reuse=False + self.register_block(params, + block_type(self, inputs, outputs, has_bias=has_bias), + reuse=False) def register_categorical_predictive_distribution(self, logits, @@ -410,10 +608,10 @@ class LayerCollection(object): tf.get_variable_scope().reuse. Raises: - ValueError: If reuse=True and name != None. - ValueError: If reuse=True and seed != None. - KeyError: If reuse=True and no existing LossFunction with 'name' found. - KeyError: If reuse=False and existing LossFunction with 'name' found. + ValueError: If reuse == True and name == None. + ValueError: If reuse == True and seed != None. + KeyError: If reuse == True and no existing LossFunction with 'name' found. + KeyError: If reuse == False and existing LossFunction with 'name' found. """ name = name or self._graph.unique_name( "register_categorical_predictive_distribution") @@ -522,11 +720,14 @@ class LayerCollection(object): try: hash(args) except TypeError: - raise TypeError(( - "Unable to use (cls, args) = ({}, {}) as a key in " - "LayerCollection.fisher_factors. The pair cannot be hashed." - ).format(cls, args)) - - with variable_scope.variable_scope(self._var_scope): - return utils.setdefault(self.fisher_factors, (cls, args), - lambda: cls(*args)) + raise TypeError( + ("Unable to use (cls, args) = ({}, {}) as a key in " + "LayerCollection.fisher_factors. The pair cannot be hashed.").format( + cls, args)) + + key = cls, args + if key not in self.fisher_factors: + colo = self._colocate_cov_ops_with_inputs + with variable_scope.variable_scope(self._var_scope): + self.fisher_factors[key] = cls(*args, colocate_cov_ops_with_inputs=colo) + return self.fisher_factors[key] diff --git a/tensorflow/contrib/kfac/python/ops/loss_functions.py b/tensorflow/contrib/kfac/python/ops/loss_functions.py index 3cfde7f9ababab73980e93ea1dd65be1b559712b..e2e5bc3ffea3e52087c24802948bc8260e3b199a 100644 --- a/tensorflow/contrib/kfac/python/ops/loss_functions.py +++ b/tensorflow/contrib/kfac/python/ops/loss_functions.py @@ -56,6 +56,30 @@ class LossFunction(object): """The inputs to the loss function (excluding the targets).""" pass + @property + def input_minibatches(self): + """A `list` of inputs to the loss function, separated by minibatch. + + Typically there will be one minibatch per tower in a multi-tower setup. + Returns a list consisting of `self.inputs` by default; `LossFunction`s + supporting registering multiple minibatches should override this method. + + Returns: + A `list` of `Tensor`s representing + """ + return [self.inputs] + + @property + def num_registered_minibatches(self): + """Number of minibatches registered for this LossFunction. + + Typically equal to the number of towers in a multi-tower setup. + + Returns: + An `int` representing the number of registered minibatches. + """ + return len(self.input_minibatches) + def evaluate(self): """Evaluate the loss function on the targets.""" if self.targets is not None: @@ -75,7 +99,6 @@ class LossFunction(object): Returns: log probability of each target, summed across all targets. """ - pass @abc.abstractmethod @@ -415,8 +438,8 @@ class NormalMeanNegativeLogProbLoss(DistributionNegativeLogProbLoss, array_ops.ones(array_ops.shape(self._mean)[:1], dtype=self._mean.dtype), axis=-1) output_slice = self._var**-0.5 * ones_slice - return insert_slice_in_zeros(output_slice, 1, - int(self._mean.shape[1]), index[0]) + return insert_slice_in_zeros(output_slice, 1, int(self._mean.shape[1]), + index[0]) @property def fisher_factor_inner_shape(self): @@ -474,24 +497,23 @@ class NormalMeanVarianceNegativeLogProbLoss(DistributionNegativeLogProbLoss): @property def _fisher_mean(self): - return 1./self._variance + return 1. / self._variance @property def _fisher_mean_factor(self): - return 1./self._scale + return 1. / self._scale @property def _fisher_var(self): - return 1./(2*math_ops.square(self._variance)) + return 1. / (2 * math_ops.square(self._variance)) @property def _fisher_var_factor(self): - return 1./(math_ops.sqrt(2.)*self._variance) + return 1. / (math_ops.sqrt(2.) * self._variance) def multiply_fisher(self, vecs): mean_vec, var_vec = vecs - return (self._fisher_mean * mean_vec, - self._fisher_var * var_vec) + return (self._fisher_mean * mean_vec, self._fisher_var * var_vec) def multiply_fisher_factor(self, vecs): mean_vec, var_vec = self._split(vecs) @@ -511,8 +533,8 @@ class NormalMeanVarianceNegativeLogProbLoss(DistributionNegativeLogProbLoss): # Index corresponds to mean parameter. mean_slice = self._fisher_mean_factor[:, index] mean_slice = array_ops.expand_dims(mean_slice, axis=-1) - mean_output = insert_slice_in_zeros(mean_slice, 1, - int(self._mean.shape[1]), index) + mean_output = insert_slice_in_zeros(mean_slice, 1, int( + self._mean.shape[1]), index) var_output = array_ops.zeros_like(mean_output) else: index -= int(self._mean.shape[-1]) @@ -527,13 +549,17 @@ class NormalMeanVarianceNegativeLogProbLoss(DistributionNegativeLogProbLoss): @property def fisher_factor_inner_shape(self): - return array_ops.concat([array_ops.shape(self._mean)[:-1], - 2*array_ops.shape(self._mean)[-1:]], axis=0) + return array_ops.concat( + [ + array_ops.shape(self._mean)[:-1], + 2 * array_ops.shape(self._mean)[-1:] + ], + axis=0) @property def fisher_factor_inner_static_shape(self): shape = self._mean.shape.as_list() - return tensor_shape.TensorShape(shape[-1:] + [2*shape[-1]]) + return tensor_shape.TensorShape(shape[-1:] + [2 * shape[-1]]) def multiply_hessian(self, vector): raise NotImplementedError() @@ -605,6 +631,10 @@ class CategoricalLogitsNegativeLogProbLoss(DistributionNegativeLogProbLoss, def _logits(self): return array_ops.concat(self._logits_components, axis=0) + @property + def input_minibatches(self): + return self._logits_components + @property def targets(self): if all(target is None for target in self._targets_components): @@ -710,8 +740,8 @@ class MultiBernoulliNegativeLogProbLoss(DistributionNegativeLogProbLoss, assert len(index) == 1, "Length of index was {}".format(len(index)) probs_slice = array_ops.expand_dims(self._probs[:, index[0]], -1) output_slice = math_ops.sqrt(probs_slice * (1 - probs_slice)) - return insert_slice_in_zeros(output_slice, 1, - int(self._logits.shape[1]), index[0]) + return insert_slice_in_zeros(output_slice, 1, int(self._logits.shape[1]), + index[0]) @property def fisher_factor_inner_shape(self): diff --git a/tensorflow/contrib/kfac/python/ops/optimizer.py b/tensorflow/contrib/kfac/python/ops/optimizer.py index bfa15e0948c96477d9a79dece985bc4b6dafab6f..ecf7f3e4e5ab7d9c151f760fdab733bc3830e37b 100644 --- a/tensorflow/contrib/kfac/python/ops/optimizer.py +++ b/tensorflow/contrib/kfac/python/ops/optimizer.py @@ -35,16 +35,20 @@ from tensorflow.python.training import gradient_descent class KfacOptimizer(gradient_descent.GradientDescentOptimizer): """The KFAC Optimizer (https://arxiv.org/abs/1503.05671).""" - def __init__( - self, - learning_rate, - cov_ema_decay, - damping, - layer_collection, - momentum=0., - momentum_type="regular", - norm_constraint=None, - name="KFAC",): + def __init__(self, + learning_rate, + cov_ema_decay, + damping, + layer_collection, + var_list=None, + momentum=0., + momentum_type="regular", + norm_constraint=None, + name="KFAC", + estimation_mode="gradients", + colocate_gradients_with_ops=False, + cov_devices=None, + inv_devices=None): """Initializes the KFAC optimizer with the given settings. Args: @@ -63,6 +67,9 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): blocks, kronecker factors, and losses associated with the graph. The layer_collection cannot be modified after KfacOptimizer's initialization. + var_list: Optional list or tuple of variables to train. Defaults to the + list of variables collected in the graph under the key + `GraphKeys.TRAINABLE_VARIABLES`. momentum: The momentum value for this optimizer. Only applies when momentum_type is 'regular' or 'adam'. (Default: 0) momentum_type: The type of momentum to use in this optimizer, one of @@ -72,6 +79,18 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): specified value. May only be used with momentum type 'regular'. (Default: None) name: The name for this optimizer. (Default: 'KFAC') + estimation_mode: The type of estimator to use for the Fishers. Can be + 'gradients', 'empirical', 'curvature_propagation', or 'exact'. + (Default: 'gradients'). See the doc-string for FisherEstimator for + more a more detailed description of these options. + colocate_gradients_with_ops: Whether we should request gradients we + compute in the estimator be colocated with their respective ops. + cov_devices: Iterable of device strings (e.g. '/gpu:0'). Covariance + computations will be placed on these devices in a round-robin fashion. + Can be None, which means that no devices are specified. + inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion + computations will be placed on these devices in a round-robin fashion. + Can be None, which means that no devices are specified. Raises: ValueError: If the momentum type is unsupported. @@ -81,12 +100,19 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): or 'adam'. """ - # We may consider determining the set of variables some other way, but for - # now it's just all the trainable variables. - variables = tf_variables.trainable_variables() + variables = var_list + if variables is None: + variables = tf_variables.trainable_variables() - self._fisher_est = est.FisherEstimator(variables, cov_ema_decay, damping, - layer_collection) + self._fisher_est = est.FisherEstimator( + variables, + cov_ema_decay, + damping, + layer_collection, + estimation_mode=estimation_mode, + colocate_gradients_with_ops=colocate_gradients_with_ops, + cov_devices=cov_devices, + inv_devices=inv_devices) momentum_type = momentum_type.lower() legal_momentum_types = ["regular", "adam", "qmodel"] @@ -101,7 +127,7 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): raise ValueError("Momentum must be unspecified if using a momentum_type " "other than 'regular' or 'adam'.") - self._momentum = ops.convert_to_tensor(momentum, name="momentum") + self._momentum = momentum self._momentum_type = momentum_type self._norm_constraint = norm_constraint @@ -125,16 +151,24 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): return self._fisher_est.damping def minimize(self, *args, **kwargs): - - if "var_list" not in kwargs: - kwargs["var_list"] = tf_variables.trainable_variables() - + kwargs["var_list"] = kwargs.get("var_list") or self.variables if set(kwargs["var_list"]) != set(self.variables): raise ValueError("var_list doesn't match with set of Fisher-estimating " "variables.") - return super(KfacOptimizer, self).minimize(*args, **kwargs) + def compute_gradients(self, *args, **kwargs): + # args[1] could be our var_list + if len(args) > 1: + var_list = args[1] + else: + kwargs["var_list"] = kwargs.get("var_list") or self.variables + var_list = kwargs["var_list"] + if set(var_list) != set(self.variables): + raise ValueError("var_list doesn't match with set of Fisher-estimating " + "variables.") + return super(KfacOptimizer, self).compute_gradients(*args, **kwargs) + def apply_gradients(self, grads_and_vars, *args, **kwargs): """Applies gradients to variables. @@ -291,14 +325,17 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): self._batch_size, dtype=fft_precon_grads[0].dtype) # compute the entries of the 2x2 matrix - m_11 = (_inner_product_list(fft_precon_grads, fft_precon_grads) / batch_size - + self.damping * _inner_product_list(precon_grads, precon_grads)) + m_11 = ( + _inner_product_list(fft_precon_grads, fft_precon_grads) / batch_size + + self.damping * _inner_product_list(precon_grads, precon_grads)) - m_21 = (_inner_product_list(fft_prev_updates, fft_precon_grads) / batch_size - + self.damping * _inner_product_list(prev_updates, precon_grads)) + m_21 = ( + _inner_product_list(fft_prev_updates, fft_precon_grads) / batch_size + + self.damping * _inner_product_list(prev_updates, precon_grads)) - m_22 = (_inner_product_list(fft_prev_updates, fft_prev_updates) / batch_size - + self.damping * _inner_product_list(prev_updates, prev_updates)) + m_22 = ( + _inner_product_list(fft_prev_updates, fft_prev_updates) / batch_size + + self.damping * _inner_product_list(prev_updates, prev_updates)) def non_zero_prevupd_case(): r"""Computes optimal (alpha, mu) given non-zero previous update. @@ -384,8 +421,8 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): grads = list(grad for (grad, _) in grads_and_vars) variables = list(var for (_, var) in grads_and_vars) # previous updates are the negative velocities (up to scaling by LR) - prev_updates = list(-self._zeros_slot(var, "velocity", self._name) - for var in variables) + prev_updates = list( + -self._zeros_slot(var, "velocity", self._name) for var in variables) # Compute optimal velocity update parameters according to quadratic model alpha, mu, _ = self._compute_qmodel_hyperparams( diff --git a/tensorflow/contrib/kfac/python/ops/utils.py b/tensorflow/contrib/kfac/python/ops/utils.py index a7473481e44da0b09c047db9af29032918ea6cef..cec018e406bc51c07f5cafcc2c38efe7e9601618 100644 --- a/tensorflow/contrib/kfac/python/ops/utils.py +++ b/tensorflow/contrib/kfac/python/ops/utils.py @@ -28,9 +28,17 @@ from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops - # Method used for inverting matrices. POSDEF_INV_METHOD = "cholesky" +POSDEF_EIG_METHOD = "self_adjoint" + + +def set_global_constants(posdef_inv_method=None): + """Sets various global constants used by the classes in this module.""" + global POSDEF_INV_METHOD + + if posdef_inv_method is not None: + POSDEF_INV_METHOD = posdef_inv_method class SequenceDict(object): @@ -56,13 +64,6 @@ class SequenceDict(object): return list(self._dict.items()) -def setdefault(dct, key, thunk): - """Like dict.setdefault but delays evaluation of the value to be set.""" - if key not in dct: - dct[key] = thunk() - return dct[key] - - def tensors_to_column(tensors): """Converts a tensor or list of tensors to a column vector. @@ -161,33 +162,11 @@ def mat2d_to_layer_params(vector_template, mat2d): return array_ops.reshape(mat2d, vector_template.shape) -def compute_pi(left_factor, right_factor): - """Computes the scalar constant pi for Tikhonov regularization/damping. - - pi = sqrt( (trace(A) / dim(A)) / (trace(B) / dim(B)) ) - See section 6.3 of https://arxiv.org/pdf/1503.05671.pdf for details. - - Args: - left_factor: The left Kronecker factor Tensor. - right_factor: The right Kronecker factor Tensor. - - Returns: - The computed scalar constant pi for these Kronecker Factors (as a Tensor). - """ - # Instead of dividing by the dim of the norm, we multiply by the dim of the - # other norm. This works out the same in the ratio. - left_norm = math_ops.trace(left_factor) * right_factor.get_shape().as_list()[ - 0] - right_norm = math_ops.trace(right_factor) * left_factor.get_shape().as_list()[ - 0] - return math_ops.sqrt(left_norm / right_norm) - - def posdef_inv(tensor, damping): """Computes the inverse of tensor + damping * identity.""" identity = linalg_ops.eye(tensor.shape.as_list()[0], dtype=tensor.dtype) damping = math_ops.cast(damping, dtype=tensor.dtype) - return posdef_inv_funcs[POSDEF_INV_METHOD](tensor, identity, damping) + return posdef_inv_functions[POSDEF_INV_METHOD](tensor, identity, damping) def posdef_inv_matrix_inverse(tensor, identity, damping): @@ -201,9 +180,44 @@ def posdef_inv_cholesky(tensor, identity, damping): return linalg_ops.cholesky_solve(chol, identity) -posdef_inv_funcs = { +def posdef_inv_eig(tensor, identity, damping): + """Computes inverse(tensor + damping * identity) with eigendecomposition.""" + eigenvalues, eigenvectors = linalg_ops.self_adjoint_eig( + tensor + damping * identity) + return math_ops.matmul( + eigenvectors / eigenvalues, eigenvectors, transpose_b=True) + + +posdef_inv_functions = { "matrix_inverse": posdef_inv_matrix_inverse, "cholesky": posdef_inv_cholesky, + "eig": posdef_inv_eig, +} + + +def posdef_eig(mat): + """Computes the eigendecomposition of a positive semidefinite matrix.""" + return posdef_eig_functions[POSDEF_EIG_METHOD](mat) + + +def posdef_eig_svd(mat): + """Computes the singular values and left singular vectors of a matrix.""" + evals, evecs, _ = linalg_ops.svd(mat) + + return evals, evecs + + +def posdef_eig_self_adjoint(mat): + """Computes eigendecomposition using self_adjoint_eig.""" + evals, evecs = linalg_ops.self_adjoint_eig(mat) + evals = math_ops.abs(evals) # Should be equivalent to svd approach. + + return evals, evecs + + +posdef_eig_functions = { + "self_adjoint": posdef_eig_self_adjoint, + "svd": posdef_eig_svd, } @@ -260,8 +274,8 @@ def fwd_gradients(ys, xs, grad_xs=None, stop_gradients=None): # generated by the first gradients_impl.gradients call. us = [array_ops.zeros_like(y) + float("nan") for y in ys] - dydxs = gradients_impl.gradients(ys, xs, grad_ys=us, - stop_gradients=stop_gradients) + dydxs = gradients_impl.gradients( + ys, xs, grad_ys=us, stop_gradients=stop_gradients) # Deal with strange types that gradients_impl.gradients returns but can't # deal with. @@ -277,3 +291,6 @@ def fwd_gradients(ys, xs, grad_xs=None, stop_gradients=None): dysdx = gradients_impl.gradients(dydxs, us, grad_ys=grad_xs) return dysdx + +# TODO(b/69623235): Add a function for finding tensors that share gradients +# to eliminate redundant fisher factor computations. diff --git a/tensorflow/contrib/kfac/python/ops/utils_lib.py b/tensorflow/contrib/kfac/python/ops/utils_lib.py index ddbb4485ce6967082f1844c6d798c078f1cc303b..8903c90fbce6a890aa419d89b3b79d75f69509fc 100644 --- a/tensorflow/contrib/kfac/python/ops/utils_lib.py +++ b/tensorflow/contrib/kfac/python/ops/utils_lib.py @@ -25,13 +25,11 @@ from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ "SequenceDict", - "setdefault", "tensors_to_column", "column_to_tensors", "kronecker_product", "layer_params_to_mat2d", "mat2d_to_layer_params", - "compute_pi", "posdef_inv", "posdef_inv_matrix_inverse", "posdef_inv_cholesky", diff --git a/tensorflow/contrib/layers/BUILD b/tensorflow/contrib/layers/BUILD index 2f1f283811b6cb9e8bfb52ab2052afac1de700cb..852d06e1e3cc8f8deecd15b7436cd4e4a393ad66 100644 --- a/tensorflow/contrib/layers/BUILD +++ b/tensorflow/contrib/layers/BUILD @@ -61,6 +61,7 @@ tf_custom_op_py_library( "python/layers/normalization.py", "python/layers/optimizers.py", "python/layers/regularizers.py", + "python/layers/rev_block_lib.py", "python/layers/summaries.py", "python/layers/target_column.py", "python/layers/utils.py", @@ -376,6 +377,20 @@ py_test( ], ) +py_test( + name = "rev_block_lib_test", + size = "small", + srcs = ["python/layers/rev_block_lib_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":layers_py", + "//tensorflow/python:client_testlib", + "//tensorflow/python:init_ops", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/layers/__init__.py b/tensorflow/contrib/layers/__init__.py index d309ba958ded86afdc1e4bba2ff471a5181cda4e..6c624929f20503054e0258aad8a843f4a201be64 100644 --- a/tensorflow/contrib/layers/__init__.py +++ b/tensorflow/contrib/layers/__init__.py @@ -42,6 +42,9 @@ See the @{$python/contrib.layers} guide. @@relu @@relu6 @@repeat +@@recompute_grad +@@RevBlock +@@rev_block @@safe_embedding_lookup_sparse @@scale_gradient @@separable_conv2d diff --git a/tensorflow/contrib/layers/python/layers/__init__.py b/tensorflow/contrib/layers/python/layers/__init__.py index 03337f9a5d11784316124442125bb498c4ce9603..f1ae2de68be33880a6fc09957f4d857973902b26 100644 --- a/tensorflow/contrib/layers/python/layers/__init__.py +++ b/tensorflow/contrib/layers/python/layers/__init__.py @@ -28,6 +28,7 @@ from tensorflow.contrib.layers.python.layers.layers import * from tensorflow.contrib.layers.python.layers.normalization import * from tensorflow.contrib.layers.python.layers.optimizers import * from tensorflow.contrib.layers.python.layers.regularizers import * +from tensorflow.contrib.layers.python.layers.rev_block_lib import * from tensorflow.contrib.layers.python.layers.summaries import * from tensorflow.contrib.layers.python.layers.target_column import * from tensorflow.contrib.layers.python.ops.bucketization_op import * diff --git a/tensorflow/contrib/layers/python/layers/feature_column.py b/tensorflow/contrib/layers/python/layers/feature_column.py index 226d933d85d91600e36ffb84212703e10455bfbb..092d418c3f232b364e2c6b4d25a4162626ba17f0 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column.py +++ b/tensorflow/contrib/layers/python/layers/feature_column.py @@ -521,7 +521,7 @@ def sparse_column_with_integerized_feature(column_name, Args: column_name: A string defining sparse column name. - bucket_size: An int that is > 1. The number of buckets. It should be bigger + bucket_size: An int that is >= 1. The number of buckets. It should be bigger than maximum feature. In other words features in this column should be an int64 in range [0, bucket_size) combiner: A string specifying how to reduce if the sparse column is @@ -539,7 +539,7 @@ def sparse_column_with_integerized_feature(column_name, An integerized _SparseColumn definition. Raises: - ValueError: bucket_size is not greater than 1. + ValueError: bucket_size is less than 1. ValueError: dtype is not integer. """ return _SparseColumnIntegerized( diff --git a/tensorflow/contrib/layers/python/layers/feature_column_ops.py b/tensorflow/contrib/layers/python/layers/feature_column_ops.py index fa0047f05d893f6543ddb1680824a32469e13293..78affea44cbfb92523063968dbc1be98841854db 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column_ops.py +++ b/tensorflow/contrib/layers/python/layers/feature_column_ops.py @@ -97,10 +97,13 @@ def _input_from_feature_columns(columns_to_tensors, trainable, scope, output_rank, - default_name): + default_name, + cols_to_outs=None): """Implementation of `input_from(_sequence)_feature_columns`.""" columns_to_tensors = columns_to_tensors.copy() check_feature_columns(feature_columns) + if cols_to_outs is not None and not isinstance(cols_to_outs, dict): + raise ValueError('cols_to_outs must be a dict unless None') with variable_scope.variable_scope(scope, default_name=default_name, values=columns_to_tensors.values()): @@ -144,6 +147,8 @@ def _input_from_feature_columns(columns_to_tensors, except ValueError as e: raise ValueError('Error creating input layer for column: {}.\n' '{}, {}'.format(column.name, e, ee)) + if cols_to_outs is not None: + cols_to_outs[column] = output_tensors[-1] return array_ops.concat(output_tensors, output_rank - 1) @@ -151,7 +156,8 @@ def input_from_feature_columns(columns_to_tensors, feature_columns, weight_collections=None, trainable=True, - scope=None): + scope=None, + cols_to_outs=None): """A tf.contrib.layers style input layer builder based on FeatureColumns. Generally a single example in training data is described with feature columns. @@ -196,6 +202,8 @@ def input_from_feature_columns(columns_to_tensors, trainable: If `True` also add variables to the graph collection `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable). scope: Optional scope for variable_scope. + cols_to_outs: Optional dict from feature column to output tensor, + which is concatenated into the returned tensor. Returns: A Tensor which can be consumed by hidden layers in the neural network. @@ -209,7 +217,8 @@ def input_from_feature_columns(columns_to_tensors, trainable, scope, output_rank=2, - default_name='input_from_feature_columns') + default_name='input_from_feature_columns', + cols_to_outs=cols_to_outs) @experimental diff --git a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py index fbfa0e32de55edab3c90189ddfe05ab826ac9167..e6bbd86ab722c4e853a59f816bed8a8ac1fe9ede 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py +++ b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py @@ -607,6 +607,31 @@ class CreateInputLayersForDNNsTest(test.TestCase): # Verify cross compatibility: Core builder output should equal to contrib. self.assertAllEqual(output.eval().shape, output_core.eval().shape) + def testAllDNNColumnsWithColumnwiseOutputs(self): + sparse_column = feature_column.sparse_column_with_keys( + "ids", ["a", "b", "c", "unseen"]) + real_valued_column = feature_column.real_valued_column("income", 2) + one_hot_column = feature_column.one_hot_column(sparse_column) + embedding_column = feature_column.embedding_column(sparse_column, 10) + features = { + "ids": + sparse_tensor.SparseTensor( + values=["c", "b", "a"], + indices=[[0, 0], [1, 0], [2, 0]], + dense_shape=[3, 1]), + "income": + constant_op.constant([[20.3, 10], [110.3, 0.4], [-3.0, 30.4]]), + } + columns = [one_hot_column, embedding_column, real_valued_column] + cols_to_outs = {} + feature_column_ops.input_from_feature_columns( + features, columns, cols_to_outs=cols_to_outs) + with self.test_session(): + variables_lib.global_variables_initializer().run() + lookup_ops.tables_initializer().run() + for column in columns: + self.assertTrue(column in cols_to_outs) + def testRealValuedColumn(self): real_valued = feature_column.real_valued_column("price") features = {"price": constant_op.constant([[20.], [110], [-3]])} diff --git a/tensorflow/contrib/layers/python/layers/initializers.py b/tensorflow/contrib/layers/python/layers/initializers.py index b12a882d9ae88f7cf4f920cfa5872e5de1c67290..51610f21b24f1d40f26630cc1e69ca723d130639 100644 --- a/tensorflow/contrib/layers/python/layers/initializers.py +++ b/tensorflow/contrib/layers/python/layers/initializers.py @@ -79,7 +79,8 @@ def variance_scaling_initializer(factor=2.0, mode='FAN_IN', uniform=False, ``` * To get [Delving Deep into Rectifiers]( - http://arxiv.org/pdf/1502.01852v1.pdf), use (Default):
+ http://arxiv.org/pdf/1502.01852v1.pdf) (also know as the "MSRA + initialization"), use (Default):
`factor=2.0 mode='FAN_IN' uniform=False` * To get [Convolutional Architecture for Fast Feature Embedding]( http://arxiv.org/abs/1408.5093), use:
diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py index c429d53cdc9101486359a09d985a5649c649f3e2..0d25a09852544a7eb1ed5eb9c2f3402d9064d91a 100644 --- a/tensorflow/contrib/layers/python/layers/layers.py +++ b/tensorflow/contrib/layers/python/layers/layers.py @@ -198,23 +198,23 @@ def avg_pool3d(inputs, return utils.collect_named_outputs(outputs_collections, sc, outputs) -def _fused_batch_norm( - inputs, - decay=0.999, - center=True, - scale=False, - epsilon=0.001, - activation_fn=None, - param_initializers=None, - updates_collections=ops.GraphKeys.UPDATE_OPS, - is_training=True, - reuse=None, - variables_collections=None, - outputs_collections=None, - trainable=True, - data_format=DATA_FORMAT_NHWC, - zero_debias_moving_mean=False, - scope=None): +def _fused_batch_norm(inputs, + decay=0.999, + center=True, + scale=False, + epsilon=0.001, + activation_fn=None, + param_initializers=None, + param_regularizers=None, + updates_collections=ops.GraphKeys.UPDATE_OPS, + is_training=True, + reuse=None, + variables_collections=None, + outputs_collections=None, + trainable=True, + data_format=DATA_FORMAT_NHWC, + zero_debias_moving_mean=False, + scope=None): """Adds a Batch Normalization layer from http://arxiv.org/abs/1502.03167. "Batch Normalization: Accelerating Deep Network Training by Reducing @@ -257,6 +257,7 @@ def _fused_batch_norm( maintain a linear activation. param_initializers: Optional initializers for beta, gamma, moving mean and moving variance. + param_regularizers: Optional regularizer for beta and gamma. updates_collections: Collections to collect the update ops for computation. The updates_ops need to be executed with the train_op. If None, a control dependency would be added to make sure the updates are @@ -285,7 +286,6 @@ def _fused_batch_norm( ValueError: If the rank of `inputs` is neither 2 or 4. ValueError: If rank or `C` dimension of `inputs` is undefined. """ - # TODO(reedwm): Add support for fp16 inputs. if data_format not in (DATA_FORMAT_NCHW, DATA_FORMAT_NHWC): raise ValueError('data_format has to be either NCHW or NHWC.') with variable_scope.variable_scope( @@ -309,7 +309,6 @@ def _fused_batch_norm( new_shape = [-1, channels, 1, 1] inputs = array_ops.reshape(inputs, new_shape) inputs_shape = inputs.get_shape() - dtype = inputs.dtype.base_dtype if data_format == DATA_FORMAT_NHWC: params_shape = inputs_shape[-1:] else: @@ -319,23 +318,30 @@ def _fused_batch_norm( (inputs.name, params_shape)) # Allocate parameters for the beta and gamma of the normalization. - trainable_beta = trainable and center beta_collections = utils.get_variable_collections(variables_collections, 'beta') + # Float32 required to avoid precision-loss when using fp16 input/output + variable_dtype = dtypes.float32 if not param_initializers: param_initializers = {} + if not param_regularizers: + param_regularizers = {} + beta_regularizer = param_regularizers.get('beta') + gamma_regularizer = param_regularizers.get('gamma') + if center: beta_initializer = param_initializers.get('beta', init_ops.zeros_initializer()) beta = variables.model_variable( 'beta', shape=params_shape, - dtype=dtype, + dtype=variable_dtype, initializer=beta_initializer, + regularizer=beta_regularizer, collections=beta_collections, - trainable=trainable_beta) + trainable=trainable) else: - beta = array_ops.constant(0.0, shape=params_shape) + beta = array_ops.constant(0.0, dtype=variable_dtype, shape=params_shape) if scale: gamma_collections = utils.get_variable_collections( @@ -345,12 +351,13 @@ def _fused_batch_norm( gamma = variables.model_variable( 'gamma', shape=params_shape, - dtype=dtype, + dtype=variable_dtype, initializer=gamma_initializer, + regularizer=gamma_regularizer, collections=gamma_collections, trainable=trainable) else: - gamma = array_ops.constant(1.0, shape=params_shape) + gamma = array_ops.constant(1.0, dtype=variable_dtype, shape=params_shape) # Create moving_mean and moving_variance variables and add them to the # appropriate collections. We disable variable partitioning while creating @@ -367,7 +374,7 @@ def _fused_batch_norm( moving_mean = variables.model_variable( 'moving_mean', shape=params_shape, - dtype=dtype, + dtype=variable_dtype, initializer=moving_mean_initializer, trainable=False, collections=moving_mean_collections) @@ -378,7 +385,7 @@ def _fused_batch_norm( moving_variance = variables.model_variable( 'moving_variance', shape=params_shape, - dtype=dtype, + dtype=variable_dtype, initializer=moving_variance_initializer, trainable=False, collections=moving_variance_collections) @@ -596,6 +603,7 @@ def batch_norm(inputs, epsilon=epsilon, activation_fn=activation_fn, param_initializers=param_initializers, + param_regularizers=param_regularizers, updates_collections=updates_collections, is_training=is_training, reuse=reuse, @@ -1394,7 +1402,8 @@ def dropout(inputs, noise_shape=None, is_training=True, outputs_collections=None, - scope=None): + scope=None, + seed=None): """Returns a dropout op applied to the input. With probability `keep_prob`, outputs the input element scaled up by @@ -1412,6 +1421,8 @@ def dropout(inputs, Otherwise, inputs is returned. outputs_collections: Collection to add the outputs. scope: Optional scope for name_scope. + seed: A Python integer. Used to create random seeds. See + @{tf.set_random_seed} for behavior. Returns: A tensor representing the output of the operation. @@ -1421,6 +1432,7 @@ def dropout(inputs, inputs = ops.convert_to_tensor(inputs) layer = core_layers.Dropout(rate=1 - keep_prob, noise_shape=noise_shape, + seed=seed, name=sc.name, _scope=sc) outputs = layer.apply(inputs, training=is_training) @@ -2008,7 +2020,7 @@ def layer_norm(inputs, Given a tensor `inputs` of rank `R`, moments are calculated and normalization is performed over axes `begin_norm_axis ... R - 1`. Scaling and centering, - if requested, is performed over axes `begin_shift_axis .. R - 1`. + if requested, is performed over axes `begin_params_axis .. R - 1`. By default, `begin_norm_axis = 1` and `begin_params_axis = -1`, meaning that normalization is performed over all but the first axis @@ -2549,7 +2561,10 @@ def separable_convolution2d( regularizer=weights_regularizer, trainable=trainable, collections=weights_collections) - strides = [1, stride_h, stride_w, 1] + strides = [1, 1, stride_h, + stride_w] if data_format.startswith('NC') else [ + 1, stride_h, stride_w, 1 + ] outputs = nn.depthwise_conv2d(inputs, depthwise_weights, strides, padding, rate=utils.two_element_tuple(rate), @@ -2639,51 +2654,52 @@ def spatial_softmax(features, ValueError: If unexpected data_format specified. ValueError: If num_channels dimension is unspecified. """ - shape = array_ops.shape(features) - static_shape = features.shape - if data_format == DATA_FORMAT_NHWC: - height, width, num_channels = shape[1], shape[2], static_shape[3] - elif data_format == DATA_FORMAT_NCHW: - num_channels, height, width = static_shape[1], shape[2], shape[3] - else: - raise ValueError('data_format has to be either NCHW or NHWC.') - if num_channels.value is None: - raise ValueError('The num_channels dimension of the inputs to ' - '`spatial_softmax` should be defined. Found `None`.') - - with ops.name_scope(name, 'spatial_softmax', [features]) as name: - # Create tensors for x and y coordinate values, scaled to range [-1, 1]. - pos_x, pos_y = array_ops.meshgrid(math_ops.lin_space(-1., 1., num=height), - math_ops.lin_space(-1., 1., num=width), - indexing='ij') - pos_x = array_ops.reshape(pos_x, [height * width]) - pos_y = array_ops.reshape(pos_y, [height * width]) - if temperature is None: - temperature_collections = utils.get_variable_collections( - variables_collections, 'temperature') - temperature = variables.model_variable( - 'temperature', - shape=(), - dtype=dtypes.float32, - initializer=init_ops.ones_initializer(), - collections=temperature_collections, - trainable=trainable) - if data_format == 'NCHW': - features = array_ops.reshape(features, [-1, height * width]) + with variable_scope.variable_scope(name, 'spatial_softmax'): + shape = array_ops.shape(features) + static_shape = features.shape + if data_format == DATA_FORMAT_NHWC: + height, width, num_channels = shape[1], shape[2], static_shape[3] + elif data_format == DATA_FORMAT_NCHW: + num_channels, height, width = static_shape[1], shape[2], shape[3] else: - features = array_ops.reshape( - array_ops.transpose(features, [0, 3, 1, 2]), [-1, height * width]) - - softmax_attention = nn.softmax(features/temperature) - expected_x = math_ops.reduce_sum( - pos_x * softmax_attention, [1], keep_dims=True) - expected_y = math_ops.reduce_sum( - pos_y * softmax_attention, [1], keep_dims=True) - expected_xy = array_ops.concat([expected_x, expected_y], 1) - feature_keypoints = array_ops.reshape( - expected_xy, [-1, num_channels.value * 2]) - feature_keypoints.set_shape([None, num_channels.value * 2]) - return feature_keypoints + raise ValueError('data_format has to be either NCHW or NHWC.') + if num_channels.value is None: + raise ValueError('The num_channels dimension of the inputs to ' + '`spatial_softmax` should be defined. Found `None`.') + + with ops.name_scope('spatial_softmax_op', 'spatial_softmax_op', [features]): + # Create tensors for x and y coordinate values, scaled to range [-1, 1]. + pos_x, pos_y = array_ops.meshgrid(math_ops.lin_space(-1., 1., num=height), + math_ops.lin_space(-1., 1., num=width), + indexing='ij') + pos_x = array_ops.reshape(pos_x, [height * width]) + pos_y = array_ops.reshape(pos_y, [height * width]) + if temperature is None: + temperature_collections = utils.get_variable_collections( + variables_collections, 'temperature') + temperature = variables.model_variable( + 'temperature', + shape=(), + dtype=dtypes.float32, + initializer=init_ops.ones_initializer(), + collections=temperature_collections, + trainable=trainable) + if data_format == 'NCHW': + features = array_ops.reshape(features, [-1, height * width]) + else: + features = array_ops.reshape( + array_ops.transpose(features, [0, 3, 1, 2]), [-1, height * width]) + + softmax_attention = nn.softmax(features/temperature) + expected_x = math_ops.reduce_sum( + pos_x * softmax_attention, [1], keep_dims=True) + expected_y = math_ops.reduce_sum( + pos_y * softmax_attention, [1], keep_dims=True) + expected_xy = array_ops.concat([expected_x, expected_y], 1) + feature_keypoints = array_ops.reshape( + expected_xy, [-1, num_channels.value * 2]) + feature_keypoints.set_shape([None, num_channels.value * 2]) + return feature_keypoints def stack(inputs, layer, stack_args, **kwargs): diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py index 7c77e905f7432db4e42e7fda70aa72f32f40bb09..ae64b75d939ce0ffab300b01d3cfcb67a9d0da1c 100644 --- a/tensorflow/contrib/layers/python/layers/layers_test.py +++ b/tensorflow/contrib/layers/python/layers/layers_test.py @@ -1345,11 +1345,20 @@ class DropoutTest(test.TestCase): num_elem_initial = math_ops.reduce_mean(math_ops.to_float(images > 0)) output = _layers.dropout(images) num_elem = math_ops.reduce_mean(math_ops.to_float(output > 0)) - sess.run(variables_lib.global_variables_initializer()) num_elem, num_elem_initial = sess.run([num_elem, num_elem_initial]) self.assertLess(num_elem, num_elem_initial / 2 + 0.1) self.assertGreater(num_elem, num_elem_initial / 2 - 0.1) + def testDropoutSeed(self): + """Test that providing the same seed produces the same result.""" + height, width = 10, 10 + with self.test_session() as sess: + images = random_ops.random_uniform( + (5, height, width, 3), seed=1, name='images') + output1 = _layers.dropout(images, seed=1) + output2 = _layers.dropout(images, seed=1) + self.assertAllEqual(*sess.run([output1, output2])) + def testCreateDropoutNoTraining(self): height, width = 3, 3 with self.test_session() as sess: @@ -1358,7 +1367,6 @@ class DropoutTest(test.TestCase): num_elem_initial = math_ops.reduce_mean(math_ops.to_float(images > 0)) output = _layers.dropout(images, is_training=False) num_elem = math_ops.reduce_mean(math_ops.to_float(output > 0)) - sess.run(variables_lib.global_variables_initializer()) num_elem, num_elem_initial = sess.run([num_elem, num_elem_initial]) self.assertEqual(num_elem, num_elem_initial) outputs, inputs = sess.run([output, images]) @@ -1766,10 +1774,13 @@ class BatchNormTest(test.TestCase): with self.assertRaisesRegexp(ValueError, 'undefined'): _layers.batch_norm(inputs, data_format='NCHW') - def _testCreateOp(self, fused): + def _testCreateOp(self, fused, dtype=None): + if dtype is None: + dtype = dtypes.float32 height, width = 3, 3 with self.test_session(): - images = np.random.uniform(size=(5, height, width, 3)).astype('f') + images = np.random.uniform(size=(5, height, width, 3)).astype( + dtype.as_numpy_dtype) output = _layers.batch_norm(images, fused=fused) expected_name = ('BatchNorm/FusedBatchNorm' if fused else 'BatchNorm/batchnorm') @@ -1784,29 +1795,44 @@ class BatchNormTest(test.TestCase): def testCreateOpFused(self): self._testCreateOp(True) - def testCreateOpBetaRegularizer(self): + def testCreateOpFusedFloat16(self): + self._testCreateOp(True, dtypes.float16) + + def _testCreateOpBetaRegularizer(self, fused=True): height, width = 3, 3 with self.test_session(): reg = lambda x: 0.1 * math_ops.reduce_sum(x) images = np.random.uniform(size=(5, height, width, 3)).astype('f') - _layers.batch_norm(images, param_regularizers={'beta': reg}) + _layers.batch_norm(images, param_regularizers={'beta': reg}, fused=fused) self.assertEqual( len(ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)), 1) beta_decay = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)[0] self.assertEqual(beta_decay.op.name, 'BatchNorm/beta/Regularizer/mul') - def testCreateOpGammaRegularizer(self): + def testCreateOpBetaRegularizerFused(self): + self._testCreateOpBetaRegularizer(fused=True) + + def testCreateOpBetaRegularizerNonFused(self): + self._testCreateOpBetaRegularizer(fused=False) + + def _testCreateOpGammaRegularizer(self, fused=True): height, width = 3, 3 with self.test_session(): reg = lambda x: 0.1 * math_ops.reduce_sum(x) images = np.random.uniform(size=(5, height, width, 3)).astype('f') _layers.batch_norm( - images, param_regularizers={'gamma': reg}, scale=True) + images, param_regularizers={'gamma': reg}, scale=True, fused=fused) self.assertEqual( len(ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)), 1) gamma_decay = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)[0] self.assertEqual(gamma_decay.op.name, 'BatchNorm/gamma/Regularizer/mul') + def testCreateOpGammaRegularizerFused(self): + self._testCreateOpGammaRegularizer(fused=True) + + def testCreateOpGammaRegularizerNonFused(self): + self._testCreateOpGammaRegularizer(fused=False) + def testCreateVariables(self): height, width = 3, 3 with self.test_session(): @@ -2639,10 +2665,63 @@ class BatchNormTest(test.TestCase): def testBatchNormBeta(self): # Test case for 11673 with self.test_session() as sess: - a = array_ops.placeholder(dtypes.float32, shape=(10, 10, 10, 10)) - b = _layers.batch_norm(a, center=False, data_format='NCHW', - zero_debias_moving_mean=True) + a_32 = array_ops.placeholder(dtypes.float32, shape=(10, 10, 10, 10)) + _layers.batch_norm( + a_32, center=False, data_format='NCHW', zero_debias_moving_mean=True) + a_16 = array_ops.placeholder(dtypes.float16, shape=(10, 10, 10, 10)) + _layers.batch_norm( + a_16, center=False, data_format='NCHW', zero_debias_moving_mean=True) + sess.run(variables_lib.global_variables_initializer()) + + def testVariablesAreFloat32(self): + height, width = 3, 3 + with self.test_session(): + images = random_ops.random_uniform( + (5, height, width, 3), seed=1, dtype=dtypes.float16) + _layers.batch_norm(images, scale=True) + beta = variables.get_variables_by_name('beta')[0] + gamma = variables.get_variables_by_name('gamma')[0] + self.assertEqual(beta.dtype, dtypes.float32_ref) + self.assertEqual(gamma.dtype, dtypes.float32_ref) + moving_mean = variables.get_variables_by_name('moving_mean')[0] + moving_variance = variables.get_variables_by_name('moving_variance')[0] + self.assertEqual(moving_mean.dtype, dtypes.float32_ref) + self.assertEqual(moving_variance.dtype, dtypes.float32_ref) + + def _runFusedBatchNorm(self, shape, dtype): + channels = shape[1] + images = np.arange(np.product(shape), dtype=dtype).reshape(shape) + beta = init_ops.constant_initializer( + np.arange(2, channels + 2, dtype=np.float32)) + gamma = init_ops.constant_initializer( + np.arange(10, channels + 10, dtype=np.float32) * 2.0) + mean = init_ops.constant_initializer( + np.arange(3, channels + 3, dtype=np.float32) * 5.0) + variance = init_ops.constant_initializer( + np.arange(1, channels + 1, dtype=np.float32) * 4.0) + output = _layers.batch_norm( + images, + fused=True, + is_training=True, + scale=True, + epsilon=0.5, + param_initializers={ + 'beta': beta, + 'gamma': gamma, + 'moving_mean': mean, + 'moving_variance': variance, + }, + data_format='NCHW') + with self.test_session(use_gpu=True) as sess: sess.run(variables_lib.global_variables_initializer()) + return sess.run(output) + + def testFusedBatchNormFloat16MatchesFloat32(self): + if test.is_gpu_available(cuda_only=True): + shape = [5, 4, 2, 3] + res_32 = self._runFusedBatchNorm(shape, np.float32) + res_16 = self._runFusedBatchNorm(shape, np.float16) + self.assertAllClose(res_32, res_16, rtol=1e-3) def testAdjustmentCreated(self): # Tests that the adjustment is appropriately passed to and used by the core @@ -3247,16 +3326,24 @@ class SeparableConv2dTest(test.TestCase): for model_variable in model_variables: self.assertEqual(trainable, model_variable in trainable_variables) - def testConvNCHW(self): - for num_filters, correct_output_filters in [(None, 6), (8, 8)]: + def testSepConvNCHW(self): + for num_filters, correct_output_filters in zip((None, 5), (6, 5)): with self.test_session(): - batch, height, width = 4, 5, 6 + batch, height, width = 4, 10, 12 + kernel_dim, stride = 3, 2 images = random_ops.random_uniform((batch, 3, height, width), seed=1) output = layers_lib.separable_conv2d( - images, num_filters, [3, 3], 2, padding='VALID', data_format='NCHW') - self.assertListEqual( - output.get_shape().as_list(), [batch, correct_output_filters, - height - 2, width - 2]) + images, + num_outputs=num_filters, + kernel_size=[kernel_dim, kernel_dim], + depth_multiplier=2, + stride=stride, + padding='VALID', + data_format='NCHW') + self.assertListEqual(output.get_shape().as_list(), [ + batch, correct_output_filters, (height - kernel_dim + 1) // stride, + (width - kernel_dim + 1) // stride + ]) class ScaleGradientTests(test.TestCase): diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib.py b/tensorflow/contrib/layers/python/layers/rev_block_lib.py new file mode 100644 index 0000000000000000000000000000000000000000..123275e1fde047cd3772528641b2e3b09742fbdc --- /dev/null +++ b/tensorflow/contrib/layers/python/layers/rev_block_lib.py @@ -0,0 +1,583 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Reversible Residual Block. + +From +[The Reversible Residual Network: Backpropagation Without Storing +Activations](https://arxiv.org/abs/1707.04585). + +Also contains the @recompute_grad decorator, which recomputes the forward +function on the backwards pass. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools +import re + +from six.moves import xrange # pylint: disable=redefined-builtin + +from tensorflow.contrib.framework.python import ops as contrib_framework_ops +from tensorflow.python.framework import function +from tensorflow.python.framework import ops as framework_ops +from tensorflow.python.layers import base +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util import nest + +__all__ = ["rev_block", "RevBlock", "recompute_grad"] + +LAYER_RE = re.compile(".*revlayer_([0-9]*)/([fg])/.*") + + +def _acc_grads(*lists_of_grads): + """Accumulates lists of gradients.""" + acc_grads = [] + for grads in zip(*lists_of_grads): + grads = [g for g in grads if g is not None] + if grads: + acc_grads.append(math_ops.add_n(grads)) + else: + acc_grads.append(None) + return acc_grads + + +def _rev_layer_forward(xs, f, g, f_side_input, g_side_input, + gate_outputs=False): + """Forward for 1 reversible layer.""" + x1, x2 = xs + y1 = x1 + (f(x2, f_side_input) if f_side_input else f(x2)) + y2 = x2 + (g(y1, g_side_input) if g_side_input else g(y1)) + if gate_outputs: + return control_flow_ops.tuple([y1, y2]) + else: + return (y1, y2) + + +def _rev_layer_backward(ys, grad_ys, f, g, f_vars, f_side_input, g_vars, + g_side_input): + """Backprop for 1 layer.""" + y1, y2 = ys + grad_y1, grad_y2 = grad_ys + + # Reconstruct intermediates and inputs (x1, x2) + # stop_gradients required on fn inputs to prevent infinite recursion into this + # grad function on the calls to gradients. + y1_stop = array_ops.stop_gradient(y1) + g_side_input = [array_ops.stop_gradient(t) for t in g_side_input] + gy1 = g(y1_stop, g_side_input) if g_side_input else g(y1_stop) + + x2 = y2 - gy1 + x2_stop = array_ops.stop_gradient(x2) + f_side_input = [array_ops.stop_gradient(t) for t in f_side_input] + fx2 = f(x2_stop, f_side_input) if f_side_input else f(x2_stop) + + x1 = y1 - fx2 + + # Compute gradients wrt to inputs + # dL/dy2 * dG(y1)/y1 + grad_gy1_y2 = gradients_impl.gradients(gy1, y1_stop, grad_y2)[0] + grad_x1 = grad_y1 + grad_gy1_y2 + grad_x2 = ( + gradients_impl.gradients(fx2, x2_stop, grad_y1)[0] + grad_y2 + + gradients_impl.gradients(fx2, x2_stop, grad_gy1_y2)[0]) + + # Compute gradients wrt to vars and side inputs in f and g + grads1 = gradients_impl.gradients(gy1, g_vars + g_side_input, grad_y2) + grad_g_vars, grad_g_side = grads1[:len(g_vars)], grads1[len(g_vars):] + grads2 = gradients_impl.gradients(fx2, f_vars + f_side_input, grad_y1) + grad_f_y1, grad_f_side1 = grads2[:len(f_vars)], grads2[len(f_vars):] + grads3 = gradients_impl.gradients(fx2, f_vars + f_side_input, grad_gy1_y2) + grad_f_y2, grad_f_side2 = grads3[:len(f_vars)], grads3[len(f_vars):] + grad_f_vars = _acc_grads(grad_f_y1, grad_f_y2) + + grad_f_side = _acc_grads(grad_f_side1, grad_f_side2) + + # Put returns in a tuple to ensure a constant memory budget (i.e. don't want + # the subsequent layer to start computing and consuming memory based on a + # subset of these values). + outputs = ((x1, x2), (grad_x1, grad_x2), (grad_f_vars, grad_f_side), + (grad_g_vars, grad_g_side)) + tupled = control_flow_ops.tuple(nest.flatten(outputs)) + return nest.pack_sequence_as(outputs, tupled) + + +def _rev_block_forward(x1, + x2, + f, + g, + num_layers=1, + f_side_input=None, + g_side_input=None, + gate_outputs=False): + """Forward for a series of reversible layers.""" + out = (x1, x2) + for i in xrange(num_layers): + out = _rev_layer_forward( + out, f[i], g[i], f_side_input, g_side_input, gate_outputs=gate_outputs) + + y1, y2 = out + return y1, y2 + + +def _scope_wrap(fn, scope): + + @functools.wraps(fn) + def wrap(*args, **kwargs): + with variable_scope.variable_scope(scope): + return fn(*args, **kwargs) + + return wrap + + +class RevBlock(base.Layer): + """Block of reversible layers. See rev_block.""" + + def __init__(self, + f, + g, + num_layers=1, + f_side_input=None, + g_side_input=None, + use_efficient_backprop=True, + name="revblock", + **kwargs): + super(RevBlock, self).__init__(name=name, **kwargs) + + if isinstance(f, list): + assert len(f) == num_layers + else: + f = [f] * num_layers + + if isinstance(g, list): + assert len(g) == num_layers + else: + g = [g] * num_layers + + f = [_scope_wrap(fn, "revlayer_%d/f" % i) for i, fn in enumerate(f)] + g = [_scope_wrap(fn, "revlayer_%d/g" % i) for i, fn in enumerate(g)] + + self.f = f + self.g = g + + self.num_layers = num_layers + self.f_side_input = f_side_input or [] + self.g_side_input = g_side_input or [] + + self._use_efficient_backprop = use_efficient_backprop + + def call(self, inputs, forward=True): + vs = variable_scope.get_variable_scope() + vars_before = vs.global_variables() + + if forward: + x1, x2 = inputs + out = self._forward(x1, x2) + else: + y1, y2 = inputs + out = self._backward(y1, y2) + + # Add any created variables to the Layer's variable stores + new_vars = vs.global_variables()[len(vars_before):] + train_vars = vs.trainable_variables() + for new_var in new_vars: + if new_var in train_vars: + self._trainable_weights.append(new_var) + else: + self._non_trainable_weights.append(new_var) + + return out + + def forward(self, x1, x2): + return self.apply([x1, x2]) + + def backward(self, y1, y2): + return self.apply([y1, y2], forward=False) + + def build(self, _): + logging.warn("RevBlock constructs its variables on first call, not on " + "build.") + self.built = True + + def _efficient_grad_fn(self, inputs, variables, ys, grad_ys): + """Custom gradient fn for a block of reversible residual layers.""" + side_inputs = inputs[2:] + f_side_idxs = [None] * len(self.f_side_input) + g_side_idxs = [None] * len(self.g_side_input) + assert len(side_inputs) == len(self.f_side_input) + len(self.g_side_input) + + for i, t in enumerate(side_inputs): + if t in self.f_side_input: + f_side_idxs[self.f_side_input.index(t)] = i + elif t in self.g_side_input: + g_side_idxs[self.g_side_input.index(t)] = i + else: + assert False + + f_vars = [[] for _ in range(self.num_layers)] + g_vars = [[] for _ in range(self.num_layers)] + f_vars_idxs = [[] for _ in range(self.num_layers)] + g_vars_idxs = [[] for _ in range(self.num_layers)] + + for i, t in enumerate(variables): + ref = _underlying_variable_ref(t) + + # Use the name to identify the layer number and function (f or g) + regex = LAYER_RE.match(ref.name) + layer_no = int(regex.group(1)) + fn_name = regex.group(2) + if fn_name == "f": + f_vars[layer_no].append(ref) + f_vars_idxs[layer_no].append(i) + else: + assert fn_name == "g" + g_vars[layer_no].append(ref) + g_vars_idxs[layer_no].append(i) + + f_var_grads = [] + g_var_grads = [] + f_side_grads = [] + g_side_grads = [] + + # Reverse variable containers to go backward + f_vars.reverse() + g_vars.reverse() + f = list(self.f) + g = list(self.g) + f.reverse() + g.reverse() + + with variable_scope.variable_scope(self.scope_name, reuse=True): + for i in xrange(self.num_layers): + ys, grad_ys, f_ret, g_ret = _rev_layer_backward( + ys, grad_ys, f[i], g[i], f_vars[i], self.f_side_input, g_vars[i], + self.g_side_input) + + grad_f_vars, grad_f_side = f_ret + grad_g_vars, grad_g_side = g_ret + f_var_grads.append(grad_f_vars) + g_var_grads.append(grad_g_vars) + f_side_grads.append(grad_f_side) + g_side_grads.append(grad_g_side) + + # Accumulate layer gradients for f_side_input and g_side_input + acc_f_side_grads = _acc_grads(*f_side_grads) + acc_g_side_grads = _acc_grads(*g_side_grads) + + # Use the stored idxs to put gradients in the passed-in order. + side_input_grads = [None] * len(side_inputs) + variable_grads = [None] * len(variables) + + # Variable gradients were collected in reverse layer order. Reverse to match + # idxs. + f_var_grads.reverse() + g_var_grads.reverse() + for idxs, grads in list(zip(f_vars_idxs, f_var_grads)) + list( + zip(g_vars_idxs, g_var_grads)): + for i, grad in zip(idxs, grads): + variable_grads[i] = grad + + for i, grad in zip(f_side_idxs, acc_f_side_grads): + side_input_grads[i] = grad + for i, grad in zip(g_side_idxs, acc_g_side_grads): + side_input_grads[i] = grad + + grad_x1, grad_x2 = grad_ys + return [grad_x1, grad_x2] + side_input_grads, variable_grads + + def _forward(self, x1, x2): + """Run forward through the reversible layers.""" + + side_inputs = [self.f_side_input, self.g_side_input] + flat_side_inputs = nest.flatten(side_inputs) + + custom_grad_fn = ( + self._efficient_grad_fn if self._use_efficient_backprop else None) + + @_fn_with_custom_grad(custom_grad_fn) + def _forward_wrap(x1_, x2_, *flat_side_inputs): + f_side, g_side = nest.pack_sequence_as(side_inputs, flat_side_inputs) + return _rev_block_forward( + x1_, + x2_, + self.f, + self.g, + num_layers=self.num_layers, + f_side_input=f_side, + g_side_input=g_side, + gate_outputs=self._use_efficient_backprop) + + return _forward_wrap(x1, x2, *flat_side_inputs) + + def _backward(self, y1, y2): + """Run backward through the reversible layers.""" + + f = list(self.f) + g = list(self.g) + f.reverse() + g.reverse() + + for i in xrange(self.num_layers): + gy1 = g[i](y1, self.g_side_input) if self.g_side_input else g[i](y1) + x2 = y2 - gy1 + fx2 = f[i](x2, self.f_side_input) if self.f_side_input else f[i](x2) + x1 = y1 - fx2 + + y1, y2 = x1, x2 + + return x1, x2 + + +def rev_block(x1, + x2, + f, + g, + num_layers=1, + f_side_input=None, + g_side_input=None, + is_training=True): + """A block of reversible residual layers. + + A reversible residual layer is defined as: + + ``` + y1 = x1 + f(x2, f_side_input) + y2 = x2 + g(y1, g_side_input) + ``` + + A reversible residual block, defined here, is a series of reversible residual + layers. + + Limitations: + * f and g must not close over any Tensors; all side inputs to f and g should + be passed in with f_side_input and g_side_input which will be forwarded to + f and g. + * f and g must not change the dimensionality of their inputs in order for the + addition in the equations above to work. + + Args: + x1: a float Tensor. + x2: a float Tensor. + f: a function, (Tensor) -> (Tensor) (or list of such of length num_layers). + Should not change the shape of the Tensor. Can make calls to get_variable. + See f_side_input if there are side inputs. + g: a function, (Tensor) -> (Tensor) (or list of such of length num_layers). + Should not change the shape of the Tensor. Can make calls to get_variable. + See g_side_input if there are side inputs. + num_layers: int, number of reversible residual layers. Each layer will + apply f and g according to the equations above, with new variables in each + layer. + f_side_input: list of Tensors, side input to f. If not None, signature of f + should be (Tensor, list) -> (Tensor). + g_side_input: list of Tensors, side input to g. If not None, signature of g + should be (Tensor, list) -> (Tensor). + is_training: bool, whether to actually use the efficient backprop codepath. + + Returns: + y1, y2: tuple of float Tensors. + """ + block = RevBlock( + f=f, + g=g, + num_layers=num_layers, + f_side_input=f_side_input, + g_side_input=g_side_input, + use_efficient_backprop=is_training, + _reuse=variable_scope.get_variable_scope().reuse) + return block.forward(x1, x2) + + +def recompute_grad(fn): + """Decorator that recomputes the function on the backwards pass. + + Args: + fn: a function that takes Tensors (all as positional arguments) and returns + a tuple of Tensors. + + Returns: + A wrapped fn that is identical to fn when called, but its activations will + be discarded and recomputed on the backwards pass (i.e. on a call to + tf.gradients). + """ + + @functools.wraps(fn) + def wrapped(*args): + return _recompute_grad(fn, args) + + return wrapped + + +def _recompute_grad(fn, args): + """See recompute_grad.""" + + cached_vs = [] + cached_arg_scope = [] + + def grad_fn(inputs, variables, outputs, output_grads): + """Recompute outputs for gradient computation.""" + del outputs + # Recompute outputs + with framework_ops.control_dependencies(output_grads): + with contrib_framework_ops.arg_scope(cached_arg_scope[0]): + with variable_scope.variable_scope(cached_vs[0], reuse=True): + outputs = fn(*inputs) + + if not (isinstance(outputs, list) or isinstance(outputs, tuple)): + outputs = [outputs] + outputs = list(outputs) + grads = gradients_impl.gradients(outputs, inputs + variables, output_grads) + grad_inputs = grads[:len(inputs)] + grad_vars = grads[len(inputs):] + return grad_inputs, grad_vars + + @_fn_with_custom_grad(grad_fn) + def fn_with_recompute(*args): + cached_vs.append(variable_scope.get_variable_scope()) + # TODO(rsepassi): Rm conditional in TF 1.4 + if hasattr(contrib_framework_ops, "current_arg_scope"): + cached_arg_scope.append(contrib_framework_ops.current_arg_scope()) + else: + cached_arg_scope.append({}) + return fn(*args) + + return fn_with_recompute(*args) + + +def _underlying_variable_ref(t): + """Find the underlying variable ref. + + Traverses through Identity, ReadVariableOp, and Enter ops. + Stops when op type has Variable or VarHandle in name. + + Args: + t: a Tensor + + Returns: + a Tensor that is a variable ref, or None on error. + """ + while t.op.type in ["Identity", "ReadVariableOp", "Enter"]: + t = t.op.inputs[0] + + op_type = t.op.type + if "Variable" in op_type or "VarHandle" in op_type: + return t + else: + return None + + +def _fn_with_custom_grad(grad_fn, use_global_vars=False): + """Decorator to create a subgraph with a custom gradient function. + + The subgraph created by the decorated function is NOT put in a Defun and so + does not suffer from the limitations of the Defun (all subgraph ops on the + same device, no summaries). + + Args: + grad_fn: function with signature + (inputs, variables, outputs, output_grads) -> (grad_inputs, grad_vars), + all of which are lists of Tensors. + use_global_vars: if True, variables will be the global variables created. + If False, will be the trainable variables. + + Returns: + Decorator for function such that the gradient is defined by grad_fn. + """ + + def dec(fn): + + @functools.wraps(fn) + def wrapped(*args): + return _fn_with_custom_grad_internal( + fn, args, grad_fn, use_global_vars=use_global_vars) + + return wrapped + + return dec + + +def _fn_with_custom_grad_internal(fn, inputs, grad_fn, use_global_vars=False): + """Create a subgraph with a custom gradient. + + Args: + fn: function that takes inputs as arguments and produces 1 or more Tensors. + inputs: list, will be passed as fn(*inputs). + grad_fn: function with signature + (inputs, vars, outputs, output_grads) -> (grad_inputs, grad_vars), + all of which are lists of Tensors. + use_global_vars: if True, variables will be the global variables created. + If False, will be the trainable variables. + + Returns: + fn(*inputs) + """ + vs = variable_scope.get_variable_scope() + get_vars_fn = ( + vs.global_variables if use_global_vars else vs.trainable_variables) + len_before_vars = len(get_vars_fn()) + inputs = list(inputs) + outputs = fn(*inputs) + train_vars = get_vars_fn()[len_before_vars:] + + if grad_fn is None: + return outputs + + if not (isinstance(outputs, tuple) or isinstance(outputs, list)): + outputs = [outputs] + outputs = list(outputs) + + defun_inputs = [inputs, train_vars, outputs] + + def custom_grad_fn(op, *dys): + """Custom grad fn applying grad_fn for identity Defun.""" + fn_inputs, fn_vars, fn_outputs = nest.pack_sequence_as( + defun_inputs, list(op.inputs)) + dys = list(dys) + assert len(fn_outputs) == len(outputs) + assert len(fn_outputs) == len(dys) + + grad_inputs, grad_vars = grad_fn(fn_inputs, fn_vars, fn_outputs, dys) + grad_outputs = [None] * len(fn_outputs) + return tuple(grad_inputs + grad_vars + grad_outputs) + + # The Defun takes as input the original inputs, the trainable variables + # created in fn, and the outputs. In the forward it passes through the + # outputs. In the backwards, it produces gradients for the original inputs + # and the trainable variables. + in_types = [t.dtype for t in inputs] + out_types = [t.dtype for t in outputs] + var_types = [t.dtype for t in train_vars] + + # Get a unique name for the Defun + with framework_ops.name_scope("identity_custom_grad") as ns: + defun_name = ns + + @function.Defun( + *(in_types + var_types + out_types), + func_name=defun_name, + python_grad_func=custom_grad_fn, + shape_func=lambda _: [t.get_shape() for t in outputs]) + def identity(*args): + _, _, outs = nest.pack_sequence_as(defun_inputs, args) + return tuple([array_ops.identity(t) for t in outs]) + + flat_inputs = nest.flatten(defun_inputs) + id_out = identity(*flat_inputs) + return id_out diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py new file mode 100644 index 0000000000000000000000000000000000000000..cbcbcd75114a522b95631e4e7e95c1641b0a9987 --- /dev/null +++ b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py @@ -0,0 +1,364 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for RevBlock.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.layers.python.layers import layers +from tensorflow.contrib.layers.python.layers import rev_block_lib +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import random_seed +from tensorflow.python.layers import convolutional +from tensorflow.python.layers import core as core_layers +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +class RevBlockTest(test.TestCase): + CHANNELS = 8 + NUM_LAYERS = 4 + BATCH_SIZE = 16 + + def testForwardBackward(self): + + def f(x): + return core_layers.dense(x, self.CHANNELS // 2, use_bias=True) + + def g(x): + return core_layers.dense(x, self.CHANNELS // 2, use_bias=True) + + x = random_ops.random_uniform( + [self.BATCH_SIZE, self.CHANNELS], dtype=dtypes.float32) + x1, x2 = array_ops.split(x, 2, axis=-1) + + block = rev_block_lib.RevBlock(f, g, num_layers=3) + y1, y2 = block.forward(x1, x2) + x1_inv, x2_inv = block.backward(y1, y2) + + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + x1, x2, x1_inv, x2_inv = sess.run([x1, x2, x1_inv, x2_inv]) + + self.assertAllClose(x1, x1_inv) + self.assertAllClose(x2, x2_inv) + + def testBackwardForward(self): + + def f(x): + return core_layers.dense(x, self.CHANNELS // 2, use_bias=True) + + def g(x): + return core_layers.dense(x, self.CHANNELS // 2, use_bias=True) + + y = random_ops.random_uniform( + [self.BATCH_SIZE, self.CHANNELS], dtype=dtypes.float32) + y1, y2 = array_ops.split(y, 2, axis=-1) + + block = rev_block_lib.RevBlock(f, g, num_layers=3) + x1, x2 = block.backward(y1, y2) + y1_inv, y2_inv = block.forward(x1, x2) + + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + y1, y2, y1_inv, y2_inv = sess.run([y1, y2, y1_inv, y2_inv]) + + self.assertAllClose(y1, y1_inv) + self.assertAllClose(y2, y2_inv) + + def _testRevBlock(self, + x=None, + f=None, + g=None, + f_side_input=None, + g_side_input=None): + random_seed.set_random_seed(1234) + + if f is None: + + def f(x): # pylint: disable=function-redefined + return core_layers.dense(x, self.CHANNELS // 2, use_bias=True) + + if g is None: + + def g(x): # pylint: disable=function-redefined + return core_layers.dense(x, self.CHANNELS // 2, use_bias=True) + + if f_side_input is None: + f_side_input = [] + + if g_side_input is None: + g_side_input = [] + + if x is None: + x = random_ops.random_uniform( + [self.BATCH_SIZE, self.CHANNELS], dtype=dtypes.float32) + x1, x2 = array_ops.split(x, 2, axis=-1) + + with variable_scope.variable_scope("rev_test") as vs: + y1_rev, y2_rev = rev_block_lib.rev_block( + x1, + x2, + f, + g, + f_side_input=f_side_input, + g_side_input=g_side_input, + num_layers=self.NUM_LAYERS) + y_rev = array_ops.concat([y1_rev, y2_rev], axis=1) + fg_vars = vs.trainable_variables() + + num_vars = len(variables.global_variables()) + with variable_scope.variable_scope(vs, reuse=True): + y1, y2 = rev_block_lib.rev_block( + x1, + x2, + f, + g, + f_side_input=f_side_input, + g_side_input=g_side_input, + num_layers=self.NUM_LAYERS, + is_training=False) + y = array_ops.concat([y1, y2], axis=1) + # Ensure no new vars were created - full reuse + assert len(variables.global_variables()) == num_vars + + loss_rev = math_ops.reduce_mean(y_rev + 10.) + loss = math_ops.reduce_mean(y + 10.) + + wrt = [x] + f_side_input + g_side_input + fg_vars + grads_rev = gradients_impl.gradients(loss_rev, wrt) + grads = gradients_impl.gradients(loss, wrt) + + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + y_val, yd_val, gd_val, g_val = sess.run([y, y_rev, grads_rev, grads]) + self.assertAllClose(y_val, yd_val) + for g1, g2 in zip(gd_val, g_val): + self.assertAllClose(g1, g2) + + def testRevBlock(self): + self._testRevBlock() + + def testSideInput(self): + f_side_input = random_ops.random_uniform( + [self.BATCH_SIZE, self.CHANNELS // 2]) + + def f(x, side_input): + return core_layers.dense( + x, self.CHANNELS // 2, use_bias=True) + side_input[0] + + self._testRevBlock(f=f, f_side_input=[f_side_input]) + + def testMultipleFns(self): + + def f1(x): + return core_layers.dense(x, self.CHANNELS // 2) + + def f2(x): + return core_layers.dense(x, self.CHANNELS // 2, activation=nn_ops.relu) + + self._testRevBlock(f=[f1, f2, f1, f2]) + + # TODO(rsepassi): Recent change to conv seems to have broken this test. Find + # out why. + def _testConvAndBatchNorm(self): + + x = random_ops.random_uniform( + [self.BATCH_SIZE, 10, self.CHANNELS], dtype=dtypes.float32) + + def f(x): + x = convolutional.conv1d(x, self.CHANNELS // 2, 3, padding="same") + x = layers.batch_norm(x, is_training=True) + x = convolutional.conv1d(x, self.CHANNELS // 2, 3, padding="same") + x = layers.batch_norm(x, is_training=True) + return x + + self._testRevBlock(x=x, f=f) + + def testReuse(self): + + def f(x): + return core_layers.dense(x, self.CHANNELS // 2) + + def g(x): + return core_layers.dense(x, self.CHANNELS // 2) + + x = random_ops.random_uniform( + [self.BATCH_SIZE, self.CHANNELS], dtype=dtypes.float32) + x1, x2 = array_ops.split(x, 2, axis=-1) + + with variable_scope.variable_scope("test"): + y1, y2 = rev_block_lib.rev_block(x1, x2, f, g, num_layers=self.NUM_LAYERS) + + num_vars_before = len(variables.global_variables()) + + with variable_scope.variable_scope("test", reuse=True): + y1, y2 = rev_block_lib.rev_block(x1, x2, f, g, num_layers=self.NUM_LAYERS) + + num_vars_after = len(variables.global_variables()) + self.assertEqual(num_vars_before, num_vars_after) + + loss = math_ops.reduce_mean(y1 + y2) + _ = gradients_impl.gradients(loss, + [x] + variables.trainable_variables()) + + with variable_scope.variable_scope("test", reuse=True): + y1, y2 = rev_block_lib.rev_block(x1, x2, f, g, num_layers=self.NUM_LAYERS) + + num_vars_after = len(variables.global_variables()) + self.assertEqual(num_vars_before, num_vars_after) + + +class RecomputeTest(test.TestCase): + + def testRecompute(self): + + def layer(x, name=None): + with variable_scope.variable_scope(name, default_name="layer"): + x = layers.layer_norm(x) + x = convolutional.conv1d( + x, + 10, + 1, + use_bias=False, + kernel_initializer=init_ops.constant_initializer(42.42)) + x = nn_ops.relu(x) + return x + + def fn(x): + out = x + for _ in range(3): + out = layer(out) + return out + + @rev_block_lib.recompute_grad + def fn_recompute(x): + return fn(x) + + x = random_ops.random_uniform((3, 1, 3)) + recompute_vars = None + with variable_scope.variable_scope("recompute") as vs: + out1 = math_ops.reduce_sum(fn_recompute(x)) + recompute_vars = vs.trainable_variables() + reg_vars = None + with variable_scope.variable_scope("regular") as vs: + out2 = math_ops.reduce_sum(fn(x)) + reg_vars = vs.trainable_variables() + + grad1 = gradients_impl.gradients(out1, recompute_vars) + grad2 = gradients_impl.gradients(out2, reg_vars) + + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + outs = sess.run([out1, out2, grad1, grad2]) + self.assertAllClose(outs[0], outs[1]) + for g1, g2 in zip(outs[2], outs[3]): + self.assertAllClose(g1, g2) + + +class FnWithCustomGradTest(test.TestCase): + + def testCorrectness(self): + + w = random_ops.random_uniform([6, 10]) + + def fn(a, b, c): + return core_layers.dense( + a, + 10, + use_bias=False, + kernel_initializer=lambda shape, dtype, partition_info: w + ) + math_ops.matmul(b, c) + + def grad_fn(inputs, trainable_variables, outputs, grad_outputs): + outputs = outputs[0] + grad_outputs = grad_outputs[0] + grad_inputs = gradients_impl.gradients( + outputs, inputs, grad_ys=grad_outputs) + grad_vars = gradients_impl.gradients( + outputs, trainable_variables, grad_ys=grad_outputs) + return grad_inputs, grad_vars + + custom_fn = rev_block_lib._fn_with_custom_grad(grad_fn)(fn) + + a = random_ops.random_uniform([11, 6]) + b = random_ops.random_uniform([11, 7]) + c = random_ops.random_uniform([7, 10]) + + out = fn(a, b, c) + custom_out = custom_fn(a, b, c) + self.assertEqual(out.get_shape().as_list(), + custom_out.get_shape().as_list()) + + loss = math_ops.reduce_mean(out) + custom_loss = math_ops.reduce_mean(custom_out) + + grads = gradients_impl.gradients( + loss, [a, b, c] + [variables.trainable_variables()[0]]) + custom_grads = gradients_impl.gradients( + custom_loss, [a, b, c] + [variables.trainable_variables()[1]]) + + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + out_val, custom_out_val, grads_val, custom_grads_val = sess.run( + [out, custom_out, grads, custom_grads]) + self.assertAllClose(out_val, custom_out_val) + for g1, g2 in zip(grads_val, custom_grads_val): + self.assertAllClose(g1, g2) + + def testCustomGrad(self): + + def fn(a, b, c): + return core_layers.dense(a, 10, use_bias=False) + math_ops.matmul(b, c) + + def grad_fn(inputs, trainable_variables, unused_outputs, + unused_grad_outputs): + grad_inputs = [ + array_ops.ones_like(t) * (i + 1.) for i, t in enumerate(inputs) + ] + grad_vars = [ + array_ops.ones_like(t) * (i + len(inputs) + 1.) + for i, t in enumerate(trainable_variables) + ] + return grad_inputs, grad_vars + + a = random_ops.random_uniform([11, 6]) + b = random_ops.random_uniform([11, 7]) + c = random_ops.random_uniform([7, 10]) + w = random_ops.random_uniform([6, 10]) + out = rev_block_lib._fn_with_custom_grad(grad_fn)(fn)(a, b, c) + loss = math_ops.reduce_mean(out) + grads = gradients_impl.gradients( + loss, [a, b, c, variables.trainable_variables()[0]]) + expected_grads = [ + array_ops.ones_like(t) * (i + 1.) for i, t in enumerate([a, b, c, w]) + ] + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + g_val, eg_val = sess.run([grads, expected_grads]) + for g1, g2 in zip(g_val, eg_val): + self.assertAllClose(g1, g2) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD index ac615b120c16d5d9a7798874653f8f00f8fd15b4..33f509ec121af6484411ab898fda37179511b708 100644 --- a/tensorflow/contrib/learn/BUILD +++ b/tensorflow/contrib/learn/BUILD @@ -10,7 +10,7 @@ package(default_visibility = [ "//tensorflow:internal", ]) -load("//tensorflow:tensorflow.bzl", "py_test") +load("//tensorflow:tensorflow.bzl", "py_test", "tf_py_test") py_library( name = "learn", @@ -22,6 +22,8 @@ py_library( exclude = ["python/learn/**/*_test.py"], ), srcs_version = "PY2AND3", + # This library should not depend on sklearn, even though some of the code + # refers to it. (The code handles the presence of sklearn conditionally.) deps = [ "//tensorflow/contrib/factorization:factorization_py", "//tensorflow/contrib/framework:framework_py", @@ -55,6 +57,7 @@ py_library( "//tensorflow/python:logging_ops", "//tensorflow/python:lookup_ops", "//tensorflow/python:math_ops", + "//tensorflow/python:metrics", "//tensorflow/python:nn", "//tensorflow/python:parsing_ops", "//tensorflow/python:partitioned_variables", @@ -76,6 +79,7 @@ py_library( "//tensorflow/python:weights_broadcast_ops", "//tensorflow/python/estimator", "//tensorflow/python/estimator:estimator_py", + "//tensorflow/python/estimator:export_export", "//tensorflow/python/estimator:export_output", "//tensorflow/python/estimator:inputs", "//tensorflow/python/estimator:inputs_queues", @@ -85,6 +89,7 @@ py_library( "//tensorflow/python/estimator:run_config", "//tensorflow/python/feature_column", "//tensorflow/python/feature_column:feature_column_py", + "//tensorflow/python/ops/losses", "//tensorflow/python/saved_model:builder", "//tensorflow/python/saved_model:loader", "//tensorflow/python/saved_model:signature_constants", @@ -131,6 +136,7 @@ py_test( "//tensorflow/contrib/learn/python/learn/datasets", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:string_ops", "//tensorflow/python:training", "//tensorflow/python:variables", "//third_party/py/numpy", @@ -148,17 +154,17 @@ py_test( ], ) -py_test( +tf_py_test( name = "experiment_test", size = "medium", srcs = ["python/learn/experiment_test.py"], - srcs_version = "PY2AND3", - deps = [ + additional_deps = [ ":learn", + "//tensorflow/contrib/layers:layers_py", "//tensorflow/core:protos_all_py", - "//tensorflow/python:client", "//tensorflow/python:client_testlib", "//tensorflow/python:platform", + "//tensorflow/python:session", "//tensorflow/python:training", "//tensorflow/python:util", "//tensorflow/python:variables", @@ -198,6 +204,7 @@ py_test( "//tensorflow/contrib/training:training_py", "//tensorflow/python:client_testlib", "//tensorflow/python:platform", + "//tensorflow/python/estimator:run_config", ], ) @@ -216,6 +223,7 @@ py_test( "//tensorflow/python:framework_test_lib", "//tensorflow/python:math_ops", "//tensorflow/python:platform", + "//tensorflow/python:session", "//tensorflow/python:state_ops", "//tensorflow/python:summary", "//tensorflow/python:training", @@ -278,6 +286,8 @@ py_test( "//tensorflow/python:parsing_ops", "//tensorflow/python:platform", "//tensorflow/python:protos_all_py", + "//tensorflow/python:session", + "//tensorflow/python:summary", "//tensorflow/python:training", "//tensorflow/python:util", "//tensorflow/python:variables", @@ -319,12 +329,12 @@ py_test( "//tensorflow/contrib/framework:framework_py", "//tensorflow/contrib/layers:layers_py", "//tensorflow/contrib/learn/python/learn/datasets", - "//tensorflow/contrib/losses:losses_py", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:init_ops", "//tensorflow/python:math_ops", + "//tensorflow/python/ops/losses", "//third_party/py/numpy", ], ) @@ -363,10 +373,10 @@ py_test( deps = [ ":learn", "//tensorflow/core:protos_all_py", - "//tensorflow/python:client", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:lookup_ops", + "//tensorflow/python:session", "//tensorflow/python:sparse_tensor", "//tensorflow/python:variables", "//tensorflow/python/ops/losses", @@ -430,7 +440,6 @@ py_test( "//tensorflow/contrib/layers:layers_py", "//tensorflow/contrib/rnn:rnn_py", "//tensorflow/python:array_ops", - "//tensorflow/python:client", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:functional_ops", @@ -439,6 +448,7 @@ py_test( "//tensorflow/python:random_ops", "//tensorflow/python:random_seed", "//tensorflow/python:rnn_cell", + "//tensorflow/python:session", "//tensorflow/python:sparse_tensor", "//tensorflow/python:variables", "//third_party/py/numpy", @@ -450,6 +460,7 @@ py_test( size = "medium", srcs = ["python/learn/estimators/state_saving_rnn_estimator_test.py"], srcs_version = "PY2AND3", + tags = ["noasan"], deps = [ ":learn", "//tensorflow/contrib/layers:layers_py", @@ -575,10 +586,10 @@ py_test( srcs_version = "PY2AND3", deps = [ ":learn", - "//tensorflow/python:client", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", "//tensorflow/python:control_flow_ops", + "//tensorflow/python:session", "//tensorflow/python:training", "//tensorflow/python/estimator:export_output", "//tensorflow/python/saved_model:signature_constants", @@ -631,9 +642,9 @@ py_test( srcs_version = "PY2AND3", deps = [ ":learn", - "//tensorflow/python:client", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:session", "//third_party/py/numpy", ], ) @@ -704,12 +715,11 @@ py_test( ], ) -py_test( +tf_py_test( name = "graph_io_test", size = "small", srcs = ["python/learn/learn_io/graph_io_test.py"], - srcs_version = "PY2AND3", - deps = [ + additional_deps = [ ":learn", "//tensorflow/python:client", "//tensorflow/python:client_testlib", @@ -721,9 +731,11 @@ py_test( "//tensorflow/python:math_ops", "//tensorflow/python:parsing_ops", "//tensorflow/python:platform", + "//tensorflow/python:session", "//tensorflow/python:training", "//tensorflow/python:variables", ], + grpc_enabled = True, ) py_test( @@ -770,11 +782,12 @@ py_test( "//tensorflow/contrib/session_bundle:exporter", "//tensorflow/contrib/session_bundle:manifest_proto_py_pb2", "//tensorflow/python:array_ops", - "//tensorflow/python:client", "//tensorflow/python:client_testlib", + "//tensorflow/python:errors", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:platform", "//tensorflow/python:random_ops", + "//tensorflow/python:session", "//tensorflow/python:training", "//third_party/py/numpy", "@six_archive//:six", @@ -822,12 +835,9 @@ py_test( srcs_version = "PY2AND3", deps = [ ":learn", - "//tensorflow/contrib/layers:layers_py", - "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", - "//tensorflow/python:platform", - "//tensorflow/python:util", + "//tensorflow/python:dtypes", ], ) @@ -855,7 +865,6 @@ py_binary( srcs_version = "PY2AND3", deps = [ "//tensorflow/contrib/framework:framework_py", - "//tensorflow/python", # TODO(b/34059704): remove when fixed "//tensorflow/python:platform", ], ) diff --git a/tensorflow/contrib/learn/python/learn/estimators/composable_model_test.py b/tensorflow/contrib/learn/python/learn/estimators/composable_model_test.py index 14750961efa30128708430fac038498de0a42118..ef5e620e8f08cffa7c2b945089aa5d150baefefc 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/composable_model_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/composable_model_test.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.framework.python.ops import variables as contrib_variables +from tensorflow.python.training import training_util from tensorflow.contrib.layers.python.layers import feature_column from tensorflow.contrib.learn.python.learn.datasets import base from tensorflow.contrib.learn.python.learn.estimators import composable_model @@ -55,7 +55,7 @@ def _base_model_fn(features, labels, mode, params): raise NotImplementedError def _train_op_fn(loss): - global_step = contrib_variables.get_global_step() + global_step = training_util.get_global_step() assert global_step train_step = model.get_train_step(loss) diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn.py b/tensorflow/contrib/learn/python/learn/estimators/dnn.py index cb15ef23e95d27c737d8ae08065b804bafd39a07..c17b41c0f767e19d9c3635a8f60347a49b297cfb 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dnn.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dnn.py @@ -23,7 +23,7 @@ import six from tensorflow.contrib import layers from tensorflow.contrib.framework import deprecated from tensorflow.contrib.framework import deprecated_arg_values -from tensorflow.contrib.framework.python.ops import variables as contrib_variables +from tensorflow.python.training import training_util from tensorflow.contrib.layers.python.layers import feature_column from tensorflow.contrib.layers.python.layers import optimizers from tensorflow.contrib.learn.python.learn import metric_spec @@ -189,7 +189,7 @@ def _dnn_model_fn(features, labels, mode, params, config=None): """Returns the op to optimize the loss.""" return optimizers.optimize_loss( loss=loss, - global_step=contrib_variables.get_global_step(), + global_step=training_util.get_global_step(), learning_rate=_LEARNING_RATE, optimizer=_get_optimizer(optimizer), gradient_multipliers=( diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py index 788d2d0b1a58fad16712c968593b40de0d3979f0..05ed8b3409e68ae54e5ef89b3a1592a6f285565b 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py @@ -30,7 +30,6 @@ import six from google.protobuf import message from tensorflow.contrib import layers -from tensorflow.contrib import metrics as metrics_lib from tensorflow.contrib.framework import deprecated from tensorflow.contrib.framework import deprecated_args from tensorflow.contrib.framework import list_variables @@ -60,6 +59,7 @@ from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_util from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import lookup_ops +from tensorflow.python.ops import metrics as metrics_lib from tensorflow.python.ops import resources from tensorflow.python.ops import variables from tensorflow.python.platform import gfile @@ -1230,7 +1230,7 @@ class Estimator(BaseEstimator): if metric_key.MetricKey.LOSS not in model_fn_ops.eval_metric_ops: model_fn_ops.eval_metric_ops[metric_key.MetricKey.LOSS] = ( - metrics_lib.streaming_mean(model_fn_ops.loss)) + metrics_lib.mean(model_fn_ops.loss)) return model_fn_ops def _get_predict_ops(self, features): diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator_input_test.py b/tensorflow/contrib/learn/python/learn/estimators/estimator_input_test.py index 248c6c733ffca351c848ba07110ba89928634a23..9d7c1a099aa4be64ca0296fa5b870597dabec7b4 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator_input_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator_input_test.py @@ -23,7 +23,7 @@ import tempfile import numpy as np -from tensorflow.contrib.framework.python.ops import variables +from tensorflow.python.training import training_util from tensorflow.contrib.layers.python.layers import optimizers from tensorflow.contrib.learn.python.learn import metric_spec from tensorflow.contrib.learn.python.learn import models @@ -114,7 +114,7 @@ def linear_model_params_fn(features, labels, mode, params): prediction, loss = (models.linear_regression_zero_init(features, labels)) train_op = optimizers.optimize_loss( loss, - variables.get_global_step(), + training_util.get_global_step(), optimizer='Adagrad', learning_rate=params['learning_rate']) return prediction, loss, train_op @@ -129,7 +129,7 @@ def linear_model_fn(features, labels, mode): (_, features), = features.items() prediction, loss = (models.linear_regression_zero_init(features, labels)) train_op = optimizers.optimize_loss( - loss, variables.get_global_step(), optimizer='Adagrad', learning_rate=0.1) + loss, training_util.get_global_step(), optimizer='Adagrad', learning_rate=0.1) return prediction, loss, train_op @@ -139,7 +139,7 @@ def linear_model_fn_with_model_fn_ops(features, labels, mode): model_fn.ModeKeys.INFER) prediction, loss = (models.linear_regression_zero_init(features, labels)) train_op = optimizers.optimize_loss( - loss, variables.get_global_step(), optimizer='Adagrad', learning_rate=0.1) + loss, training_util.get_global_step(), optimizer='Adagrad', learning_rate=0.1) return model_fn.ModelFnOps( mode=mode, predictions=prediction, loss=loss, train_op=train_op) @@ -150,7 +150,7 @@ def logistic_model_no_mode_fn(features, labels): labels = array_ops.one_hot(labels, 3, 1, 0) prediction, loss = (models.logistic_regression_zero_init(features, labels)) train_op = optimizers.optimize_loss( - loss, variables.get_global_step(), optimizer='Adagrad', learning_rate=0.1) + loss, training_util.get_global_step(), optimizer='Adagrad', learning_rate=0.1) return { 'class': math_ops.argmax(prediction, 1), 'prob': prediction diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py index be2b0cb3ca959323b4de095ca072278f028be301..2a13a84627df35a68a4f04b25ab26ceecad0db0d 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py @@ -32,7 +32,7 @@ from google.protobuf import text_format from tensorflow.contrib import learn from tensorflow.contrib import lookup -from tensorflow.contrib.framework.python.ops import variables +from tensorflow.python.training import training_util from tensorflow.contrib.layers.python.layers import feature_column as feature_column_lib from tensorflow.contrib.layers.python.layers import optimizers from tensorflow.contrib.learn.python.learn import experiment @@ -132,7 +132,7 @@ def linear_model_params_fn(features, labels, mode, params): prediction, loss = (models.linear_regression_zero_init(features, labels)) train_op = optimizers.optimize_loss( loss, - variables.get_global_step(), + training_util.get_global_step(), optimizer='Adagrad', learning_rate=params['learning_rate']) return prediction, loss, train_op @@ -147,7 +147,7 @@ def linear_model_fn(features, labels, mode): (_, features), = features.items() prediction, loss = (models.linear_regression_zero_init(features, labels)) train_op = optimizers.optimize_loss( - loss, variables.get_global_step(), optimizer='Adagrad', learning_rate=0.1) + loss, training_util.get_global_step(), optimizer='Adagrad', learning_rate=0.1) return prediction, loss, train_op @@ -157,7 +157,7 @@ def linear_model_fn_with_model_fn_ops(features, labels, mode): model_fn.ModeKeys.INFER) prediction, loss = (models.linear_regression_zero_init(features, labels)) train_op = optimizers.optimize_loss( - loss, variables.get_global_step(), optimizer='Adagrad', learning_rate=0.1) + loss, training_util.get_global_step(), optimizer='Adagrad', learning_rate=0.1) return model_fn.ModelFnOps( mode=mode, predictions=prediction, loss=loss, train_op=train_op) @@ -168,7 +168,7 @@ def logistic_model_no_mode_fn(features, labels): labels = array_ops.one_hot(labels, 3, 1, 0) prediction, loss = (models.logistic_regression_zero_init(features, labels)) train_op = optimizers.optimize_loss( - loss, variables.get_global_step(), optimizer='Adagrad', learning_rate=0.1) + loss, training_util.get_global_step(), optimizer='Adagrad', learning_rate=0.1) return { 'class': math_ops.argmax(prediction, 1), 'prob': prediction @@ -241,7 +241,7 @@ def _build_estimator_for_resource_export_test(): const = constant_op.constant(-1, dtype=dtypes.int64) table = lookup.MutableHashTable( dtypes.string, dtypes.int64, const, name='LookupTableModel') - update_global_step = variables.get_global_step().assign_add(1) + update_global_step = training_util.get_global_step().assign_add(1) if mode in (model_fn.ModeKeys.TRAIN, model_fn.ModeKeys.EVAL): key = constant_op.constant(['key']) value = constant_op.constant([42], dtype=dtypes.int64) @@ -306,7 +306,7 @@ def _model_fn_ops( mode=mode, predictions=constant_op.constant(0.), loss=constant_op.constant(0.), - train_op=variables.get_global_step().assign_add(1)) + train_op=training_util.get_global_step().assign_add(1)) def _make_input_fn(features, labels): @@ -389,7 +389,7 @@ class EstimatorModelFnTest(test.TestCase): self.assertEqual(expected_param, params) self.assertEqual(model_dir, expected_model_dir) return (constant_op.constant(0.), constant_op.constant(0.), - variables.get_global_step().assign_add(1)) + training_util.get_global_step().assign_add(1)) est = estimator.Estimator(model_fn=_argument_checker, params=expected_param, model_dir=expected_model_dir) @@ -400,7 +400,7 @@ class EstimatorModelFnTest(test.TestCase): def _invalid_model_fn(features, labels): # pylint: disable=unused-argument w = variables_lib.Variable(42.0, 'weight') - update_global_step = variables.get_global_step().assign_add(1) + update_global_step = training_util.get_global_step().assign_add(1) with ops.control_dependencies([update_global_step]): loss = 100.0 - w return None, loss, None @@ -415,7 +415,7 @@ class EstimatorModelFnTest(test.TestCase): # pylint: disable=unused-argument w = variables_lib.Variable(42.0, 'weight') loss = 100.0 - w - update_global_step = variables.get_global_step().assign_add(1) + update_global_step = training_util.get_global_step().assign_add(1) with ops.control_dependencies([update_global_step]): train_op = w.assign_add(loss / 100.0) predictions = loss @@ -434,7 +434,7 @@ class EstimatorModelFnTest(test.TestCase): # pylint: disable=unused-argument w = variables_lib.Variable(42.0, 'weight') loss = 100.0 - w - update_global_step = variables.get_global_step().assign_add(1) + update_global_step = training_util.get_global_step().assign_add(1) with ops.control_dependencies([update_global_step]): train_op = w.assign_add(loss / 100.0) return None, loss, train_op @@ -464,7 +464,7 @@ class EstimatorModelFnTest(test.TestCase): mode=mode, predictions=constant_op.constant(0.), loss=constant_op.constant(0.), - train_op=variables.get_global_step().assign_add(1), + train_op=training_util.get_global_step().assign_add(1), scaffold=monitored_session.Scaffold(init_fn=_init_fn)) est = estimator.Estimator(model_fn=_model_fn_scaffold) @@ -483,7 +483,7 @@ class EstimatorModelFnTest(test.TestCase): mode=mode, predictions=constant_op.constant([[1.]]), loss=constant_op.constant(0.), - train_op=variables.get_global_step().assign_add(1), + train_op=training_util.get_global_step().assign_add(1), scaffold=monitored_session.Scaffold(saver=self.mock_saver)) def input_fn(): diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimators_test.py b/tensorflow/contrib/learn/python/learn/estimators/estimators_test.py index 1d89dfb55b10b032cab7dcf434d396404d4eb83b..8131e0fde6fea5501cacc4714f53ed8d867ca70f 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimators_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimators_test.py @@ -22,7 +22,7 @@ import random import numpy as np -from tensorflow.contrib.framework.python.ops import variables +from tensorflow.python.training import training_util from tensorflow.contrib.learn.python import learn from tensorflow.contrib.learn.python.learn import datasets from tensorflow.contrib.learn.python.learn import metric_spec @@ -62,7 +62,7 @@ class FeatureEngineeringFunctionTest(test.TestCase): _ = labels predictions = features["transformed_x"] loss = constant_op.constant([2.]) - update_global_step = variables.get_global_step().assign_add(1) + update_global_step = training_util.get_global_step().assign_add(1) return predictions, loss, update_global_step estimator = estimator_lib.Estimator( @@ -100,7 +100,7 @@ class FeatureEngineeringFunctionTest(test.TestCase): _ = labels predictions = features["x"] loss = constant_op.constant([2.]) - update_global_step = variables.get_global_step().assign_add(1) + update_global_step = training_util.get_global_step().assign_add(1) return predictions, loss, update_global_step estimator = estimator_lib.Estimator( @@ -139,7 +139,7 @@ class FeatureEngineeringFunctionTest(test.TestCase): _ = labels predictions = features["x"] loss = constant_op.constant([2.]) - update_global_step = variables.get_global_step().assign_add(1) + update_global_step = training_util.get_global_step().assign_add(1) return predictions, loss, update_global_step estimator_with_fe_fn = estimator_lib.Estimator( diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py index 468d792a0dccf5cf046a41ed8e1600940a15ac37..bc0e6fc0091c9b5419ab526855b404eb4a927e97 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head.py @@ -119,7 +119,7 @@ class Head(object): update_op = tf.contrib.layers.optimize_loss(optimizer=sync, loss=model_fn_ops.loss, ...) hooks = [sync.make_session_run_hook(is_chief)] - ... upate train_op and hooks in ModelFnOps and return + ... update train_op and hooks in ModelFnOps and return ``` """ __metaclass__ = abc.ABCMeta diff --git a/tensorflow/contrib/learn/python/learn/estimators/kmeans.py b/tensorflow/contrib/learn/python/learn/estimators/kmeans.py index 992b804f59ecd88fedc2fba10d3079f93c4fe83d..8f9d6fc318a357853bdb8e3264f6691b410006b1 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/kmeans.py +++ b/tensorflow/contrib/learn/python/learn/estimators/kmeans.py @@ -28,7 +28,7 @@ import time import numpy as np from tensorflow.contrib.factorization.python.ops import clustering_ops -from tensorflow.contrib.framework.python.ops import variables +from tensorflow.python.training import training_util from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.learn.python.learn.estimators.model_fn import ModelFnOps from tensorflow.python.framework import ops @@ -128,7 +128,7 @@ def _kmeans_clustering_model_fn(features, labels, mode, params, config): random_seed=params.get('random_seed'), kmeans_plus_plus_num_retries=params.get( 'kmeans_plus_plus_num_retries')).training_graph() - incr_step = state_ops.assign_add(variables.get_global_step(), 1) + incr_step = state_ops.assign_add(training_util.get_global_step(), 1) loss = math_ops.reduce_sum(losses, name=KMeansClustering.LOSS_OP_NAME) summary.scalar('loss/raw', loss) training_op = with_dependencies([training_op, incr_step], loss) diff --git a/tensorflow/contrib/learn/python/learn/estimators/linear.py b/tensorflow/contrib/learn/python/learn/estimators/linear.py index f5445ad4e728dbd3904279573771de9454b5d17c..37aa8b339622415d082933cdf66d2472a4119b48 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/linear.py +++ b/tensorflow/contrib/learn/python/learn/estimators/linear.py @@ -26,7 +26,7 @@ import six from tensorflow.contrib import layers from tensorflow.contrib.framework import deprecated from tensorflow.contrib.framework import deprecated_arg_values -from tensorflow.contrib.framework.python.ops import variables as contrib_variables +from tensorflow.python.training import training_util from tensorflow.contrib.layers.python.layers import feature_column from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.learn.python.learn.estimators import head as head_lib @@ -170,7 +170,7 @@ def _linear_model_fn(features, labels, mode, params, config=None): weight_collections=[parent_scope]) def _train_op_fn(loss): - global_step = contrib_variables.get_global_step() + global_step = training_util.get_global_step() my_vars = ops.get_collection(parent_scope) grads = gradients.gradients(loss, my_vars) if gradient_clip_norm: @@ -252,7 +252,7 @@ def sdca_model_fn(features, labels, mode, params): _add_bias_column(feature_columns, features, bias, columns_to_variables) def _train_op_fn(unused_loss): - global_step = contrib_variables.get_global_step() + global_step = training_util.get_global_step() sdca_model, train_op = optimizer.get_train_step(columns_to_variables, weight_column_name, loss_type, features, diff --git a/tensorflow/contrib/learn/python/learn/estimators/logistic_regressor_test.py b/tensorflow/contrib/learn/python/learn/estimators/logistic_regressor_test.py index 93c62f87e8495f299a8c456574c7b40534186304..656d68b76888d9319c0b9be481f9b0478ac4314c 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/logistic_regressor_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/logistic_regressor_test.py @@ -21,7 +21,7 @@ from __future__ import print_function import numpy as np from tensorflow.contrib import layers -from tensorflow.contrib.framework.python.ops import variables +from tensorflow.python.training import training_util from tensorflow.contrib.layers.python.layers import optimizers from tensorflow.contrib.learn.python.learn.datasets import base from tensorflow.contrib.learn.python.learn.estimators import logistic_regressor @@ -57,7 +57,7 @@ def _logistic_regression_model_fn(features, labels, mode): predictions = math_ops.sigmoid(logits) loss = losses.sigmoid_cross_entropy(labels, logits) train_op = optimizers.optimize_loss( - loss, variables.get_global_step(), optimizer='Adagrad', learning_rate=0.1) + loss, training_util.get_global_step(), optimizer='Adagrad', learning_rate=0.1) return predictions, loss, train_op diff --git a/tensorflow/contrib/learn/python/learn/estimators/model_fn.py b/tensorflow/contrib/learn/python/learn/estimators/model_fn.py index 8be9c72adf1602826fabc650f350b57f72c886be..44e6c7c52dac524a22e9099e33e2aef82f8fe7ba 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/model_fn.py +++ b/tensorflow/contrib/learn/python/learn/estimators/model_fn.py @@ -23,7 +23,6 @@ import collections import six -from tensorflow.contrib import framework as contrib_framework from tensorflow.contrib.framework import get_graph_from_inputs from tensorflow.contrib.learn.python.learn.estimators import constants from tensorflow.contrib.learn.python.learn.estimators import metric_key @@ -32,6 +31,7 @@ from tensorflow.python.estimator import model_fn as core_model_fn_lib from tensorflow.python.estimator.export import export_output as core_export_lib from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.platform import tf_logging as logging @@ -156,11 +156,11 @@ class ModelFnOps( else: if isinstance(predictions, dict): predictions = { - k: contrib_framework.convert_to_tensor_or_sparse_tensor(v) + k: sparse_tensor.convert_to_tensor_or_sparse_tensor(v) for k, v in six.iteritems(predictions) } else: - predictions = contrib_framework.convert_to_tensor_or_sparse_tensor( + predictions = sparse_tensor.convert_to_tensor_or_sparse_tensor( predictions) # Validate eval_metric_ops diff --git a/tensorflow/contrib/learn/python/learn/experiment.py b/tensorflow/contrib/learn/python/learn/experiment.py index 307db76afe20a7743df16d169270a6f319497eb6..fc4bd1f461d7bfbfcfb78201d527959055342f0a 100644 --- a/tensorflow/contrib/learn/python/learn/experiment.py +++ b/tensorflow/contrib/learn/python/learn/experiment.py @@ -140,7 +140,8 @@ class Experiment(object): delay_workers_by_global_step=False, export_strategies=None, train_steps_per_iteration=None, - checkpoint_and_export=False): + checkpoint_and_export=False, + saving_listeners=None): """Constructor for `Experiment`. Creates an Experiment instance. None of the functions passed to this @@ -200,6 +201,9 @@ class Experiment(object): `save_checkpoints_steps`. Also, this parameter leads to the creation of a default `CheckpointSaverHook` instead of a `ValidationMonitor`, so the provided `train_monitors` will need to be adjusted accordingly. + saving_listeners: list of `CheckpointSaverListener` objects. Used by + tf.estimator.Estimator for callbacks that run immediately before or + after checkpoint savings. Raises: ValueError: if `estimator` does not implement Estimator interface, @@ -221,6 +225,9 @@ class Experiment(object): raise ValueError( "`estimator` must implement `tf.contrib.learn.Trainable`" "or `tf.estimator.`Estimator`.") + if saving_listeners is not None: + raise ValueError("`saving_listeners` must be `None` with " + "`tf.contrib.learn.Estimator`.") if isinstance(estimator, tpu_estimator.TPUEstimator): logging.warn( @@ -242,6 +249,7 @@ class Experiment(object): self._eval_delay_secs = eval_delay_secs self._continuous_eval_throttle_secs = continuous_eval_throttle_secs self._checkpoint_and_export = checkpoint_and_export + self._saving_listeners = saving_listeners # Using 1 on a non-cached file system requires a lot of overhead to # read the checkpoint state file. This is particular bad on GCS, so # we use a different default. This is a temporary band-aid, to be @@ -362,9 +370,11 @@ class Experiment(object): logging.info("Waiting %d secs before starting training.", remaining) time.sleep(delay_secs) - return self._call_train(input_fn=self._train_input_fn, - max_steps=self._train_steps, - hooks=self._train_monitors + extra_hooks) + return self._call_train( + input_fn=self._train_input_fn, + max_steps=self._train_steps, + hooks=self._train_monitors + extra_hooks, + saving_listeners=self._saving_listeners) def evaluate(self, delay_secs=None, name=None): """Evaluate on the evaluation data. @@ -712,9 +722,11 @@ class Experiment(object): break logging.info("Training model for %s steps", train_steps_per_iteration) - self._call_train(input_fn=self._train_input_fn, - steps=train_steps_per_iteration, - hooks=self._train_monitors) + self._call_train( + input_fn=self._train_input_fn, + steps=train_steps_per_iteration, + hooks=self._train_monitors, + saving_listeners=self._saving_listeners) logging.info("Evaluating model now.") eval_result = self._call_evaluate(input_fn=self._eval_input_fn, @@ -762,9 +774,11 @@ class Experiment(object): Returns: The result of the `evaluate` call to the `Estimator`. """ - self._call_train(input_fn=self._train_input_fn, - steps=1, - hooks=self._train_monitors) + self._call_train( + input_fn=self._train_input_fn, + steps=1, + hooks=self._train_monitors, + saving_listeners=self._saving_listeners) eval_result = self._call_evaluate(input_fn=self._eval_input_fn, steps=1, @@ -792,7 +806,8 @@ class Experiment(object): return server def _call_train(self, _sentinel=None, # pylint: disable=invalid-name, - input_fn=None, steps=None, hooks=None, max_steps=None): + input_fn=None, steps=None, hooks=None, max_steps=None, + saving_listeners=None): if _sentinel is not None: raise ValueError("_call_train should be called with keyword args only") @@ -801,10 +816,12 @@ class Experiment(object): # safe to convert for both cases. hooks = monitors.replace_monitors_with_hooks(hooks, self._estimator) if self._core_estimator_used: - return self._estimator.train(input_fn=input_fn, - steps=steps, - max_steps=max_steps, - hooks=hooks) + return self._estimator.train( + input_fn=input_fn, + steps=steps, + max_steps=max_steps, + hooks=hooks, + saving_listeners=saving_listeners) else: return self._estimator.fit(input_fn=input_fn, steps=steps, diff --git a/tensorflow/contrib/learn/python/learn/experiment_test.py b/tensorflow/contrib/learn/python/learn/experiment_test.py index fe40d27c445d4f560c96fc9b50ceb0daed30ee93..c29c198d094090a59c8c7dd2949c3f069adf49d0 100644 --- a/tensorflow/contrib/learn/python/learn/experiment_test.py +++ b/tensorflow/contrib/learn/python/learn/experiment_test.py @@ -232,14 +232,19 @@ class ExperimentTest(test.TestCase): def test_train(self): for est in self._estimators_for_tests(): - eval_metrics = 'eval_metrics' if not isinstance( - est, core_estimator.Estimator) else None + if isinstance(est, core_estimator.Estimator): + eval_metrics = None + saving_listeners = 'saving_listeners' + else: + eval_metrics = 'eval_metrics' + saving_listeners = None ex = experiment.Experiment( est, train_input_fn='train_input', train_steps='train_steps', eval_input_fn='eval_input', - eval_metrics=eval_metrics) + eval_metrics=eval_metrics, + saving_listeners=saving_listeners) fit_args = ex.train(delay_secs=0) self.assertEqual(1, est.fit_count) self.assertIn(('max_steps', 'train_steps'), fit_args) @@ -675,8 +680,12 @@ class ExperimentTest(test.TestCase): def test_continuous_train_and_eval(self): for est in self._estimators_for_tests(eval_dict={'global_step': 100}): - eval_metrics = 'eval_metrics' if not isinstance( - est, core_estimator.Estimator) else None + if isinstance(est, core_estimator.Estimator): + eval_metrics = None + saving_listeners = 'saving_listeners' + else: + eval_metrics = 'eval_metrics' + saving_listeners = None noop_hook = _NoopHook() export_strategy = saved_model_export_utils.make_export_strategy( est, @@ -690,7 +699,8 @@ class ExperimentTest(test.TestCase): eval_hooks=[noop_hook], train_steps=100, eval_steps=100, - export_strategies=export_strategy) + export_strategies=export_strategy, + saving_listeners=saving_listeners) ex.continuous_train_and_eval() self.assertEqual(1, est.fit_count) self.assertEqual(1, est.eval_count) @@ -742,9 +752,10 @@ class ExperimentTest(test.TestCase): ex.continuous_train_and_eval(continuous_eval_predicate_fn=predicate_fn) mock_estimator.train.assert_called_once_with( input_fn='train_input', - steps=int(total_steps/10), + steps=int(total_steps / 10), max_steps=test.mock.ANY, - hooks=test.mock.ANY) + hooks=test.mock.ANY, + saving_listeners=test.mock.ANY) def test_continuous_train_and_eval_with_steps_per_iteration_from_user(self): mock_estimator = test.mock.Mock(core_estimator.Estimator) @@ -768,7 +779,8 @@ class ExperimentTest(test.TestCase): input_fn='train_input', steps=1234, max_steps=test.mock.ANY, - hooks=test.mock.ANY) + hooks=test.mock.ANY, + saving_listeners=test.mock.ANY) def test_continuous_train_and_eval_with_default_steps_per_iteration(self): mock_estimator = test.mock.Mock(core_estimator.Estimator) @@ -791,7 +803,8 @@ class ExperimentTest(test.TestCase): input_fn='train_input', steps=1000, max_steps=test.mock.ANY, - hooks=test.mock.ANY) + hooks=test.mock.ANY, + saving_listeners=test.mock.ANY) def test_continuous_train_and_eval_with_invalid_predicate_fn(self): for est in self._estimators_for_tests(): @@ -857,11 +870,19 @@ class ExperimentTest(test.TestCase): est, None if isinstance(est, core_estimator.Estimator) else 'export_input', exports_to_keep=None) + if isinstance(est, core_estimator.Estimator): + eval_metrics = None + saving_listeners = 'saving_listeners' + else: + eval_metrics = 'eval_metrics' + saving_listeners = None ex = experiment.Experiment( est, train_input_fn='train_input', eval_input_fn='eval_input', - export_strategies=(exp_strategy,)) + export_strategies=(exp_strategy,), + eval_metrics=eval_metrics, + saving_listeners=saving_listeners) ex.test() self.assertEqual(1, est.fit_count) self.assertEqual(1, est.eval_count) diff --git a/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py b/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py index 4c50d40aaa9b3c5d94d0a66d08e8ab6173db427a..86fad4c5535a918d87e0741687cfebe3afaf9ddf 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py @@ -28,13 +28,13 @@ import six from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.platform import tf_logging as logging # pylint: disable=g-multiple-import,g-bad-import-order from .pandas_io import HAS_PANDAS, extract_pandas_data, extract_pandas_matrix, extract_pandas_labels from .dask_io import HAS_DASK, extract_dask_data, extract_dask_labels - # pylint: enable=g-multiple-import,g-bad-import-order @@ -365,8 +365,14 @@ class DataFeeder(object): self.random_state = np.random.RandomState( 42) if random_state is None else random_state - num_samples = list(self._x.values())[0].shape[ - 0] if x_is_dict else self._x.shape[0] + if x_is_dict: + num_samples = list(self._x.values())[0].shape[0] + elif tensor_util.is_tensor(self._x): + num_samples = self._x.shape[ + 0].value # shape will be a Dimension, extract an int + else: + num_samples = self._x.shape[0] + if self._shuffle: self.indices = self.random_state.permutation(num_samples) else: diff --git a/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py b/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py index 4b34fc62849766370979bb2002d42ee03ea7161a..3a46c239688017f9204d2c6182a6f81cd325a417 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py @@ -24,6 +24,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.layers import utils from tensorflow.python.ops import array_ops from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import io_ops @@ -280,14 +281,33 @@ def _get_file_names(file_pattern, randomize_input): def _get_examples(file_name_queue, reader, num_threads, read_batch_size, filter_fn, parse_fn): + """Get example filenames matching. + + Args: + file_name_queue: A queue implementation that dequeues elements in + first-in first-out order. + reader: A function or class that returns an object with + `read` method, (filename tensor) -> (example tensor). + num_threads: The number of threads enqueuing examples. + read_batch_size: An int or scalar `Tensor` specifying the number of + records to read at once. + filter_fn: Filtering function, takes both keys as well as an `Example` + Tensors and returns a boolean mask of the same shape as the input Tensors + to be applied for filtering. If `None`, no filtering is done. + parse_fn: Parsing function, takes `Example` Tensor returns parsed + representation. If `None`, no parsing is done. + + Returns: + List of example file names matching `file_name_queue`. + """ with ops.name_scope('read'): example_list = [] for _ in range(num_threads): - if read_batch_size > 1: - keys, examples_proto = reader().read_up_to(file_name_queue, - read_batch_size) - else: - keys, examples_proto = reader().read(file_name_queue) + keys, examples_proto = utils.smart_cond( + read_batch_size > 1, + lambda: reader().read_up_to(file_name_queue, read_batch_size), + lambda: reader().read(file_name_queue)) + if filter_fn: mask = filter_fn(keys, examples_proto) keys = array_ops.boolean_mask(keys, mask) @@ -379,14 +399,15 @@ def _read_keyed_batch_examples_helper(file_pattern, capacity=1, dtypes=[dtypes.string], shapes=[[]]) enqueue_op = file_name_queue.enqueue( input_pipeline_ops.seek_next( - file_names, shuffle=randomize_input, num_epochs=num_epochs, + file_names, + shuffle=randomize_input, + num_epochs=num_epochs, seed=seed)) queue_runner.add_queue_runner( queue_runner.QueueRunner(file_name_queue, [enqueue_op])) else: file_name_queue = input_ops.string_input_producer( - constant_op.constant( - file_names, name='input'), + constant_op.constant(file_names, name='input'), shuffle=randomize_input, num_epochs=num_epochs, name=file_name_queue_scope, @@ -496,7 +517,8 @@ def read_keyed_batch_features(file_pattern, """ with ops.name_scope(name, 'read_batch_features', [file_pattern]) as scope: - if read_batch_size is None: read_batch_size = batch_size + if read_batch_size is None: + read_batch_size = batch_size keys, examples = read_keyed_batch_examples( file_pattern, batch_size, diff --git a/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py b/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py index 6f0fd9a2976d37d1c701a96f50c2b987562cb191..e11e8b698adc113486bbb45572c8129e964cc931 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py @@ -204,8 +204,7 @@ class GraphIOTest(test.TestCase): shape = (0,) features = { "feature": - parsing_ops.FixedLenFeature( - shape=shape, dtype=dtypes_lib.float32) + parsing_ops.FixedLenFeature(shape=shape, dtype=dtypes_lib.float32) } with ops.Graph().as_default() as g, self.test_session(graph=g) as sess: @@ -255,8 +254,8 @@ class GraphIOTest(test.TestCase): self.assertAllEqual((None,), inputs.get_shape().as_list()) self.assertEqual("%s:1" % name, inputs.name) file_name_queue_name = "%s/file_name_queue" % name - file_name_queue_limit_name = ("%s/limit_epochs/epochs" % - file_name_queue_name) + file_name_queue_limit_name = ( + "%s/limit_epochs/epochs" % file_name_queue_name) file_names_name = "%s/input" % file_name_queue_name example_queue_name = "%s/random_shuffle_queue" % name op_nodes = test_util.assert_ops_in_graph({ @@ -354,8 +353,8 @@ class GraphIOTest(test.TestCase): json_lines = [ "".join([ '{"features": { "feature": { "sequence": {', - '"bytes_list": { "value": ["', base64.b64encode(l).decode("ascii"), - '"]}}}}}\n' + '"bytes_list": { "value": ["', + base64.b64encode(l).decode("ascii"), '"]}}}}}\n' ]) for l in lines ] return self._create_temp_file("".join(json_lines)) @@ -823,6 +822,31 @@ class GraphIOTest(test.TestCase): coord.request_stop() coord.join(threads) + def test_read_keyed_batch_features_shared_queue(self): + batch_size = 17 + shape = (0,) + fixed_feature = parsing_ops.FixedLenFeature( + shape=shape, dtype=dtypes_lib.float32) + feature = {"feature": fixed_feature} + reader = io_ops.TFRecordReader + + _, queued_feature = graph_io.read_keyed_batch_features_shared_queue( + _VALID_FILE_PATTERN, batch_size, feature, reader) + + with ops.Graph().as_default() as g, self.test_session(graph=g) as session: + features_result = graph_io.read_batch_features( + _VALID_FILE_PATTERN, batch_size, feature, reader) + session.run(variables.local_variables_initializer()) + + self.assertAllEqual( + queued_feature.get("feature").get_shape().as_list(), + features_result.get("feature").get_shape().as_list()) + + def test_get_file_names_errors(self): + # Raise bad file_pattern. + with self.assertRaises(ValueError): + graph_io._get_file_names([], True) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/learn/python/learn/metric_spec.py b/tensorflow/contrib/learn/python/learn/metric_spec.py index ed6683abedbb8ae76ba364405158eb52cbb6d762..6440bc204b8e339ff51311dcc87b36f556b94092 100644 --- a/tensorflow/contrib/learn/python/learn/metric_spec.py +++ b/tensorflow/contrib/learn/python/learn/metric_spec.py @@ -42,10 +42,8 @@ def _args(fn): """ if hasattr(fn, 'func') and hasattr(fn, 'keywords'): # Handle functools.partial and similar objects. - return tuple([ - arg for arg in tf_inspect.getargspec(fn.func).args - if arg not in set(fn.keywords.keys()) - ]) + return tuple( + [arg for arg in _args(fn.func) if arg not in set(fn.keywords.keys())]) # Handle function. return tuple(tf_inspect.getargspec(fn).args) diff --git a/tensorflow/contrib/learn/python/learn/utils/export.py b/tensorflow/contrib/learn/python/learn/utils/export.py index 6af2287761299f6725f9547917101c18b0cc0164..cb34cb1d26b6812c7f3f39e9f965615de5a8ef07 100644 --- a/tensorflow/contrib/learn/python/learn/utils/export.py +++ b/tensorflow/contrib/learn/python/learn/utils/export.py @@ -20,7 +20,7 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.framework import deprecated -from tensorflow.contrib.framework.python.ops import variables as contrib_variables +from tensorflow.python.training import training_util from tensorflow.contrib.session_bundle import exporter from tensorflow.contrib.session_bundle import gc from tensorflow.python.client import session as tf_session @@ -78,7 +78,7 @@ def _export_graph(graph, saver, checkpoint_path, export_dir, default_graph_signature=default_graph_signature, named_graph_signatures=named_graph_signatures, assets_collection=ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS)) - return export.export(export_dir, contrib_variables.get_global_step(), + return export.export(export_dir, training_util.get_global_step(), session, exports_to_keep=exports_to_keep) @@ -295,7 +295,7 @@ def _export_estimator(estimator, checkpoint_path = (checkpoint_path or tf_saver.latest_checkpoint(estimator._model_dir)) with ops.Graph().as_default() as g: - contrib_variables.create_global_step(g) + training_util.create_global_step(g) if use_deprecated_input_fn: examples = array_ops.placeholder(dtype=dtypes.string, diff --git a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py index 49413092a6bae547ddd2cad272b1abb3af1de046..6ffd2a133995a6ff8b35540221fb5676bf5de19f 100644 --- a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py +++ b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py @@ -33,6 +33,7 @@ from __future__ import division from __future__ import print_function import os +import tempfile import time from tensorflow.contrib.layers.python.layers import feature_column @@ -644,18 +645,22 @@ def make_best_model_export_strategy(serving_input_fn, # TODO(b/67013778): Revisit this approach when corresponding changes to # TF Core are finalized. -def extend_export_strategy(base_export_strategy, post_export_fn, - post_export_name): +def extend_export_strategy(base_export_strategy, + post_export_fn, + post_export_name=None): """Extend ExportStrategy, calling post_export_fn after export. Args: base_export_strategy: An ExportStrategy that can be passed to the Experiment constructor. post_export_fn: A user-specified function to call after exporting the - SavedModel. Takes the export directory as an argument, and returns - a string path to a (potentially different) SavedModel. + SavedModel. Takes two arguments - the path to the SavedModel exported by + base_export_strategy and the directory where to export the SavedModel + modified by the post_export_fn. Returns the path to the exported + SavedModel. post_export_name: The directory name under the export base directory where - SavedModels generated by the post_export_fn will be written. + SavedModels generated by the post_export_fn will be written. If None, the + directory name of base_export_strategy is used. Returns: An ExportStrategy that can be passed to the Experiment constructor. @@ -675,12 +680,24 @@ def extend_export_strategy(base_export_strategy, post_export_fn, Raises: ValueError: If `estimator` is a ${tf.estimator.Estimator} instance - and `default_output_alternative_key` was specified. + and `default_output_alternative_key` was specified or if post_export_fn + does not return a valid directory. """ - export_dir = base_export_strategy.export(estimator, export_dir_base, - checkpoint_path) - if post_export_fn: - export_dir = post_export_fn(export_dir) - return export_dir - - return export_strategy.ExportStrategy(post_export_name, export_fn) + tmp_base_export_dir = tempfile.mkdtemp() + tmp_base_export = base_export_strategy.export( + estimator, tmp_base_export_dir, checkpoint_path) + tmp_post_export_dir = tempfile.mkdtemp() + tmp_post_export = post_export_fn(tmp_base_export, tmp_post_export_dir) + + if not tmp_post_export.startswith(tmp_post_export_dir): + raise ValueError('post_export_fn must return a sub-directory of {}' + .format(tmp_post_export_dir)) + export_relpath = os.path.relpath(tmp_post_export, tmp_post_export_dir) + + gfile.Rename( + os.path.join(tmp_post_export_dir, export_relpath), + os.path.join(export_dir_base, export_relpath)) + return os.path.join(export_dir_base, export_relpath) + + name = post_export_name if post_export_name else base_export_strategy.name + return export_strategy.ExportStrategy(name, export_fn) diff --git a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py index 27f17b54221ea442baafb382aa3fb034d1bb82e6..ec3a88003f01b3b62591c13472029601b11ba491 100644 --- a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py +++ b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py @@ -743,12 +743,19 @@ class SavedModelExportUtilsTest(test.TestCase): None) def test_extend_export_strategy(self): - def _base_export_fn(unused_estimator, export_dir_base, + + def _base_export_fn(unused_estimator, + export_dir_base, unused_checkpoint_path=None): - return export_dir_base + "/e1" + base_path = os.path.join(export_dir_base, "e1") + gfile.MkDir(base_path) + return base_path - def _post_export_fn(orig_path): - return orig_path + "/rewrite" + def _post_export_fn(orig_path, new_path): + assert orig_path.endswith("/e1") + post_export_path = os.path.join(new_path, "rewrite") + gfile.MkDir(post_export_path) + return post_export_path base_export_strategy = export_strategy_lib.ExportStrategy( "Servo", _base_export_fn) @@ -758,9 +765,67 @@ class SavedModelExportUtilsTest(test.TestCase): self.assertEqual(final_export_strategy.name, "Servo2") test_estimator = TestEstimator() - final_path = final_export_strategy.export(test_estimator, "/path/to/orig", - "/path/to/checkpoint") - self.assertEqual("/path/to/orig/e1/rewrite", final_path) + tmpdir = tempfile.mkdtemp() + final_path = final_export_strategy.export(test_estimator, tmpdir, + os.path.join( + tmpdir, "checkpoint")) + self.assertEqual(os.path.join(tmpdir, "rewrite"), final_path) + + def test_extend_export_strategy_same_name(self): + + def _base_export_fn(unused_estimator, + export_dir_base, + unused_checkpoint_path=None): + base_path = os.path.join(export_dir_base, "e1") + gfile.MkDir(base_path) + return base_path + + def _post_export_fn(orig_path, new_path): + assert orig_path.endswith("/e1") + post_export_path = os.path.join(new_path, "rewrite") + gfile.MkDir(post_export_path) + return post_export_path + + base_export_strategy = export_strategy_lib.ExportStrategy( + "Servo", _base_export_fn) + + final_export_strategy = saved_model_export_utils.extend_export_strategy( + base_export_strategy, _post_export_fn) + self.assertEqual(final_export_strategy.name, "Servo") + + test_estimator = TestEstimator() + tmpdir = tempfile.mkdtemp() + final_path = final_export_strategy.export(test_estimator, tmpdir, + os.path.join( + tmpdir, "checkpoint")) + self.assertEqual(os.path.join(tmpdir, "rewrite"), final_path) + + def test_extend_export_strategy_raises_error(self): + + def _base_export_fn(unused_estimator, + export_dir_base, + unused_checkpoint_path=None): + base_path = os.path.join(export_dir_base, "e1") + gfile.MkDir(base_path) + return base_path + + def _post_export_fn(unused_orig_path, unused_new_path): + return tempfile.mkdtemp() + + base_export_strategy = export_strategy_lib.ExportStrategy( + "Servo", _base_export_fn) + + final_export_strategy = saved_model_export_utils.extend_export_strategy( + base_export_strategy, _post_export_fn) + + test_estimator = TestEstimator() + tmpdir = tempfile.mkdtemp() + with self.assertRaises(ValueError) as ve: + final_export_strategy.export(test_estimator, tmpdir, + os.path.join(tmpdir, "checkpoint")) + + self.assertTrue( + "post_export_fn must return a sub-directory" in str(ve.exception)) def _create_test_export_dir(export_dir_base): diff --git a/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py b/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py index 8313aa355d6d40596b40c39f28b64f46c1bb5719..5e7b422e3cc368a22eb94ed470297ae78293c4eb 100644 --- a/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py +++ b/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py @@ -76,7 +76,7 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.util import nest # TODO(ebrevdo): Remove once _linear is fully deprecated. -Linear = rnn_cell_impl._Linear # pylint: disable=protected-access,invalid-name +Linear = core_rnn_cell._Linear # pylint: disable=protected-access,invalid-name def _extract_argmax_and_embed(embedding, diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py index 13f2f0f5021ea4dd339b671e20cb718f4db509f9..7526f3ae0dbdb3d6827e9d7f690090b8438e4f6e 100644 --- a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py +++ b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py @@ -238,10 +238,10 @@ class SdcaModel(object): with name_scope('sdca/prediction'): sparse_variables = self._convert_n_to_tensor(self._variables[ 'sparse_features_weights']) - result = 0.0 + result_sparse = 0.0 for sfc, sv in zip(examples['sparse_features'], sparse_variables): # TODO(sibyl-Aix6ihai): following does not take care of missing features. - result += math_ops.segment_sum( + result_sparse += math_ops.segment_sum( math_ops.multiply( array_ops.gather(sv, sfc.feature_indices), sfc.feature_values), sfc.example_indices) @@ -249,12 +249,14 @@ class SdcaModel(object): dense_variables = self._convert_n_to_tensor(self._variables[ 'dense_features_weights']) + result_dense = 0.0 for i in range(len(dense_variables)): - result += math_ops.matmul(dense_features[i], - array_ops.expand_dims(dense_variables[i], -1)) + result_dense += math_ops.matmul(dense_features[i], + array_ops.expand_dims( + dense_variables[i], -1)) # Reshaping to allow shape inference at graph construction time. - return array_ops.reshape(result, [-1]) + return array_ops.reshape(result_dense, [-1]) + result_sparse def predictions(self, examples): """Add operations to compute predictions by the model. diff --git a/tensorflow/contrib/linear_optimizer/python/sdca_estimator.py b/tensorflow/contrib/linear_optimizer/python/sdca_estimator.py index 701fc1c0597d1de0b0189e86feafbd1c5bbdc818..05794a42c5f2d0eece6adab36fb5610078cece31 100644 --- a/tensorflow/contrib/linear_optimizer/python/sdca_estimator.py +++ b/tensorflow/contrib/linear_optimizer/python/sdca_estimator.py @@ -19,7 +19,7 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib import layers -from tensorflow.contrib.framework.python.ops import variables as contrib_variables +from tensorflow.python.training import training_util from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.learn.python.learn.estimators import head as head_lib from tensorflow.contrib.learn.python.learn.estimators import prediction_key @@ -154,7 +154,7 @@ def sdca_model_fn(features, labels, mode, params, config=None): _add_bias_column(feature_columns, features, bias, columns_to_variables) def _train_op_fn(unused_loss): - global_step = contrib_variables.get_global_step() + global_step = training_util.get_global_step() sdca_model, train_op = optimizer.get_train_step( columns_to_variables, weight_column_name, loss_type, features, labels, global_step) diff --git a/tensorflow/contrib/lite/BUILD b/tensorflow/contrib/lite/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..3f1b0be1a73a3ff1da3452f4ee1a9125f9e26178 --- /dev/null +++ b/tensorflow/contrib/lite/BUILD @@ -0,0 +1,204 @@ +package(default_visibility = [ + "//visibility:public", +]) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts", "gen_selected_ops") + +exports_files(glob([ + "testdata/*.bin", + "models/testdata/*", +])) + +config_setting( + name = "mips", + values = { + "cpu": "mips", + }, +) + +config_setting( + name = "mips64", + values = { + "cpu": "mips64", + }, +) + +load( + "//tensorflow:tensorflow.bzl", + "tf_cc_test", +) + +cc_library( + name = "schema_fbs_version", + hdrs = ["version.h"], +) + +# Main library. No ops are included here. +# TODO(aselle): Resolve problems preventing C99 usage. +cc_library( + name = "context", + srcs = ["context.c"], + hdrs = ["context.h"], +) + +cc_library( + name = "builtin_op_data", + hdrs = [ + "builtin_op_data.h", + ], +) + +cc_library( + name = "string", + hdrs = [ + "string.h", + ], + deps = [ + "//tensorflow/core:lib_platform", + ], +) + +# TODO(ahentz): investigate dependency on gemm_support requiring usage of tf_copts. +cc_library( + name = "framework", + srcs = [ + "allocation.cc", + "error_reporter.cc", + "interpreter.cc", + "model.cc", + "nnapi_delegate.cc", + "optional_debug_tools.cc", + "simple_memory_arena.cc", + ], + hdrs = [ + "allocation.h", + "context.h", + "error_reporter.h", + "interpreter.h", + "model.h", + "nnapi_delegate.h", + "optional_debug_tools.h", + "simple_memory_arena.h", + ], + copts = tflite_copts(), + deps = [ + ":builtin_op_data", + ":context", + ":schema_fbs_version", + "//tensorflow/contrib/lite/kernels:gemm_support", + "//tensorflow/contrib/lite/nnapi:nnapi_lib", + "//tensorflow/contrib/lite/schema:schema_fbs", + "//tensorflow/core:lib_platform", + ], +) + +cc_library( + name = "string_util", + srcs = ["string_util.cc"], + hdrs = ["string_util.h"], + deps = [ + ":framework", + ":string", + ], +) + +cc_test( + name = "string_util_test", + size = "small", + srcs = ["string_util_test.cc"], + deps = [ + ":framework", + ":string_util", + "//tensorflow/contrib/lite/testing:util", + "@com_google_googletest//:gtest", + ], +) + +# Test main interpreter +cc_test( + name = "interpreter_test", + size = "small", + srcs = ["interpreter_test.cc"], + deps = [ + ":framework", + ":string_util", + "@com_google_googletest//:gtest", + ], +) + +# Test arena allocator +cc_test( + name = "simple_memory_arena_test", + size = "small", + srcs = ["simple_memory_arena_test.cc"], + deps = [ + ":framework", + "//tensorflow/contrib/lite/testing:util", + "@com_google_googletest//:gtest", + ], +) + +# Test model framework. +cc_test( + name = "model_test", + size = "small", + srcs = ["model_test.cc"], + data = [ + "testdata/0_subgraphs.bin", + "testdata/2_subgraphs.bin", + "testdata/empty_model.bin", + "testdata/test_model.bin", + "testdata/test_model_broken.bin", + ], + deps = [ + ":framework", + "//tensorflow/contrib/lite/testing:util", + "@com_google_googletest//:gtest", + ], +) + +# Test the C extension API code. +cc_test( + name = "context_test", + size = "small", + srcs = ["context_test.cc"], + deps = [ + ":framework", + "//tensorflow/contrib/lite/testing:util", + "@com_google_googletest//:gtest", + ], +) + +# Test the serialization of a model with optional tensors. + +# Model tests + +cc_library( + name = "models_test_utils", + testonly = 1, + hdrs = ["models/test_utils.h"], + deps = select({ + "//tensorflow:android": [], + "//conditions:default": [ + "@com_google_absl//absl/strings", + "//tensorflow/core:test", + ], + }), +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + "downloads", + "examples", + "gen", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/lite/Makefile b/tensorflow/contrib/lite/Makefile new file mode 100644 index 0000000000000000000000000000000000000000..78402727abdd2742ffff54bf59ca076d8b97b042 --- /dev/null +++ b/tensorflow/contrib/lite/Makefile @@ -0,0 +1,147 @@ + +# Find where we're running from, so we can store generated files here. +ifeq ($(origin MAKEFILE_DIR), undefined) + MAKEFILE_DIR := $(shell dirname $(realpath $(lastword $(MAKEFILE_LIST)))) +endif + +# Try to figure out the host system +HOST_OS := +ifeq ($(OS),Windows_NT) + HOST_OS = WINDOWS +else + UNAME_S := $(shell uname -s) + ifeq ($(UNAME_S),Linux) + HOST_OS := LINUX + endif + ifeq ($(UNAME_S),Darwin) + HOST_OS := OSX + endif +endif + +ARCH := $(shell if [[ $(shell uname -m) =~ i[345678]86 ]]; then echo x86_32; else echo $(shell uname -m); fi) + +# Where compiled objects are stored. +OBJDIR := $(MAKEFILE_DIR)/gen/obj/ +BINDIR := $(MAKEFILE_DIR)/gen/bin/ +LIBDIR := $(MAKEFILE_DIR)/gen/lib/ +GENDIR := $(MAKEFILE_DIR)/gen/obj/ + +# Settings for the host compiler. +CXX := $(CC_PREFIX) gcc +CXXFLAGS := --std=c++11 -O3 -DNDEBUG +CC := $(CC_PREFIX) gcc +CFLAGS := +LDOPTS := +LDOPTS += -L/usr/local/lib +ARFLAGS := -r + +INCLUDES := \ +-I. \ +-I$(MAKEFILE_DIR)/../../../ \ +-I$(MAKEFILE_DIR)/downloads/ \ +-I$(MAKEFILE_DIR)/downloads/eigen \ +-I$(MAKEFILE_DIR)/downloads/gemmlowp \ +-I$(MAKEFILE_DIR)/downloads/neon_2_sse \ +-I$(MAKEFILE_DIR)/downloads/farmhash/src \ +-I$(MAKEFILE_DIR)/downloads/flatbuffers/include \ +-I$(GENDIR) +# This is at the end so any globally-installed frameworks like protobuf don't +# override local versions in the source tree. +INCLUDES += -I/usr/local/include + +LIBS := \ +-lstdc++ \ +-lpthread \ +-lm \ +-lz + +# If we're on Linux, also link in the dl library. +ifeq ($(OS),LINUX) + LIBS += -ldl -lpthread +endif + +include $(MAKEFILE_DIR)/ios_makefile.inc + +# This library is the main target for this makefile. It will contain a minimal +# runtime that can be linked in to other programs. +LIB_NAME := libtensorflow-lite.a +LIB_PATH := $(LIBDIR)$(LIB_NAME) + +# A small example program that shows how to link against the library. +BENCHMARK_PATH := $(BINDIR)benchmark_model + +BENCHMARK_SRCS := \ +tensorflow/contrib/lite/tools/benchmark_model.cc +BENCHMARK_OBJS := $(addprefix $(OBJDIR), \ +$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(BENCHMARK_SRCS)))) + +# What sources we want to compile, must be kept in sync with the main Bazel +# build files. + +CORE_CC_ALL_SRCS := \ +$(wildcard tensorflow/contrib/lite/*.cc) \ +$(wildcard tensorflow/contrib/lite/kernels/*.cc) \ +$(wildcard tensorflow/contrib/lite/kernels/internal/*.cc) \ +$(wildcard tensorflow/contrib/lite/kernels/internal/optimized/*.cc) \ +$(wildcard tensorflow/contrib/lite/kernels/internal/reference/*.cc) \ +$(wildcard tensorflow/contrib/lite/*.c) \ +$(wildcard tensorflow/contrib/lite/kernels/*.c) \ +$(wildcard tensorflow/contrib/lite/kernels/internal/*.c) \ +$(wildcard tensorflow/contrib/lite/kernels/internal/optimized/*.c) \ +$(wildcard tensorflow/contrib/lite/kernels/internal/reference/*.c) \ +$(wildcard tensorflow/contrib/lite/downloads/farmhash/src/farmhash.cc) +# Remove any duplicates. +CORE_CC_ALL_SRCS := $(sort $(CORE_CC_ALL_SRCS)) +CORE_CC_EXCLUDE_SRCS := \ +$(wildcard tensorflow/contrib/lite/*test.cc) \ +$(wildcard tensorflow/contrib/lite/*/*test.cc) \ +$(wildcard tensorflow/contrib/lite/*/*/*test.cc) \ +$(wildcard tensorflow/contrib/lite/*/*/*/*test.cc) \ +$(wildcard tensorflow/contrib/lite/kernels/test_util.cc) \ +$(BENCHMARK_SRCS) +# Filter out all the excluded files. +TF_LITE_CC_SRCS := $(filter-out $(CORE_CC_EXCLUDE_SRCS), $(CORE_CC_ALL_SRCS)) +# File names of the intermediate files target compilation generates. +TF_LITE_CC_OBJS := $(addprefix $(OBJDIR), \ +$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(TF_LITE_CC_SRCS)))) +LIB_OBJS := $(TF_LITE_CC_OBJS) + +# For normal manually-created TensorFlow C++ source files. +$(OBJDIR)%.o: %.cc + @mkdir -p $(dir $@) + $(CXX) $(CXXFLAGS) $(INCLUDES) -c $< -o $@ + +# For normal manually-created TensorFlow C++ source files. +$(OBJDIR)%.o: %.c + @mkdir -p $(dir $@) + $(CC) $(CCFLAGS) $(INCLUDES) -c $< -o $@ + +# The target that's compiled if there's no command-line arguments. +all: $(LIB_PATH) $(BENCHMARK_PATH) + +# Gathers together all the objects we've compiled into a single '.a' archive. +$(LIB_PATH): $(LIB_OBJS) + @mkdir -p $(dir $@) + $(AR) $(ARFLAGS) $(LIB_PATH) $(LIB_OBJS) + +$(BENCHMARK_PATH): $(BENCHMARK_OBJS) $(LIB_PATH) + @mkdir -p $(dir $@) + $(CXX) $(CXXFLAGS) $(INCLUDES) \ + -o $(BENCHMARK_PATH) $(BENCHMARK_OBJS) \ + $(LIBFLAGS) $(LIB_PATH) $(LDFLAGS) $(LIBS) + +# Gets rid of all generated files. +clean: + rm -rf $(MAKEFILE_DIR)/gen + +# Gets rid of target files only, leaving the host alone. Also leaves the lib +# directory untouched deliberately, so we can persist multiple architectures +# across builds for iOS and Android. +cleantarget: + rm -rf $(OBJDIR) + rm -rf $(BINDIR) + +$(DEPDIR)/%.d: ; +.PRECIOUS: $(DEPDIR)/%.d + +-include $(patsubst %,$(DEPDIR)/%.d,$(basename $(TF_CC_SRCS))) diff --git a/tensorflow/contrib/lite/README.md b/tensorflow/contrib/lite/README.md new file mode 100644 index 0000000000000000000000000000000000000000..2fb40070cb25df16d32569ca764c181bf6333506 --- /dev/null +++ b/tensorflow/contrib/lite/README.md @@ -0,0 +1,222 @@ +# TensorFlow Lite +TensorFlow Lite is TensorFlow's lightweight solution for mobile and embedded devices. It enables low-latency inference of on-device machine learning models with a small binary size and fast performance supporting hardware acceleration. + +TensorFlow Lite uses many techniques for achieving low latency like optimizing the kernels for specific mobile apps, pre-fused activations, quantized kernels that allow smaller and faster (fixed-point math) models, and in the future, leverage specialized machine learning hardware to get the best possible performance for a particular model on a particular device. + +![image](g3doc/TFLite-Architecture.jpg) +# Getting Started with an Android Demo App + +This section contains an example application using TensorFlow Lite for Android devices. The demo is a sample camera app that classifies images continuously using a quantized Mobilenet model. A device running Android 5.0 ( API 21) or higher is required to run the demo. + +There are 3 ways to get the demo app to your device + - Download the prebuilt binary or + - Use Android Studio to build the application or + - Download the source code for TensorFlow Lite and the demo and build it using bazel + +## Description +In the demo app, inference is done using the TensorFlow Lite Java API. The demo app classifies frames in real-time, displaying the top most probable classifications. It also displays the time taken to detect the object. + +## Downloading the pre-built binary +The fastest path to trying the demo, is to download the pre-built binary +[TfLiteCameraDemo.apk](https://storage.googleapis.com/download.tensorflow.org/deps/tflite/TfLiteCameraDemo.apk) + +Once the apk is installed, click the app icon to start the app. The first-time the app is opened, the app asks for runtime permissions to access the device camera. The demo app opens the back-camera of the device and recognizes the objects in the camera's field of view. At the bottom of the image (or at the left of the image if the device is in landscape mode), it shows the latency of classification and the top three objects classified. + +## Building in Android Studio using TensorFlow Lite AAR from JCenter +The simplest way to compile the demo app, and try out changes to the project code is to use AndroidStudio. + + - Install the latest version of Android Studio 3 as specified [here](https://developer.android.com/studio/index.html). + - Make sure the Android SDK version is greater than 26 and NDK version is greater than 14 (in the Android Studio Settings). + - Import the `tensorflow/contrib/lite/java/demo` directory as a new Android Studio project. + - Click through installing all the Gradle extensions it requests. + - Download the quantized Mobilenet TensorFlow Lite model from [here](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip) + - unzip and copy mobilenet_quant_v1_224.tflite to the assets directory: + `tensorflow/contrib/lite/java/demo/app/src/main/assets/` + - Build and run the demo app + +## Building TensorFlow Lite and the demo app from source + +### Clone the TensorFlow repo +- git clone + [https://github.com/tensorflow/tensorflow](https://github.com/tensorflow/tensorflow) + +### Install Bazel +If bazel is not installed on your system, install it now by following [these directions](https://bazel.build/versions/master/docs/install.html) + +NOTE: Bazel does not fully support building Android on Windows yet. Full support for Gradle/CMake builds is coming soon, but in the meantime Windows users should download the [prebuilt binary](https://storage.googleapis.com/download.tensorflow.org/deps/tflite/TfLiteCameraDemo.apk) instead. + +### Install Android NDK and SDK +Bazel is the primary build system for TensorFlow. Bazel and the Android NDK and SDK must be installed on your system. + - Install the latest version of Bazel as per the instructions on the [Bazel website](https://bazel.build/versions/master/docs/install.html) + - The Android NDK is required to build the native (C/C++) TensorFlow Lite code. The current recommended version is 14b, which can be found [here](https://developer.android.com/ndk/downloads/older_releases.html#ndk-14b-downloads). + - The Android SDK and build tools may be obtained [here](https://developer.android.com/tools/revisions/build-tools.html), or alternatively as part of [Android Studio](https://developer.android.com/studio/index.html). Build tools API >= 23 is required to build the TF Android demo (though it will run on API >= 21 devices). + - In the root of the TensorFlow repository update the `WORKSPACE` file with the `api_level` and location of the SDK and NDK. If you installed it with AndroidStudio the SDK path can be found in the SDK manager, and the default NDK path is:`{SDK path}/ndk-bundle.` + +``` +android_sdk_repository ( + name = "androidsdk", + api_level = 23, + build_tools_version = "23.0.2", + path = "/home/xxxx/android-sdk-linux/", +) + +android_ndk_repository( + name = "androidndk", + path = "/home/xxxx/android-ndk-r10e/", + api_level = 19, +) +``` + +Additional details on building with Android can be found [here](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/java/demo/README.md). + +### Build the source code +Run bazel with the following command to build the demo. + +Build the demo app: + +``` +bazel build --cxxopt=--std=c++11 //tensorflow/contrib/lite/java/demo/app/src/main:TfLiteCameraDemo +``` + +### Note + +Currently, we only support building the Android demo app within a Python 2 +environment (due to a Bazel bug). + +### More about the demo +The demo is resizing each camera image frame to (224 width * 224 height) to match the quantized Mobilenet model being used. The resized image is converted into a ByteBuffer row by row of size 1 * 224 * 224 * 3 bytes, where 1 is the number of images in a batch 224 * 224 is the width and height of the image 3 bytes represents three colors of a pixel. This demo uses the TensorFlow Lite Java inference API for models which take a single input and provide a single output. This outputs a two-dimensional array, with the first dimension being the category index and the second dimension being the confidence of classification. The Mobilenet model has 1001 unique categories and the app sorts the probabilities of all the categories and displays the top three. The Mobilenet quantized model is bundled within the assets directory of the app. + +# iOS Demo App + +Similar to the Android demo app, there's an iOS camera app that uses exactly the same model (224 * 224 quantized Mobilenet). + +This demo app requires a camera so it doesn't work with simulators. It need to be executed on a real iOS device. Follow the instructions to build and run the demo app: + +1. Follow the Building section [here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/g3doc/ios.md#building) to build the universal iOS library for TensorFlow Lite. +1. Install [CocoaPods](https://cocoapods.org/) if it wasn't installed yet: `sudo gem install cocoapods`. +1. Run `pod install` in `tensorflow/contrib/lite/examples/ios/camera` to generate the workspace file. +1. Open the project by running `open tflite_camera_example.xcworkspace`, and build the app in XCode. + +# TensorFlow Lite Quick Start + +## Step 1. Decide which GraphDef to use + Depending on the use case, the developer may choose to use one of the popular + open-sourced models such as InceptionV3 or MobileNets, re-train these models + with their own custom data set or even build their own custom model. + +### Using a pre-trained model + +[MobileNets](https://research.googleblog.com/2017/06/mobilenets-open-source-models-for.html) is a family of mobile-first computer vision models for [TensorFlow](https://www.tensorflow.org/) designed to effectively maximize accuracy while being mindful of the restricted resources for an on-device or embedded application. MobileNets are small, low-latency, low-power models parameterized to meet the resource constraints of a variety of use cases. They can be built upon for classification, detection, embeddings and segmentation similar to how other popular large scale models, such as [Inception](https://arxiv.org/pdf/1602.07261.pdf), are used. Google provides 16 pre-trained [ImageNet](http://www.image-net.org/challenges/LSVRC/) classification checkpoints for MobileNets for use in mobile projects of all sizes. + +[Inception-v3](https://arxiv.org/abs/1512.00567) is an image recognition model which achieves fairly high accuracy in recognizing general objects with 1000 classes, like "Zebra", "Dalmatian", and "Dishwasher". The model extracts general features from input images using a convolutional neural network and classifies them based on those features with fully-connected and softmax layers. + +[On Device Smart Reply](https://research.googleblog.com/2017/02/on-device-machine-intelligence.html) is an on-device model which provides one-touch replies for an incoming text message by suggesting contextually relevant messages. The model is built specifically for memory constrained devices such as watches & phones and it has been successfully used to surface [Smart Replies on Android Wear](https://research.googleblog.com/2017/02/on-device-machine-intelligence.html). Note that this model only works on Android as of now. + +These pre-trained models can be downloaded from [here](g3doc/models.md). + +### Retrain Inception-V3 or MobileNet for a custom data set +The above pre-trained models have been trained on the ImageNet data set, which consists of 1000 predefined classes. A model will need to be re-trained if these classes are not relevant or useful for a given use case. This technique is called transfer learning, which starts with a model that has been already trained on a problem and will then be retrained on a similar problem. Deep learning from scratch can take days, but transfer learning can be done fairly quickly. In order to do this, a developer will need to generate their custom data set labeled with the relevant classes. + +The [TensorFlow for Poets](https://codelabs.developers.google.com/codelabs/tensorflow-for-poets/) codelab walks through this process step-by-step. The retraining code supports retraining for both floating point and quantized inference. + + +### Train a custom model +A developer may choose to train a custom model using Tensorflow. TensorFlow documentation has [several tutorials](https://www.tensorflow.org/tutorials/) for building and training models. If the user has written a model using TensorFlow's Slim Framework the first step is to export this to a GraphDef file. This is necessary because Slim does not store the model structure outside the code, so to communicate with other parts of the framework it needs to be exported. Documentation for the export can be found [here](https://github.com/tensorflow/models/tree/master/research/slim#Export). The output of this step will be a .pb file for the custom model. + +TensorFlow Lite currently supports a subset of TensorFlow operators. Please refer to [this document](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md) for details of supported operators and their usage. This +set will continue to expand in future releases of Tensorflow Lite. + + +## Step 2. Model format conversion + +The model generated in Step 1 is a standard Tensorflow model. After the completion of Step 1 a user should have a standard .pb or .pbtxt GraphDef file. If the application developer is using a pre-trained model (as defined in Step 1 above), they can download a ready to use, already converted model for use from [here](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/g3doc/models.md). Models generated using retraining (aka transfer learning) or custom models will need to be converted using the steps mentioned below. + +A prerequisite to converting the model to the Tensorflow Lite format is to freeze the graph. + +Since we employ several formats, the following definitions may be useful: + - GraphDef (.pb) - a protobuf that represents the TensorFlow training and or computation graph. This contains operators, tensors, and variables definitions. + + - CheckPoint (.ckpt) - Serialized variables from a TensorFlow graph. Note, this does not contain the graph structure, so alone it cannot typically be interpreted. + + - FrozenGraphDef - a subclass of GraphDef that contains no variables. A GraphDef can be converted to a frozen graphdef by taking a checkpoint and a graphdef and converting every variable into a constant with the value looked up in the checkpoint. + + - SavedModel - A collection of GraphDef and CheckPoint together with a signature that labels input and output arguments to a model. A GraphDef and Checkpoint can be extracted from a saved model. + + - TensorFlow lite model (.lite) - a serialized flatbuffer, containing TensorFlow lite operators and Tensors for the TensorFlow lite interpreter. This is most analogous to TensorFlow frozen GraphDefs. + +### Freeze Graph +To use this .pb GraphDef file within TensorFlow Lite, the application developer will need checkpoints containing trained weight parameters. The .pb contains only the structure of the graph. The process of merging the checkpoint values with the graph structure is known as "freezing" the graph. + +The developer should know where the checkpoints folder is present or checkpoints can also be downloaded for a pre-trained model (Example: Here is a link to the [MobileNets](https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet_v1.md)). + +Graph freezing can be done using the command below (and modifying the arguments appropriately) + +``` +bazel build tensorflow/python/tools:freeze_graph + +bazel-bin/tensorflow/python/tools/freeze_graph\ + --input_graph=/tmp/mobilenet_v1_224.pb \ + --input_checkpoint=/tmp/checkpoints/mobilenet-10202.ckpt \ + --input_binary=true --output_graph=/tmp/frozen_mobilenet_v1_224.pb \ + --output_node_names=MobileNet/Predictions/Reshape_1 +``` + +The user has to first build the freeze_graph script using bazel and then run the script. The input_binary flag has to be enabled to ensure that the protobuf is read and written in binary format. The user has to input the .pb and the .ckpt files to freeze the graph The output_node_names may not be obvious outside of the code that built the model. The easiest way to find them is to visualize the graph, either with +graphviz, or [in tensorboard](https://codelabs.developers.google.com/codelabs/tensorflow-for-poets-2/#3). + +This frozen Graphdef is now ready to be converted to flatbuffer format (.lite) for use on Android or iOS. On Android users have the flexibility to use either the float or quantized versions of the frozen graphdef, if available, using the Tensorflow Optimizing Converter tool. + +Here is a sample command line to convert the frozen Graphdef to '.lite' format for The Tensorflow Optimizing Converter supports both float and quantized models, however, different configuration parameters are needed depending on whether a FLOAT or QUANTIZED mode is being used. +(Here is a link to the pb [file](https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_1.0_224_frozen.tgz)). + +``` +bazel build tensorflow/contrib/lite/toco:toco + +bazel-bin/tensorflow/contrib/lite/toco/toco -- \ + --input_file=$(pwd)/mobilenet_v1_1.0_224/frozen_graph.pb \ + --input_format=TENSORFLOW_GRAPHDEF --output_format=TFLITE \ + --output_file=/tmp/mobilenet_v1_1.0_224.lite --inference_type=FLOAT \ + --input_type=FLOAT --input_arrays=input \ + --output_arrays=MobilenetV1/Predictions/Reshape_1 --input_shapes=1,224,224,3 +``` + +- The input_file argument should point to the frozen GraphDef file that holds the model architecture. +- The output_file argument should point to where the TensorFlow Lite model file should be generated. +- The input_type and inference_type arguments should be set to FLOAT, unless converted a [quantized](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/g3doc/) model. +- Setting the input_array, output_array and input_shape arguments are a bit trickier. The easiest way to find these values is to explore the graph in tensorboard . The user should reuse the arguments that were used for specifying the output nodes for inference in the `freeze_graph`step. + +Note, it is also possible to use the Tensorflow Optimizing Converter through protos either from Python or from the command line see the +documentation [here](https://github.com/tensorflow/tensorflow/tree/mastertensorflow/contrib/lite/python:toco_from_protos target) A developer can then integrate the conversion step into their model design workflow to ensure that a model will be easily convertible to a mobile inference graph. For example, + +``` +import tensorflow as tf + +img = tf.placeholder(name="img", dtype=tf.float32, shape=(1, 64, 64, 3)) +val = img + tf.constant([1., 2., 3.]) + tf.constant([1., 4., 4.]) +out = tf.identity(val, name="out") +with tf.Session() as sess: + tflite_model = tf.contrib.lite.toco_convert(sess.graph_def, [img], [out]) + open("converteds_model.tflite", "wb").write(tflite_model) + +``` +For detailed instructions on how to use the Tensorflow Optimizing Converter, please see [here](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md). + +You may refer to the [Ops compatibility guide](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md) for troubleshooting help. If that doesn't help, please file an [issue](https://github.com/tensorflow/tensorflow/issues). + +## Step 3. Use the TensorFlow Lite model for inference in a mobile app + +After completion of Step 2 the developer should have a .lite model. + +### For Android +Because Android apps need to be written in Java, and core TensorFlow is in C++, a JNI library is provided to interface between the two. Its interface is aimed only at inference, so it provides the ability to load a graph, set up inputs, and run the model to calculate particular outputs. The full documentation for the set of methods can be seen [here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/g3doc/). The demo app is also open sourced on [github](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/java/demo/app). + +The [demo app](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/java/demo/app) uses this interface, so it's a good place to look for example usage. You can also download the prebuilt binary [here](http://download.tensorflow.org/deps/tflite/TfLiteCameraDemo.apk). + +Note that you'd need to follow instructions for installing TensorFlow on Android, setting up bazel and Android Studio outlined [here](https://www.tensorflow.org/mobile/android_build). + +### For iOS +Follow the documentation [here](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/g3doc/ios.md) to get integrate a TFLite model into your app. + +## Core ML support + +Core ML is a machine learning framework used across Apple products. In addition to using Tensorflow Lite models directly in their applications, developers have the option to convert their trained Tensorflow models to the [CoreML](https://developer.apple.com/machine-learning/) format for use on Apple devices. For information on how to use the converter please refer to the [Tensorflow-CoreML converter documentation](https://github.com/tf-coreml/tf-coreml). diff --git a/tensorflow/contrib/lite/allocation.cc b/tensorflow/contrib/lite/allocation.cc new file mode 100644 index 0000000000000000000000000000000000000000..4b322e027d48f4bf9f90d5b873c449d1ec31cc49 --- /dev/null +++ b/tensorflow/contrib/lite/allocation.cc @@ -0,0 +1,122 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/allocation.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/nnapi_delegate.h" + +namespace tflite { + +MMAPAllocation::MMAPAllocation(const char* filename, + ErrorReporter* error_reporter) + : Allocation(error_reporter), mmapped_buffer_(MAP_FAILED) { + mmap_fd_ = open(filename, O_RDONLY); + if (mmap_fd_ == -1) { + error_reporter_->Report("Could not open '%s'.", filename); + return; + } + struct stat sb; + fstat(mmap_fd_, &sb); + buffer_size_bytes_ = sb.st_size; + mmapped_buffer_ = + mmap(nullptr, buffer_size_bytes_, PROT_READ, MAP_SHARED, mmap_fd_, 0); + if (mmapped_buffer_ == MAP_FAILED) { + error_reporter_->Report("Mmap of '%s' failed.", filename); + return; + } +} + +MMAPAllocation::~MMAPAllocation() { + if (valid()) { + munmap(const_cast(mmapped_buffer_), buffer_size_bytes_); + } + if (mmap_fd_ != -1) close(mmap_fd_); +} + +const void* MMAPAllocation::base() const { return mmapped_buffer_; } + +size_t MMAPAllocation::bytes() const { return buffer_size_bytes_; } + +bool MMAPAllocation::valid() const { return mmapped_buffer_ != MAP_FAILED; } + +FileCopyAllocation::FileCopyAllocation(const char* filename, + ErrorReporter* error_reporter) + : Allocation(error_reporter) { + // Obtain the file size, using an alternative method that is does not + // require fstat for more compatibility. + std::unique_ptr file(fopen(filename, "rb"), fclose); + if (!file) { + error_reporter_->Report("Could not open '%s'.", filename); + return; + } + // TODO(ahentz): Why did you think using fseek here was better for finding + // the size? + struct stat sb; + if (fstat(fileno(file.get()), &sb) != 0) { + error_reporter_->Report("Failed to get file size of '%s'.", filename); + return; + } + buffer_size_bytes_ = sb.st_size; + std::unique_ptr buffer(new char[buffer_size_bytes_]); + if (!buffer) { + error_reporter_->Report("Malloc of buffer to hold copy of '%s' failed.", + filename); + return; + } + size_t bytes_read = + fread(buffer.get(), sizeof(char), buffer_size_bytes_, file.get()); + if (bytes_read != buffer_size_bytes_) { + error_reporter_->Report("Read of '%s' failed (too few bytes read).", + filename); + return; + } + copied_buffer_ = std::move(buffer); +} + +FileCopyAllocation::~FileCopyAllocation() {} + +const void* FileCopyAllocation::base() const { return copied_buffer_.get(); } + +size_t FileCopyAllocation::bytes() const { return buffer_size_bytes_; } + +bool FileCopyAllocation::valid() const { return copied_buffer_ != nullptr; } + +MemoryAllocation::MemoryAllocation(const void* ptr, size_t num_bytes, + ErrorReporter* error_reporter) + : Allocation(error_reporter) { + buffer_ = ptr; + buffer_size_bytes_ = num_bytes; +} + +MemoryAllocation::~MemoryAllocation() {} + +const void* MemoryAllocation::base() const { return buffer_; } + +size_t MemoryAllocation::bytes() const { return buffer_size_bytes_; } + +bool MemoryAllocation::valid() const { return buffer_ != nullptr; } + +} // namespace tflite diff --git a/tensorflow/contrib/lite/allocation.h b/tensorflow/contrib/lite/allocation.h new file mode 100644 index 0000000000000000000000000000000000000000..ee8a7ccd0b232f9e48095567fd4aefe94f595bc3 --- /dev/null +++ b/tensorflow/contrib/lite/allocation.h @@ -0,0 +1,94 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Main abstraction controlling the tflite interpreter. +// See context.h for the API for defining operations (TfLiteRegistration). +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_ALLOCATION_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_ALLOCATION_H_ + +#include +#include +#include +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/simple_memory_arena.h" + +namespace tflite { + +// A memory allocation handle. This could be a mmap or shared memory. +class Allocation { + public: + Allocation(ErrorReporter* error_reporter) : error_reporter_(error_reporter) {} + virtual ~Allocation() {} + + // Base pointer of this allocation + virtual const void* base() const = 0; + // Size in bytes of the allocation + virtual size_t bytes() const = 0; + // Whether the allocation is valid + virtual bool valid() const = 0; + + protected: + ErrorReporter* error_reporter_; +}; + +class MMAPAllocation : public Allocation { + public: + MMAPAllocation(const char* filename, ErrorReporter* error_reporter); + virtual ~MMAPAllocation(); + const void* base() const override; + size_t bytes() const override; + bool valid() const override; + + protected: + // Data required for mmap. + int mmap_fd_ = -1; // mmap file descriptor + const void* mmapped_buffer_; + size_t buffer_size_bytes_ = 0; +}; + +class FileCopyAllocation : public Allocation { + public: + FileCopyAllocation(const char* filename, ErrorReporter* error_reporter); + virtual ~FileCopyAllocation(); + const void* base() const override; + size_t bytes() const override; + bool valid() const override; + + private: + // Data required for mmap. + std::unique_ptr copied_buffer_; + size_t buffer_size_bytes_ = 0; +}; + +class MemoryAllocation : public Allocation { + public: + // Allocates memory with the pointer and the number of bytes of the memory. + // The pointer has to remain alive and unchanged until the destructor is + // called. + MemoryAllocation(const void* ptr, size_t num_bytes, + ErrorReporter* error_reporter); + virtual ~MemoryAllocation(); + const void* base() const override; + size_t bytes() const override; + bool valid() const override; + + private: + const void* buffer_; + size_t buffer_size_bytes_ = 0; +}; + +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_ALLOCATION_H_ diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl new file mode 100644 index 0000000000000000000000000000000000000000..d1fcdce70a34393defce0f2d0f6d5bb53f21c45e --- /dev/null +++ b/tensorflow/contrib/lite/build_def.bzl @@ -0,0 +1,235 @@ +"""Generate Flatbuffer binary from json.""" + +def tflite_copts(): + """Defines compile time flags.""" + copts = [ + "-DFARMHASH_NO_CXX_STRING", + ] + select({ + "//tensorflow:android_arm64": [ + "-std=c++11", + "-O3", + ], + "//tensorflow:android_arm": [ + "-mfpu=neon", + "-mfloat-abi=softfp", + "-std=c++11", + "-O3", + ], + "//tensorflow:android_x86": [ + "-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK", + ], + "//tensorflow:ios_x86_64": [ + "-msse4.1", + ], + "//conditions:default": [], + }) + select({ + "//tensorflow:with_default_optimizations": [], + "//conditions:default": ["-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK"], + }) + + return copts + +LINKER_SCRIPT = "//tensorflow/contrib/lite/java/src/main/native:version_script.lds" + +def tflite_linkopts_unstripped(): + """Defines linker flags to reduce size of TFLite binary. + + These are useful when trying to investigate the relative size of the + symbols in TFLite. + + Returns: + a select object with proper linkopts + """ + return select({ + "//tensorflow:android": [ + "-Wl,--no-export-dynamic", # Only inc syms referenced by dynamic obj. + "-Wl,--exclude-libs,ALL", # Exclude syms in all libs from auto export. + "-Wl,--gc-sections", # Eliminate unused code and data. + "-Wl,--as-needed", # Don't link unused libs. + ], + "//tensorflow/contrib/lite:mips": [], + "//tensorflow/contrib/lite:mips64": [], + "//conditions:default": [ + "-Wl,--icf=all", # Identical code folding. + ], + }) + +def tflite_jni_linkopts_unstripped(): + """Defines linker flags to reduce size of TFLite binary with JNI. + + These are useful when trying to investigate the relative size of the + symbols in TFLite. + + Returns: + a select object with proper linkopts + """ + return select({ + "//tensorflow:android": [ + "-Wl,--gc-sections", # Eliminate unused code and data. + "-Wl,--as-needed", # Don't link unused libs. + ], + "//tensorflow/contrib/lite:mips": [], + "//tensorflow/contrib/lite:mips64": [], + "//conditions:default": [ + "-Wl,--icf=all", # Identical code folding. + ], + }) + +def tflite_linkopts(): + """Defines linker flags to reduce size of TFLite binary.""" + return tflite_linkopts_unstripped() + select({ + "//tensorflow:android": [ + "-s", # Omit symbol table. + ], + "//conditions:default": [], + }) + +def tflite_jni_linkopts(): + """Defines linker flags to reduce size of TFLite binary with JNI.""" + return tflite_jni_linkopts_unstripped() + select({ + "//tensorflow:android": [ + "-s", # Omit symbol table. + "-latomic", # Required for some uses of ISO C++11 in x86. + ], + "//conditions:default": [], + }) + + +def tflite_jni_binary(name, + copts=tflite_copts(), + linkopts=tflite_jni_linkopts(), + linkscript=LINKER_SCRIPT, + linkshared=1, + linkstatic=1, + deps=[]): + """Builds a jni binary for TFLite.""" + linkopts = linkopts + [ + "-Wl,--version-script", # Export only jni functions & classes. + linkscript, + ] + native.cc_binary( + name=name, + copts=copts, + linkshared=linkshared, + linkstatic=linkstatic, + deps= deps + [linkscript], + linkopts=linkopts) + +def tf_to_tflite(name, src, options, out): + """Convert a frozen tensorflow graphdef to TF Lite's flatbuffer. + + Args: + name: Name of rule. + src: name of the input graphdef file. + options: options passed to TOCO. + out: name of the output flatbuffer file. + """ + + toco = "//tensorflow/contrib/lite/toco:toco" + native.genrule( + name = name, + srcs=[src, options], + outs=[out], + cmd = ("$(location %s) " + + " --input_file=$(location %s) " + + " --output_file=$(location %s) " + + " --input_format=TENSORFLOW_GRAPHDEF" + + " --output_format=TFLITE" + + " `cat $(location %s)`") + % (toco, src, out, options), + tools= [toco], + ) + +def tflite_to_json(name, src, out): + """Convert a TF Lite flatbuffer to JSON. + + Args: + name: Name of rule. + src: name of the input flatbuffer file. + out: name of the output JSON file. + """ + + flatc = "@flatbuffers//:flatc" + schema = "//tensorflow/contrib/lite/schema:schema.fbs" + native.genrule( + name = name, + srcs = [schema, src], + outs = [out], + cmd = ("TMP=`mktemp`; cp $(location %s) $${TMP}.bin &&" + + "$(location %s) --raw-binary --strict-json -t" + + " -o /tmp $(location %s) -- $${TMP}.bin &&" + + "cp $${TMP}.json $(location %s)") + % (src, flatc, schema, out), + tools = [flatc], + ) + +def json_to_tflite(name, src, out): + """Convert a JSON file to TF Lite's flatbuffer. + + Args: + name: Name of rule. + src: name of the input JSON file. + out: name of the output flatbuffer file. + """ + + flatc = "@flatbuffers//:flatc" + schema = "//tensorflow/contrib/lite/schema:schema_fbs" + native.genrule( + name = name, + srcs = [schema, src], + outs = [out], + cmd = ("TMP=`mktemp`; cp $(location %s) $${TMP}.json &&" + + "$(location %s) --raw-binary --unknown-json --allow-non-utf8 -b" + + " -o /tmp $(location %s) $${TMP}.json &&" + + "cp $${TMP}.bin $(location %s)") + % (src, flatc, schema, out), + tools = [flatc], + ) + +def gen_zipped_test_files(name, files): + """Generate a zip file of tests by using :generate_examples. + + Args: + name: Name of output. We will produce "`name`_files" as a target. + files: A list of zip file basenames. + """ + toco = "//tensorflow/contrib/lite/toco:toco" + out_files = [] + for f in files: + out_file = name + "/" + f + out_files.append(out_file) + native.genrule( + name = name + "_" + f + ".files", + cmd = ("$(locations :generate_examples) --toco $(locations %s) " % toco + + " --zip_to_output " + f + + " $(@D) zipped"), + outs = [out_file], + tools = [ + ":generate_examples", + toco, + ], + ) + + native.filegroup( + name = name, + srcs = out_files, + ) + +def gen_selected_ops(name, model): + """Generate the library that includes only used ops. + + Args: + name: Name of the generated library. + model: TFLite model to interpret. + """ + out = name + "_registration.cc" + tool = "//tensorflow/contrib/lite/tools:generate_op_registrations" + tflite_path = "//tensorflow/contrib/lite" + native.genrule( + name = name, + srcs = [model], + outs = [out], + cmd = ("$(location %s) --input_model=$(location %s) --output_registration=$(location %s) --tflite_path=%s") + % (tool, model, out, tflite_path[2:]), + tools = [tool], + ) diff --git a/tensorflow/contrib/lite/build_ios_universal_lib.sh b/tensorflow/contrib/lite/build_ios_universal_lib.sh new file mode 100755 index 0000000000000000000000000000000000000000..cbc96e6edd4358f6666731caa4c208c77d9c6c54 --- /dev/null +++ b/tensorflow/contrib/lite/build_ios_universal_lib.sh @@ -0,0 +1,31 @@ +#!/bin/bash -x +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +set -e +make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=x86_64 -j 8 +make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=i386 -j 8 +make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=armv7 -j 8 +make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=armv7s -j 8 +make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=arm64 -j 8 + +lipo \ +tensorflow/contrib/lite/gen/lib/ios_x86_64/libtensorflow-lite.a \ +tensorflow/contrib/lite/gen/lib/ios_i386/libtensorflow-lite.a \ +tensorflow/contrib/lite/gen/lib/ios_armv7/libtensorflow-lite.a \ +tensorflow/contrib/lite/gen/lib/ios_armv7s/libtensorflow-lite.a \ +tensorflow/contrib/lite/gen/lib/ios_arm64/libtensorflow-lite.a \ +-create \ +-output tensorflow/contrib/lite/gen/lib/libtensorflow-lite.a diff --git a/tensorflow/contrib/lite/builtin_op_data.h b/tensorflow/contrib/lite/builtin_op_data.h new file mode 100644 index 0000000000000000000000000000000000000000..93072bf90bd8a18d9011a74c2eec95d86dbdce8a --- /dev/null +++ b/tensorflow/contrib/lite/builtin_op_data.h @@ -0,0 +1,164 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_BUILTIN_OP_DATA_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_BUILTIN_OP_DATA_H_ + +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// TODO(aselle): Consider using "if this then that" for testing. + +// Possible padding types (for convolutions) +typedef enum { + kTfLitePaddingUnknown = 0, + kTfLitePaddingSame, + kTfLitePaddingValid, +} TfLitePadding; + +typedef struct { + int width; + int height; +} TfLitePaddingValues; + +// Possible fused activation functions. +// TODO(aselle): rename to TfLiteActivation +typedef enum { + kTfLiteActNone = 0, + kTfLiteActRelu, + kTfLiteActRelu1, + kTfLiteActRelu6, + kTfLiteActTanh, + kTfLiteActSignBit, + kTfLiteActSigmoid, +} TfLiteFusedActivation; + +typedef struct { + TfLitePadding padding; + int stride_width; + int stride_height; + TfLiteFusedActivation activation; +} TfLiteConvParams; + +typedef struct { + TfLitePadding padding; + int stride_width; + int stride_height; + int filter_width; + int filter_height; + TfLiteFusedActivation activation; + struct { + TfLitePaddingValues padding; + } computed; +} TfLitePoolParams; + +typedef struct { + TfLitePadding padding; + int stride_width; + int stride_height; + int depth_multiplier; + TfLiteFusedActivation activation; +} TfLiteDepthwiseConvParams; + +typedef struct { + int rank; + TfLiteFusedActivation activation; +} TfLiteSVDFParams; + +typedef struct { + TfLiteFusedActivation activation; +} TfLiteRNNParams; + +typedef struct { TfLiteFusedActivation activation; } TfLiteFullyConnectedParams; + +typedef enum { + kTfLiteLshProjectionUnknown = 0, + kTfLiteLshProjectionSparse = 1, + kTfLiteLshProjectionDense = 2, +} TfLiteLSHProjectionType; + +typedef struct { TfLiteLSHProjectionType type; } TfLiteLSHProjectionParams; + +typedef struct { float beta; } TfLiteSoftmaxParams; + +typedef struct { + int axis; + TfLiteFusedActivation activation; +} TfLiteConcatenationParams; + +typedef struct { + TfLiteFusedActivation activation; +} TfLiteAddParams; + +typedef struct { + TfLiteFusedActivation activation; +} TfLiteMulParams; + +typedef struct { + TfLiteFusedActivation activation; +} TfLiteL2NormParams; + +typedef struct { + int radius; + float bias; + float alpha; + float beta; +} TfLiteLocalResponseNormParams; + +typedef struct { + TfLiteFusedActivation activation; + float cell_clip; + float proj_clip; +} TfLiteLSTMParams; + +typedef struct { + int new_height; + int new_width; +} TfLiteResizeBilinearParams; + +typedef struct { + // TODO(ahentz): We can't have dynamic data in this struct, at least not yet. + // For now we will fix the maximum possible number of dimensions. + int shape[8]; + int num_dimensions; +} TfLiteReshapeParams; + +typedef struct { + int ngram_size; + int max_skip_size; + bool include_all_ngrams; +} TfLiteSkipGramParams; + +typedef struct { + int block_size; +} TfLiteSpaceToDepthParams; + +typedef enum { + kTfLiteCombinerTypeSum = 0, + kTfLiteCombinerTypeMean = 1, + kTfLiteCombinerTypeSqrtn = 2, +} TfLiteCombinerType; + +typedef struct { + TfLiteCombinerType combiner; +} TfLiteEmbeddingLookupSparseParams; + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_BUILTIN_OP_DATA_H_ diff --git a/tensorflow/contrib/lite/context.c b/tensorflow/contrib/lite/context.c new file mode 100644 index 0000000000000000000000000000000000000000..c09e838c5c2e50e0f4a38eaf66e55246fd9a6f7f --- /dev/null +++ b/tensorflow/contrib/lite/context.c @@ -0,0 +1,92 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/context.h" +#include +#include + +TfLiteIntArray* TfLiteIntArrayCreate(int size) { + TfLiteIntArray* ret = + (TfLiteIntArray*)malloc(sizeof(*ret) + sizeof(ret->data[0]) * size); + ret->size = size; + return ret; +} + +void TfLiteIntArrayPrint(const char* s, TfLiteIntArray* a) { + printf("%s: length=%d [", s, a->size); + if (a->size) printf("%d", a->data[0]); + int i = 1; + for (; i < a->size; i++) { + printf(" %d", a->data[i]); + } + printf("]\n"); +} + +int TfLiteIntArrayEqual(TfLiteIntArray* a, TfLiteIntArray* b) { + if (a == b) return 1; + if (a == NULL || b == NULL) return 0; + if (a->size != b->size) return 0; + int i = 0; + for (; i < a->size; i++) + if (a->data[i] != b->data[i]) return 0; + return 1; +} + +TfLiteIntArray* TfLiteIntArrayCopy(TfLiteIntArray* src) { + if (!src) return NULL; + TfLiteIntArray* ret = TfLiteIntArrayCreate(src->size); + if (ret) { + memcpy(ret->data, src->data, src->size * sizeof(int)); + } + return ret; +} + +void TfLiteIntArrayFree(TfLiteIntArray* a) { free(a); } + +void TfLiteTensorFree(TfLiteTensor* t) { + if (t->allocation_type == kTfLiteDynamic && t->data.raw) { + free(t->data.raw); + } + if (t->dims) TfLiteIntArrayFree(t->dims); + t->data.raw = NULL; + t->dims = NULL; +} + +void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims, + TfLiteQuantizationParams quantization, char* buffer, + size_t size, TfLiteAllocationType allocation_type, + const void* allocation, TfLiteTensor* tensor) { + TfLiteTensorFree(tensor); + tensor->type = type; + tensor->name = name; + tensor->dims = dims; + tensor->params = quantization; + tensor->data.raw = buffer; + tensor->bytes = size; + tensor->allocation_type = allocation_type; + tensor->allocation = allocation; +} + +void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor) { + if (tensor->allocation_type != kTfLiteDynamic) { + return; + } + if (!tensor->data.raw) { + tensor->data.raw = malloc(num_bytes); + } else if (num_bytes > tensor->bytes) { + tensor->data.raw = realloc(tensor->data.raw, num_bytes); + } + tensor->bytes = num_bytes; +} diff --git a/tensorflow/contrib/lite/context.h b/tensorflow/contrib/lite/context.h new file mode 100644 index 0000000000000000000000000000000000000000..41257a53b145cbe7e252c9d4de6ea7ef654431b5 --- /dev/null +++ b/tensorflow/contrib/lite/context.h @@ -0,0 +1,298 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// This file defines a C API for implementing operations in tflite. +// These operations can be defined using c++ but the interface between +// the interpreter and the operations are C. +// +// Summary of abstractions +// TF_LITE_ENSURE - Self-sufficient error checking +// TfLiteStatus - Status reporting +// TfLiteIntArray - stores tensor shapes (dims), +// TfLiteContext - allows an op to access the tensors +// TfLiteTensor - tensor (a multidimensional array) +// TfLiteNode - a single node or operation +// TfLiteRegistration - the implementation of a conceptual operation. +// +// Some abstractions in this file are created and managed by Interpreter. +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_CONTEXT_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_CONTEXT_H_ + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +typedef enum { kTfLiteOk = 0, kTfLiteError = 1 } TfLiteStatus; + +#define kOptionalTensor (-1) + +// Fixed size list of integers. Used for dimensions and inputs/outputs tensor +// indices +typedef struct { + int size; +// gcc 6.1+ have a bug where flexible members aren't properly handled +// https://github.com/google/re2/commit/b94b7cd42e9f02673cd748c1ac1d16db4052514c +#if !defined(__clang__) && defined(__GNUC__) && __GNUC__ == 6 && \ + __GNUC_MINOR__ >= 1 + int data[0]; +#else + int data[]; +#endif +} TfLiteIntArray; + +// Create a array of a given `size` (uninitialized entries). +// This returns a pointer, that you must free using TfLiteIntArrayFree(). +TfLiteIntArray* TfLiteIntArrayCreate(int size); + +// Check if two tensors are equal. Returns 1 if they are equal, 0 otherwise. +int TfLiteIntArrayEqual(TfLiteIntArray* a, TfLiteIntArray* b); + +// Create a copy of an array passed as `src`. +// You are expected to free memory with TfLiteIntArrayFree +TfLiteIntArray* TfLiteIntArrayCopy(TfLiteIntArray* src); + +// Free memory of array `v`. +void TfLiteIntArrayFree(TfLiteIntArray* v); + +// Since we must not depend on any libraries, define a minimal subset of +// error macros while avoiding names that have pre-conceived meanings like +// assert and check. + +// Check whether value is true, and if not return kTfLiteError from +// the current function (and report the error string msg). +#define TF_LITE_ENSURE_MSG(context, value, msg) \ + do { \ + if (!(value)) { \ + (context)->ReportError((context), __FILE__ " " msg); \ + return kTfLiteError; \ + } \ + } while (0) + +// Check whether the value `a` is true, and if not return kTfLiteError from +// the current function, while also reporting the location of the error. +#define TF_LITE_ENSURE(context, a) \ + do { \ + if (!(a)) { \ + (context)->ReportError((context), "%s:%d %s was not true.", __FILE__, \ + __LINE__, #a); \ + return kTfLiteError; \ + } \ + } while (0) + +#define TF_LITE_ENSURE_STATUS(a) \ + do { \ + if ((a) != kTfLiteOk) { \ + return kTfLiteError; \ + } \ + } while (0) + +// Check whether the value `a == b` is true, and if not return kTfLiteError from +// the current function, while also reporting the location of the error. +// `a` and `b` may be evaluated more than once, so no side effects or +// extremely expensive computations should be done. +#define TF_LITE_ENSURE_EQ(context, a, b) \ + do { \ + if ((a) != (b)) { \ + (context)->ReportError((context), "%s:%d %s != %s (%d != %d)", __FILE__, \ + __LINE__, #a, #b, (a), (b)); \ + return kTfLiteError; \ + } \ + } while (0) + +#define TF_LITE_ENSURE_OK(context, status) \ + do { \ + if ((status) != kTfLiteOk) { \ + return status; \ + } \ + } while (0) + +// Types supported by tensor +typedef enum { + kTfLiteNoType = 0, + kTfLiteFloat32 = 1, + kTfLiteInt32 = 2, + kTfLiteUInt8 = 3, + kTfLiteInt64 = 4, + kTfLiteString = 5, +} TfLiteType; + +// Parameters for asymmetric quantization. Quantized values can be converted +// back to float using: +// real_value = scale * (quantized_value - zero_point); +typedef struct { + float scale; + int32_t zero_point; +} TfLiteQuantizationParams; + +// A union of points that points to memory for a given tensor. +typedef union { + int* i32; + float* f; + char* raw; + const char* raw_const; + uint8_t* uint8; +} TfLitePtrUnion; + +// Memory allocation strategies. kTfLiteMmapRo is for read-only memory-mapped +// data (or data externally allocated). kTfLiteArenaRw is arena allocated +// data. kTfLiteDynamic is for tensors that are allocated during evaluation. +typedef enum { + kTfLiteMemNone = 0, + kTfLiteMmapRo, + kTfLiteArenaRw, + kTfLiteArenaRwPersistent, + kTfLiteDynamic, +} TfLiteAllocationType; + +// An tensor in the interpreter system which is a wrapper around a buffer of +// data including a dimensionality (or NULL if not currently defined). +typedef struct { + // The data type specification for data stored in `data`. This affects + // what member of `data` union should be used. + TfLiteType type; + // A union of data pointers. The appropriate type should be used for a typed + // tensor based on `type`. + TfLitePtrUnion data; + // A pointer to a structure representing the dimensionality interpretation + // that the buffer should have. NOTE: the product of elements of `dims` + // and the element datatype size should be equal to `bytes` below. + TfLiteIntArray* dims; + // Quantization information. + TfLiteQuantizationParams params; + // How memory is mapped + // kTfLiteMmapRo: Memory mapped read only. + // i.e. weights + // kTfLiteArenaRw: Arena allocated read write memory + // (i.e. temporaries, outputs). + TfLiteAllocationType allocation_type; + // The number of bytes required to store the data of this Tensor. I.e. + // (bytes of each element) * dims[0] * ... * dims[n-1]. For example, if + // type is kTfLiteFloat32 and dims = {3, 2} then + // bytes = sizeof(float) * 3 * 2 = 4 * 3 * 2 = 24. + size_t bytes; + + // An opaque pointer to a tflite::MMapAllocation + const void* allocation; + + // Null-terminated name of this tensor. + const char* name; +} TfLiteTensor; + +// Free memory of tensor `t`; +void TfLiteTensorFree(TfLiteTensor* t); + +// Set all of a tensor's fields (and free any previously allocated data). +void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims, + TfLiteQuantizationParams quantization, char* buffer, + size_t size, TfLiteAllocationType allocation_type, + const void* allocation, TfLiteTensor* tensor); + +// Resize the allocated data of a (dynamic) tensor. +void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor); + +typedef struct TfLiteContext { + // Number of tensors in the context. + int tensors_size; + // An tensor of tensors in the interpreter context (of length `tensors_size`) + TfLiteTensor* tensors; + + // opaque full context ptr (an opaque c++ data structure) + void* impl_; + + // Request memory pointer be resized. Updates dimensions on the tensor. + // NOTE: ResizeTensor takes ownership of newSize. + TfLiteStatus (*ResizeTensor)(struct TfLiteContext*, TfLiteTensor* tensor, + TfLiteIntArray* new_size); + // Request that a error be reported with format string msg. + void (*ReportError)(struct TfLiteContext*, const char* msg, ...); + + // Add `tensors_to_add` tensors, preserving pre-existing Tensor entries. If + // non-null, the value pointed to by `first_new_tensor_index` will be set to + // the index of the first new tensor. + TfLiteStatus (*AddTensors)(struct TfLiteContext*, int tensors_to_add, + int* first_new_tensor_index); + + // TODO(ahentz): we should create a more general mechanism for this sort of + // library-global objects. + void* gemm_context; +} TfLiteContext; + +// A structure representing an instance of a node. +// This structure only exhibits the inputs, outputs and user defined data, not +// other features like the type. +typedef struct { + // Inputs to this node expressed as indices into the simulator's tensors. + TfLiteIntArray* inputs; + + // Outputs to this node expressed as indices into the simulator's tensors. + TfLiteIntArray* outputs; + + // Temporary tensors uses during the computations. This usually contains no + // tensors, but ops are allowed to change that if they need scratch space of + // any sort. + TfLiteIntArray* temporaries; + + // Opaque data provided by the node implementer through `Registration.init`. + void* user_data; + + // Opaque data provided to the node if the node is a builtin. + void* builtin_data; +} TfLiteNode; + +typedef struct { + // Initializes the op from serialized data. + // If a built-in op: + // `buffer` is the op's params data (TfLiteLSTMParams*). + // `length` is zero. + // If custom op: + // `buffer` is the op's `custom_options`. + // `length` is the size of the buffer. + // + // Returns a type-punned (i.e. void*) opaque data (e.g. a primitive pointer + // or an instance of a struct). + // + // The returned pointer will be stored with the node in the `user_data` field, + // accessible within prepare and invoke functions below. + // NOTE: if the data is already in the desired format, simply implement this + // function to return `nullptr` and implement the free function to be a no-op. + void* (*init)(TfLiteContext* context, const char* buffer, size_t length); + + // The pointer `buffer` is the data previously returned by an init invocation. + void (*free)(TfLiteContext* context, void* buffer); + + // prepare is called when the inputs this node depends on have been resized. + // context->ResizeTensor() can be called to request output tensors to be + // resized. + // + // Returns kTfLiteOk on success. + TfLiteStatus (*prepare)(TfLiteContext* context, TfLiteNode* node); + + // Execute the node (should read node->inputs and output to node->outputs). + // Returns kTfLiteOk on success. + TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node); + + // Builtin codes. If this kernel refers to a builtin this is the code + // of the builtin. This is so we can do marshaling to other frameworks like + // NN API. Note, it is the responsibility of the registration binder to + // set this properly. + int32_t builtin_code; +} TfLiteRegistration; + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_CONTEXT_H_ diff --git a/tensorflow/contrib/lite/context_test.cc b/tensorflow/contrib/lite/context_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..20d6f69a25e9f0bb4323cf5d067b8ebd37bb3c23 --- /dev/null +++ b/tensorflow/contrib/lite/context_test.cc @@ -0,0 +1,75 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/context.h" +#include +#include "tensorflow/contrib/lite/testing/util.h" + +namespace tflite { + +// NOTE: this tests only the TfLiteIntArray part of context. +// most of context.h is provided in the context of using it with interpreter.h +// and interpreter.cc, so interpreter_test.cc tests context structures more +// thoroughly. + +TEST(IntArray, TestIntArrayCreate) { + TfLiteIntArray* a = TfLiteIntArrayCreate(0); + TfLiteIntArray* b = TfLiteIntArrayCreate(3); + TfLiteIntArrayFree(a); + TfLiteIntArrayFree(b); +} + +TEST(IntArray, TestIntArrayCopy) { + TfLiteIntArray* a = TfLiteIntArrayCreate(2); + a->data[0] = 22; + a->data[1] = 24; + TfLiteIntArray* b = TfLiteIntArrayCopy(a); + ASSERT_NE(a, b); + ASSERT_EQ(a->size, b->size); + ASSERT_EQ(a->data[0], b->data[0]); + ASSERT_EQ(a->data[1], b->data[1]); + TfLiteIntArrayFree(a); + TfLiteIntArrayFree(b); +} + +TEST(IntArray, TestIntArrayEqual) { + TfLiteIntArray* a = TfLiteIntArrayCreate(1); + a->data[0] = 1; + TfLiteIntArray* b = TfLiteIntArrayCreate(2); + b->data[0] = 5; + b->data[1] = 6; + TfLiteIntArray* c = TfLiteIntArrayCreate(2); + c->data[0] = 5; + c->data[1] = 6; + TfLiteIntArray* d = TfLiteIntArrayCreate(2); + d->data[0] = 6; + d->data[1] = 6; + ASSERT_FALSE(TfLiteIntArrayEqual(a, b)); + ASSERT_TRUE(TfLiteIntArrayEqual(b, c)); + ASSERT_TRUE(TfLiteIntArrayEqual(b, b)); + ASSERT_FALSE(TfLiteIntArrayEqual(c, d)); + TfLiteIntArrayFree(a); + TfLiteIntArrayFree(b); + TfLiteIntArrayFree(c); + TfLiteIntArrayFree(d); +} + +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/download_dependencies.sh b/tensorflow/contrib/lite/download_dependencies.sh new file mode 100755 index 0000000000000000000000000000000000000000..7fce1ba3461066e6dada95246781440258d844c1 --- /dev/null +++ b/tensorflow/contrib/lite/download_dependencies.sh @@ -0,0 +1,106 @@ +#!/bin/bash +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +set -e + +DOWNLOADS_DIR=tensorflow/contrib/lite/downloads +BZL_FILE_PATH=tensorflow/workspace.bzl + +# Ensure it is being run from repo root +if [ ! -f $BZL_FILE_PATH ]; then + echo "Could not find ${BZL_FILE_PATH}": + echo "Likely you are not running this from the root directory of the repository."; + exit 1; +fi + +EIGEN_URL="$(grep -o 'http.*bitbucket.org/eigen/eigen/get/.*tar\.gz' "${BZL_FILE_PATH}" | grep -v bazel-mirror | head -n1)" +GEMMLOWP_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/gemmlowp/.*zip' "${BZL_FILE_PATH}" | head -n1)" +GOOGLETEST_URL="https://github.com/google/googletest/archive/release-1.8.0.tar.gz" +ABSL_URL="$(grep -o 'https://github.com/abseil/abseil-cpp/.*tar.gz' "${BZL_FILE_PATH}" | head -n1)" +NEON_2_SSE_URL="https://github.com/intel/ARM_NEON_2_x86_SSE/archive/master.zip" +FARMHASH_URL="https://mirror.bazel.build/github.com/google/farmhash/archive/816a4ae622e964763ca0862d9dbd19324a1eaf45.tar.gz" +FLATBUFFERS_URL="https://github.com/google/flatbuffers/archive/master.zip" +MODELS_URL="https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_1.0_224_ios_lite_float_2017_11_08.zip" +QUANTIZED_MODELS_URL="https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip" + +# TODO(petewarden): Some new code in Eigen triggers a clang bug with iOS arm64, +# so work around it by patching the source. +replace_by_sed() { + local regex="${1}" + shift + # Detect the version of sed by the return value of "--version" flag. GNU-sed + # supports "--version" while BSD-sed doesn't. + if ! sed --version >/dev/null 2>&1; then + # BSD-sed. + sed -i '' -e "${regex}" "$@" + else + # GNU-sed. + sed -i -e "${regex}" "$@" + fi +} + +download_and_extract() { + local usage="Usage: download_and_extract URL DIR" + local url="${1:?${usage}}" + local dir="${2:?${usage}}" + echo "downloading ${url}" >&2 + mkdir -p "${dir}" + if [[ "${url}" == *gz ]]; then + curl -Ls "${url}" | tar -C "${dir}" --strip-components=1 -xz + elif [[ "${url}" == *zip ]]; then + tempdir=$(mktemp -d) + tempdir2=$(mktemp -d) + + curl -L ${url} > ${tempdir}/zipped.zip + unzip ${tempdir}/zipped.zip -d ${tempdir2} + + # If the zip file contains nested directories, extract the files from the + # inner directory. + if ls ${tempdir2}/*/* 1> /dev/null 2>&1; then + # unzip has no strip components, so unzip to a temp dir, and move the + # files we want from the tempdir to destination. + cp -R ${tempdir2}/*/* ${dir}/ + else + cp -R ${tempdir2}/* ${dir}/ + fi + rm -rf ${tempdir2} ${tempdir} + fi + + # Delete any potential BUILD files, which would interfere with Bazel builds. + find "${dir}" -type f -name '*BUILD' -delete +} + +download_and_extract "${EIGEN_URL}" "${DOWNLOADS_DIR}/eigen" +download_and_extract "${GEMMLOWP_URL}" "${DOWNLOADS_DIR}/gemmlowp" +download_and_extract "${GOOGLETEST_URL}" "${DOWNLOADS_DIR}/googletest" +download_and_extract "${ABSL_URL}" "${DOWNLOADS_DIR}/absl" +download_and_extract "${NEON_2_SSE_URL}" "${DOWNLOADS_DIR}/neon_2_sse" +download_and_extract "${FARMHASH_URL}" "${DOWNLOADS_DIR}/farmhash" +download_and_extract "${FLATBUFFERS_URL}" "${DOWNLOADS_DIR}/flatbuffers" +download_and_extract "${MODELS_URL}" "${DOWNLOADS_DIR}/models" +download_and_extract "${QUANTIZED_MODELS_URL}" "${DOWNLOADS_DIR}/quantized_models" + +replace_by_sed 's#static uint32x4_t p4ui_CONJ_XOR = vld1q_u32( conj_XOR_DATA );#static uint32x4_t p4ui_CONJ_XOR; // = vld1q_u32( conj_XOR_DATA ); - Removed by script#' \ + "${DOWNLOADS_DIR}/eigen/Eigen/src/Core/arch/NEON/Complex.h" +replace_by_sed 's#static uint32x2_t p2ui_CONJ_XOR = vld1_u32( conj_XOR_DATA );#static uint32x2_t p2ui_CONJ_XOR;// = vld1_u32( conj_XOR_DATA ); - Removed by scripts#' \ + "${DOWNLOADS_DIR}/eigen/Eigen/src/Core/arch/NEON/Complex.h" +replace_by_sed 's#static uint64x2_t p2ul_CONJ_XOR = vld1q_u64( p2ul_conj_XOR_DATA );#static uint64x2_t p2ul_CONJ_XOR;// = vld1q_u64( p2ul_conj_XOR_DATA ); - Removed by script#' \ + "${DOWNLOADS_DIR}/eigen/Eigen/src/Core/arch/NEON/Complex.h" + +cp ${DOWNLOADS_DIR}/models/models/* tensorflow/contrib/lite/examples/ios/simple/data/ +cp ${DOWNLOADS_DIR}/quantized_models/* tensorflow/contrib/lite/examples/ios/camera/data/ + +echo "download_dependencies.sh completed successfully." >&2 diff --git a/tensorflow/contrib/lite/error_reporter.cc b/tensorflow/contrib/lite/error_reporter.cc new file mode 100644 index 0000000000000000000000000000000000000000..6ba5384a94dbf9de03fb2e4e2f63074525eafa2d --- /dev/null +++ b/tensorflow/contrib/lite/error_reporter.cc @@ -0,0 +1,50 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/error_reporter.h" +#include +#include + +namespace tflite { + +ErrorReporter::~ErrorReporter() {} + +int ErrorReporter::Report(const char* format, ...) { + va_list args; + va_start(args, format); + int code = Report(format, args); + va_end(args); + return code; +} + +// TODO(aselle): Make the name of ReportError on context the same, so +// we can use the ensure functions w/o a context and w/ a reporter. +int ErrorReporter::ReportError(void*, const char* format, ...) { + va_list args; + va_start(args, format); + int code = Report(format, args); + va_end(args); + return code; +} + +int StderrReporter::Report(const char* format, va_list args) { + return vfprintf(stderr, format, args); +} + +ErrorReporter* DefaultErrorReporter() { + static StderrReporter* error_reporter = new StderrReporter; + return error_reporter; +} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/error_reporter.h b/tensorflow/contrib/lite/error_reporter.h new file mode 100644 index 0000000000000000000000000000000000000000..637d456ce7a754c7da34e551869e49b4efd18e3b --- /dev/null +++ b/tensorflow/contrib/lite/error_reporter.h @@ -0,0 +1,54 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_ERROR_REPORTER_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_ERROR_REPORTER_H_ + +#include +#include "tensorflow/contrib/lite/context.h" + +namespace tflite { + +// A functor that reports error to supporting system. Invoked similar to +// printf. +// +// Usage: +// ErrorReporter foo; +// foo.Report("test %d\n", 5); +// or +// va_list args; +// foo.Report("test %d\n", args); // where args is va_list +// +// Sublclass ErrorReporter to provide another reporting destination. +// For example, if you have a GUI program, you might redirect to a buffer +// that drives a GUI error log box. +class ErrorReporter { + public: + virtual ~ErrorReporter(); + virtual int Report(const char* format, va_list args) = 0; + int Report(const char* format, ...); + int ReportError(void*, const char* format, ...); +}; + +// An error reporter that simplify writes the message to stderr. +struct StderrReporter : public ErrorReporter { + int Report(const char* format, va_list args) override; +}; + +// Return the default error reporter (output to stderr). +ErrorReporter* DefaultErrorReporter(); + +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_ERROR_REPORTER_H_ diff --git a/tensorflow/contrib/lite/examples/ios/camera/.gitignore b/tensorflow/contrib/lite/examples/ios/camera/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..9e8962f4c63562dd95896833f563abfbfb578ccc --- /dev/null +++ b/tensorflow/contrib/lite/examples/ios/camera/.gitignore @@ -0,0 +1,2 @@ +/data/*.txt +/data/*.tflite diff --git a/tensorflow/contrib/lite/examples/ios/camera/CameraExampleAppDelegate.h b/tensorflow/contrib/lite/examples/ios/camera/CameraExampleAppDelegate.h new file mode 100644 index 0000000000000000000000000000000000000000..55891c3ee18318037fd14fe4160c6f012aeaae66 --- /dev/null +++ b/tensorflow/contrib/lite/examples/ios/camera/CameraExampleAppDelegate.h @@ -0,0 +1,21 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +@interface CameraExampleAppDelegate : UIResponder + +@property(strong, nonatomic) UIWindow* window; + +@end diff --git a/tensorflow/contrib/lite/examples/ios/camera/CameraExampleAppDelegate.m b/tensorflow/contrib/lite/examples/ios/camera/CameraExampleAppDelegate.m new file mode 100644 index 0000000000000000000000000000000000000000..128266d53f560f3009f6435939ab48ae1c117a3a --- /dev/null +++ b/tensorflow/contrib/lite/examples/ios/camera/CameraExampleAppDelegate.m @@ -0,0 +1,44 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "CameraExampleAppDelegate.h" + +@implementation CameraExampleAppDelegate + +@synthesize window = _window; + +- (BOOL)application:(UIApplication *)application + didFinishLaunchingWithOptions:(NSDictionary *)launchOptions { + [self.window makeKeyAndVisible]; + return YES; +} + +- (void)applicationWillResignActive:(UIApplication *)application { + [[UIApplication sharedApplication] setIdleTimerDisabled:NO]; +} + +- (void)applicationDidEnterBackground:(UIApplication *)application { +} + +- (void)applicationWillEnterForeground:(UIApplication *)application { +} + +- (void)applicationDidBecomeActive:(UIApplication *)application { + [[UIApplication sharedApplication] setIdleTimerDisabled:YES]; +} + +- (void)applicationWillTerminate:(UIApplication *)application { +} + +@end diff --git a/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.h b/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.h new file mode 100644 index 0000000000000000000000000000000000000000..fb5800e86d365b56f1b52147c3f9cc8d7211f8c3 --- /dev/null +++ b/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.h @@ -0,0 +1,48 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import +#import + +#include + +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/model.h" + +@interface CameraExampleViewController + : UIViewController { + IBOutlet UIView* previewView; + AVCaptureVideoPreviewLayer* previewLayer; + AVCaptureVideoDataOutput* videoDataOutput; + dispatch_queue_t videoDataOutputQueue; + UIView* flashView; + BOOL isUsingFrontFacingCamera; + NSMutableDictionary* oldPredictionValues; + NSMutableArray* labelLayers; + AVCaptureSession* session; + + std::vector labels; + std::unique_ptr model; + tflite::ops::builtin::BuiltinOpResolver resolver; + std::unique_ptr interpreter; + + double total_latency; + int total_count; +} +@property(strong, nonatomic) CATextLayer* predictionTextLayer; + +- (IBAction)takePicture:(id)sender; +- (IBAction)switchCameras:(id)sender; + +@end diff --git a/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.mm b/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.mm new file mode 100644 index 0000000000000000000000000000000000000000..10f31bb6f17242c9f7f70f0648ec643f99c5ac86 --- /dev/null +++ b/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.mm @@ -0,0 +1,510 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "CameraExampleViewController.h" +#import +#import +#import +#import + +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/string_util.h" +#include "tensorflow/contrib/lite/tools/mutable_op_resolver.h" + +#define LOG(x) std::cerr + +// If you have your own model, modify this to the file name, and make sure +// you've added the file to your app resources too. +static NSString* model_file_name = @"mobilenet_quant_v1_224"; +static NSString* model_file_type = @"tflite"; + +// If you have your own model, point this to the labels file. +static NSString* labels_file_name = @"labels"; +static NSString* labels_file_type = @"txt"; + +// These dimensions need to match those the model was trained with. +static const int wanted_input_width = 224; +static const int wanted_input_height = 224; +static const int wanted_input_channels = 3; + +static NSString* FilePathForResourceName(NSString* name, NSString* extension) { + NSString* file_path = [[NSBundle mainBundle] pathForResource:name ofType:extension]; + if (file_path == NULL) { + LOG(FATAL) << "Couldn't find '" << [name UTF8String] << "." << [extension UTF8String] + << "' in bundle."; + } + return file_path; +} + +static void LoadLabels(NSString* file_name, NSString* file_type, + std::vector* label_strings) { + NSString* labels_path = FilePathForResourceName(file_name, file_type); + if (!labels_path) { + LOG(ERROR) << "Failed to find model proto at" << [file_name UTF8String] + << [file_type UTF8String]; + } + std::ifstream t; + t.open([labels_path UTF8String]); + std::string line; + while (t) { + std::getline(t, line); + label_strings->push_back(line); + } + t.close(); +} + +// Returns the top N confidence values over threshold in the provided vector, +// sorted by confidence in descending order. +static void GetTopN(const uint8_t* prediction, const int prediction_size, const int num_results, + const float threshold, std::vector>* top_results) { + // Will contain top N results in ascending order. + std::priority_queue, std::vector>, + std::greater>> + top_result_pq; + + const long count = prediction_size; + for (int i = 0; i < count; ++i) { + const float value = prediction[i] / 255.0; + // Only add it if it beats the threshold and has a chance at being in + // the top N. + if (value < threshold) { + continue; + } + + top_result_pq.push(std::pair(value, i)); + + // If at capacity, kick the smallest value out. + if (top_result_pq.size() > num_results) { + top_result_pq.pop(); + } + } + + // Copy to output vector and reverse into descending order. + while (!top_result_pq.empty()) { + top_results->push_back(top_result_pq.top()); + top_result_pq.pop(); + } + std::reverse(top_results->begin(), top_results->end()); +} + +@interface CameraExampleViewController (InternalMethods) +- (void)setupAVCapture; +- (void)teardownAVCapture; +@end + +@implementation CameraExampleViewController + +- (void)setupAVCapture { + NSError* error = nil; + + session = [AVCaptureSession new]; + if ([[UIDevice currentDevice] userInterfaceIdiom] == UIUserInterfaceIdiomPhone) + [session setSessionPreset:AVCaptureSessionPreset640x480]; + else + [session setSessionPreset:AVCaptureSessionPresetPhoto]; + + AVCaptureDevice* device = [AVCaptureDevice defaultDeviceWithMediaType:AVMediaTypeVideo]; + AVCaptureDeviceInput* deviceInput = + [AVCaptureDeviceInput deviceInputWithDevice:device error:&error]; + + if (error != nil) { + NSLog(@"Failed to initialize AVCaptureDeviceInput. Note: This app doesn't work with simulator"); + assert(NO); + } + + if ([session canAddInput:deviceInput]) [session addInput:deviceInput]; + + videoDataOutput = [AVCaptureVideoDataOutput new]; + + NSDictionary* rgbOutputSettings = + [NSDictionary dictionaryWithObject:[NSNumber numberWithInt:kCMPixelFormat_32BGRA] + forKey:(id)kCVPixelBufferPixelFormatTypeKey]; + [videoDataOutput setVideoSettings:rgbOutputSettings]; + [videoDataOutput setAlwaysDiscardsLateVideoFrames:YES]; + videoDataOutputQueue = dispatch_queue_create("VideoDataOutputQueue", DISPATCH_QUEUE_SERIAL); + [videoDataOutput setSampleBufferDelegate:self queue:videoDataOutputQueue]; + + if ([session canAddOutput:videoDataOutput]) [session addOutput:videoDataOutput]; + [[videoDataOutput connectionWithMediaType:AVMediaTypeVideo] setEnabled:YES]; + + previewLayer = [[AVCaptureVideoPreviewLayer alloc] initWithSession:session]; + [previewLayer setBackgroundColor:[[UIColor blackColor] CGColor]]; + [previewLayer setVideoGravity:AVLayerVideoGravityResizeAspect]; + CALayer* rootLayer = [previewView layer]; + [rootLayer setMasksToBounds:YES]; + [previewLayer setFrame:[rootLayer bounds]]; + [rootLayer addSublayer:previewLayer]; + [session startRunning]; + + if (error) { + NSString* title = [NSString stringWithFormat:@"Failed with error %d", (int)[error code]]; + UIAlertController* alertController = + [UIAlertController alertControllerWithTitle:title + message:[error localizedDescription] + preferredStyle:UIAlertControllerStyleAlert]; + UIAlertAction* dismiss = + [UIAlertAction actionWithTitle:@"Dismiss" style:UIAlertActionStyleDefault handler:nil]; + [alertController addAction:dismiss]; + [self presentViewController:alertController animated:YES completion:nil]; + [self teardownAVCapture]; + } +} + +- (void)teardownAVCapture { + [previewLayer removeFromSuperlayer]; +} + +- (AVCaptureVideoOrientation)avOrientationForDeviceOrientation: + (UIDeviceOrientation)deviceOrientation { + AVCaptureVideoOrientation result = (AVCaptureVideoOrientation)(deviceOrientation); + if (deviceOrientation == UIDeviceOrientationLandscapeLeft) + result = AVCaptureVideoOrientationLandscapeRight; + else if (deviceOrientation == UIDeviceOrientationLandscapeRight) + result = AVCaptureVideoOrientationLandscapeLeft; + return result; +} + +- (IBAction)takePicture:(id)sender { + if ([session isRunning]) { + [session stopRunning]; + [sender setTitle:@"Continue" forState:UIControlStateNormal]; + + flashView = [[UIView alloc] initWithFrame:[previewView frame]]; + [flashView setBackgroundColor:[UIColor whiteColor]]; + [flashView setAlpha:0.f]; + [[[self view] window] addSubview:flashView]; + + [UIView animateWithDuration:.2f + animations:^{ + [flashView setAlpha:1.f]; + } + completion:^(BOOL finished) { + [UIView animateWithDuration:.2f + animations:^{ + [flashView setAlpha:0.f]; + } + completion:^(BOOL finished) { + [flashView removeFromSuperview]; + flashView = nil; + }]; + }]; + + } else { + [session startRunning]; + [sender setTitle:@"Freeze Frame" forState:UIControlStateNormal]; + } +} + +- (void)captureOutput:(AVCaptureOutput*)captureOutput + didOutputSampleBuffer:(CMSampleBufferRef)sampleBuffer + fromConnection:(AVCaptureConnection*)connection { + CVPixelBufferRef pixelBuffer = CMSampleBufferGetImageBuffer(sampleBuffer); + CFRetain(pixelBuffer); + [self runModelOnFrame:pixelBuffer]; + CFRelease(pixelBuffer); +} + +- (void)runModelOnFrame:(CVPixelBufferRef)pixelBuffer { + assert(pixelBuffer != NULL); + + OSType sourcePixelFormat = CVPixelBufferGetPixelFormatType(pixelBuffer); + int doReverseChannels; + if (kCVPixelFormatType_32ARGB == sourcePixelFormat) { + doReverseChannels = 1; + } else if (kCVPixelFormatType_32BGRA == sourcePixelFormat) { + doReverseChannels = 0; + } else { + assert(false); // Unknown source format + } + + const int sourceRowBytes = (int)CVPixelBufferGetBytesPerRow(pixelBuffer); + const int image_width = (int)CVPixelBufferGetWidth(pixelBuffer); + const int fullHeight = (int)CVPixelBufferGetHeight(pixelBuffer); + + CVPixelBufferLockFlags unlockFlags = kNilOptions; + CVPixelBufferLockBaseAddress(pixelBuffer, unlockFlags); + + unsigned char* sourceBaseAddr = (unsigned char*)(CVPixelBufferGetBaseAddress(pixelBuffer)); + int image_height; + unsigned char* sourceStartAddr; + if (fullHeight <= image_width) { + image_height = fullHeight; + sourceStartAddr = sourceBaseAddr; + } else { + image_height = image_width; + const int marginY = ((fullHeight - image_width) / 2); + sourceStartAddr = (sourceBaseAddr + (marginY * sourceRowBytes)); + } + const int image_channels = 4; + assert(image_channels >= wanted_input_channels); + uint8_t* in = sourceStartAddr; + + int input = interpreter->inputs()[0]; + + uint8_t* out = interpreter->typed_tensor(input); + for (int y = 0; y < wanted_input_height; ++y) { + uint8_t* out_row = out + (y * wanted_input_width * wanted_input_channels); + for (int x = 0; x < wanted_input_width; ++x) { + const int in_x = (y * image_width) / wanted_input_width; + const int in_y = (x * image_height) / wanted_input_height; + uint8_t* in_pixel = in + (in_y * image_width * image_channels) + (in_x * image_channels); + uint8_t* out_pixel = out_row + (x * wanted_input_channels); + for (int c = 0; c < wanted_input_channels; ++c) { + out_pixel[c] = in_pixel[c]; + } + } + } + + double startTimestamp = [[NSDate new] timeIntervalSince1970]; + if (interpreter->Invoke() != kTfLiteOk) { + LOG(FATAL) << "Failed to invoke!"; + } + double endTimestamp = [[NSDate new] timeIntervalSince1970]; + total_latency += (endTimestamp - startTimestamp); + total_count += 1; + NSLog(@"Time: %.4lf, avg: %.4lf, count: %d", endTimestamp - startTimestamp, + total_latency / total_count, total_count); + + const int output_size = 1000; + const int kNumResults = 5; + const float kThreshold = 0.1f; + + std::vector> top_results; + + uint8_t* output = interpreter->typed_output_tensor(0); + GetTopN(output, output_size, kNumResults, kThreshold, &top_results); + + NSMutableDictionary* newValues = [NSMutableDictionary dictionary]; + for (const auto& result : top_results) { + const float confidence = result.first; + const int index = result.second; + NSString* labelObject = [NSString stringWithUTF8String:labels[index].c_str()]; + NSNumber* valueObject = [NSNumber numberWithFloat:confidence]; + [newValues setObject:valueObject forKey:labelObject]; + } + dispatch_async(dispatch_get_main_queue(), ^(void) { + [self setPredictionValues:newValues]; + }); + + CVPixelBufferUnlockBaseAddress(pixelBuffer, unlockFlags); + + CVPixelBufferUnlockBaseAddress(pixelBuffer, 0); +} + +- (void)dealloc { + [self teardownAVCapture]; +} + +- (void)didReceiveMemoryWarning { + [super didReceiveMemoryWarning]; +} + +- (void)viewDidLoad { + [super viewDidLoad]; + labelLayers = [[NSMutableArray alloc] init]; + oldPredictionValues = [[NSMutableDictionary alloc] init]; + + NSString* graph_path = FilePathForResourceName(model_file_name, @"tflite"); + model = tflite::FlatBufferModel::BuildFromFile([graph_path UTF8String]); + if (!model) { + LOG(FATAL) << "Failed to mmap model " << graph_path; + } + LOG(INFO) << "Loaded model " << graph_path; + model->error_reporter(); + LOG(INFO) << "resolved reporter"; + + tflite::ops::builtin::BuiltinOpResolver resolver; + LoadLabels(labels_file_name, labels_file_type, &labels); + + tflite::InterpreterBuilder(*model, resolver)(&interpreter); + if (!interpreter) { + LOG(FATAL) << "Failed to construct interpreter"; + } + if (interpreter->AllocateTensors() != kTfLiteOk) { + LOG(FATAL) << "Failed to allocate tensors!"; + } + + [self setupAVCapture]; +} + +- (void)viewDidUnload { + [super viewDidUnload]; +} + +- (void)viewWillAppear:(BOOL)animated { + [super viewWillAppear:animated]; +} + +- (void)viewDidAppear:(BOOL)animated { + [super viewDidAppear:animated]; +} + +- (void)viewWillDisappear:(BOOL)animated { + [super viewWillDisappear:animated]; +} + +- (void)viewDidDisappear:(BOOL)animated { + [super viewDidDisappear:animated]; +} + +- (BOOL)shouldAutorotateToInterfaceOrientation:(UIInterfaceOrientation)interfaceOrientation { + return (interfaceOrientation == UIInterfaceOrientationPortrait); +} + +- (BOOL)prefersStatusBarHidden { + return YES; +} + +- (void)setPredictionValues:(NSDictionary*)newValues { + const float decayValue = 0.75f; + const float updateValue = 0.25f; + const float minimumThreshold = 0.01f; + + NSMutableDictionary* decayedPredictionValues = [[NSMutableDictionary alloc] init]; + for (NSString* label in oldPredictionValues) { + NSNumber* oldPredictionValueObject = [oldPredictionValues objectForKey:label]; + const float oldPredictionValue = [oldPredictionValueObject floatValue]; + const float decayedPredictionValue = (oldPredictionValue * decayValue); + if (decayedPredictionValue > minimumThreshold) { + NSNumber* decayedPredictionValueObject = [NSNumber numberWithFloat:decayedPredictionValue]; + [decayedPredictionValues setObject:decayedPredictionValueObject forKey:label]; + } + } + oldPredictionValues = decayedPredictionValues; + + for (NSString* label in newValues) { + NSNumber* newPredictionValueObject = [newValues objectForKey:label]; + NSNumber* oldPredictionValueObject = [oldPredictionValues objectForKey:label]; + if (!oldPredictionValueObject) { + oldPredictionValueObject = [NSNumber numberWithFloat:0.0f]; + } + const float newPredictionValue = [newPredictionValueObject floatValue]; + const float oldPredictionValue = [oldPredictionValueObject floatValue]; + const float updatedPredictionValue = (oldPredictionValue + (newPredictionValue * updateValue)); + NSNumber* updatedPredictionValueObject = [NSNumber numberWithFloat:updatedPredictionValue]; + [oldPredictionValues setObject:updatedPredictionValueObject forKey:label]; + } + NSArray* candidateLabels = [NSMutableArray array]; + for (NSString* label in oldPredictionValues) { + NSNumber* oldPredictionValueObject = [oldPredictionValues objectForKey:label]; + const float oldPredictionValue = [oldPredictionValueObject floatValue]; + if (oldPredictionValue > 0.05f) { + NSDictionary* entry = @{@"label" : label, @"value" : oldPredictionValueObject}; + candidateLabels = [candidateLabels arrayByAddingObject:entry]; + } + } + NSSortDescriptor* sort = [NSSortDescriptor sortDescriptorWithKey:@"value" ascending:NO]; + NSArray* sortedLabels = + [candidateLabels sortedArrayUsingDescriptors:[NSArray arrayWithObject:sort]]; + + const float leftMargin = 10.0f; + const float topMargin = 10.0f; + + const float valueWidth = 48.0f; + const float valueHeight = 18.0f; + + const float labelWidth = 246.0f; + const float labelHeight = 18.0f; + + const float labelMarginX = 5.0f; + const float labelMarginY = 5.0f; + + [self removeAllLabelLayers]; + + int labelCount = 0; + for (NSDictionary* entry in sortedLabels) { + NSString* label = [entry objectForKey:@"label"]; + NSNumber* valueObject = [entry objectForKey:@"value"]; + const float value = [valueObject floatValue]; + const float originY = topMargin + ((labelHeight + labelMarginY) * labelCount); + const int valuePercentage = (int)roundf(value * 100.0f); + + const float valueOriginX = leftMargin; + NSString* valueText = [NSString stringWithFormat:@"%d%%", valuePercentage]; + + [self addLabelLayerWithText:valueText + originX:valueOriginX + originY:originY + width:valueWidth + height:valueHeight + alignment:kCAAlignmentRight]; + + const float labelOriginX = (leftMargin + valueWidth + labelMarginX); + + [self addLabelLayerWithText:[label capitalizedString] + originX:labelOriginX + originY:originY + width:labelWidth + height:labelHeight + alignment:kCAAlignmentLeft]; + + labelCount += 1; + if (labelCount > 4) { + break; + } + } +} + +- (void)removeAllLabelLayers { + for (CATextLayer* layer in labelLayers) { + [layer removeFromSuperlayer]; + } + [labelLayers removeAllObjects]; +} + +- (void)addLabelLayerWithText:(NSString*)text + originX:(float)originX + originY:(float)originY + width:(float)width + height:(float)height + alignment:(NSString*)alignment { + CFTypeRef font = (CFTypeRef) @"Menlo-Regular"; + const float fontSize = 12.0; + const float marginSizeX = 5.0f; + const float marginSizeY = 2.0f; + + const CGRect backgroundBounds = CGRectMake(originX, originY, width, height); + const CGRect textBounds = CGRectMake((originX + marginSizeX), (originY + marginSizeY), + (width - (marginSizeX * 2)), (height - (marginSizeY * 2))); + + CATextLayer* background = [CATextLayer layer]; + [background setBackgroundColor:[UIColor blackColor].CGColor]; + [background setOpacity:0.5f]; + [background setFrame:backgroundBounds]; + background.cornerRadius = 5.0f; + + [[self.view layer] addSublayer:background]; + [labelLayers addObject:background]; + + CATextLayer* layer = [CATextLayer layer]; + [layer setForegroundColor:[UIColor whiteColor].CGColor]; + [layer setFrame:textBounds]; + [layer setAlignmentMode:alignment]; + [layer setWrapped:YES]; + [layer setFont:font]; + [layer setFontSize:fontSize]; + layer.contentsScale = [[UIScreen mainScreen] scale]; + [layer setString:text]; + + [[self.view layer] addSublayer:layer]; + [labelLayers addObject:layer]; +} + +@end diff --git a/tensorflow/contrib/lite/examples/ios/camera/Info.plist b/tensorflow/contrib/lite/examples/ios/camera/Info.plist new file mode 100644 index 0000000000000000000000000000000000000000..f3d96bab162a707df4df8655354af5a54d1e985e --- /dev/null +++ b/tensorflow/contrib/lite/examples/ios/camera/Info.plist @@ -0,0 +1,44 @@ + + + + + CFBundleDevelopmentRegion + en + CFBundleDisplayName + tflite_camera_example + CFBundleExecutable + ${EXECUTABLE_NAME} + CFBundleIdentifier + $(PRODUCT_BUNDLE_IDENTIFIER) + CFBundleInfoDictionaryVersion + 6.0 + CFBundleName + ${PRODUCT_NAME} + CFBundlePackageType + APPL + CFBundleShortVersionString + 1.0 + CFBundleSignature + ???? + CFBundleVersion + 1.0 + LSRequiresIPhoneOS + + NSCameraUsageDescription + Capture images to detect object + UIMainStoryboardFile + MainStoryboard_iPhone + UIRequiresFullScreen + + UIStatusBarHidden + + UISupportedInterfaceOrientations + + UIInterfaceOrientationPortrait + + UISupportedInterfaceOrientations~ipad + + UIInterfaceOrientationPortrait + + + diff --git a/tensorflow/contrib/lite/examples/ios/camera/MainStoryboard_iPhone.storyboard b/tensorflow/contrib/lite/examples/ios/camera/MainStoryboard_iPhone.storyboard new file mode 100644 index 0000000000000000000000000000000000000000..0f10a22e415bd2519e90dd6bfac8b2ad6230caab --- /dev/null +++ b/tensorflow/contrib/lite/examples/ios/camera/MainStoryboard_iPhone.storyboard @@ -0,0 +1,46 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tensorflow/contrib/lite/examples/ios/camera/Podfile b/tensorflow/contrib/lite/examples/ios/camera/Podfile new file mode 100644 index 0000000000000000000000000000000000000000..4ae6fb6b94e4489f63506b05a2f348b7daafd3b7 --- /dev/null +++ b/tensorflow/contrib/lite/examples/ios/camera/Podfile @@ -0,0 +1,5 @@ +platform :ios, '8.0' +inhibit_all_warnings! + +target 'tflite_camera_example' + pod 'TensorFlow-experimental' diff --git a/tensorflow/contrib/lite/examples/ios/camera/data/.gitignore b/tensorflow/contrib/lite/examples/ios/camera/data/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tensorflow/contrib/lite/examples/ios/camera/main.mm b/tensorflow/contrib/lite/examples/ios/camera/main.mm new file mode 100644 index 0000000000000000000000000000000000000000..1a9e542f7c9a5b09be6463437c3a8e4a5afeda6d --- /dev/null +++ b/tensorflow/contrib/lite/examples/ios/camera/main.mm @@ -0,0 +1,28 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +#import "CameraExampleAppDelegate.h" + +int main(int argc, char* argv[]) { + int retVal = 0; + + @autoreleasepool { + retVal = + UIApplicationMain(argc, argv, nil, NSStringFromClass([CameraExampleAppDelegate class])); + } + return retVal; +} diff --git a/tensorflow/contrib/lite/examples/ios/camera/tflite_camera_example.xcodeproj/project.pbxproj b/tensorflow/contrib/lite/examples/ios/camera/tflite_camera_example.xcodeproj/project.pbxproj new file mode 100644 index 0000000000000000000000000000000000000000..c98183276bd60d2a0ad023ba26aad12572a02786 --- /dev/null +++ b/tensorflow/contrib/lite/examples/ios/camera/tflite_camera_example.xcodeproj/project.pbxproj @@ -0,0 +1,419 @@ +// !$*UTF8*$! +{ + archiveVersion = 1; + classes = { + }; + objectVersion = 46; + objects = { + +/* Begin PBXBuildFile section */ + 1C3C9DCC1ED3AB4200B8B5FA /* main.mm in Sources */ = {isa = PBXBuildFile; fileRef = 1C3C9DCA1ED3AB4200B8B5FA /* main.mm */; }; + 1C99111C1ED3B0E600A6BFB9 /* MainStoryboard_iPhone.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = 1C99111B1ED3B0E600A6BFB9 /* MainStoryboard_iPhone.storyboard */; }; + 1CA5EB931ED3ABFB00247A34 /* CoreMedia.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 1CA5EB921ED3ABFB00247A34 /* CoreMedia.framework */; }; + 1CB47D491ED3AD1700DF7666 /* AVFoundation.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 1CB47D481ED3AD1700DF7666 /* AVFoundation.framework */; }; + 1CDB2D491ED3A9CD007929E9 /* CameraExampleAppDelegate.m in Sources */ = {isa = PBXBuildFile; fileRef = 1CDB2D431ED3A9CD007929E9 /* CameraExampleAppDelegate.m */; }; + 1CDB2D4A1ED3A9CD007929E9 /* CameraExampleViewController.mm in Sources */ = {isa = PBXBuildFile; fileRef = 1CDB2D451ED3A9CD007929E9 /* CameraExampleViewController.mm */; }; + 1CDB2D4E1ED3AA35007929E9 /* Info.plist in Resources */ = {isa = PBXBuildFile; fileRef = 1CDB2D4D1ED3AA35007929E9 /* Info.plist */; }; + 54DC6C3C5F734F3A58069F0C /* libPods-tflite_camera_example.a in Frameworks */ = {isa = PBXBuildFile; fileRef = 3BA8BF92C84895BFE59D8236 /* libPods-tflite_camera_example.a */; }; + AC1F82661FBA3CBD0052BA77 /* labels.txt in Resources */ = {isa = PBXBuildFile; fileRef = AC1F82641FBA3CBD0052BA77 /* labels.txt */; }; + AC1F82691FBA3F930052BA77 /* libtensorflow-lite.a in Frameworks */ = {isa = PBXBuildFile; fileRef = AC1F82681FBA3F930052BA77 /* libtensorflow-lite.a */; }; + ACA1A4CA1FBB6C28009B8D86 /* mobilenet_quant_v1_224.tflite in Resources */ = {isa = PBXBuildFile; fileRef = ACA1A4C91FBB6C28009B8D86 /* mobilenet_quant_v1_224.tflite */; }; +/* End PBXBuildFile section */ + +/* Begin PBXFileReference section */ + 1C0D73481ECCC41B008C1DAB /* CoreImage.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = CoreImage.framework; path = System/Library/Frameworks/CoreImage.framework; sourceTree = SDKROOT; }; + 1C0D734A1ECCC460008C1DAB /* CoreGraphics.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = CoreGraphics.framework; path = System/Library/Frameworks/CoreGraphics.framework; sourceTree = SDKROOT; }; + 1C3C9DCA1ED3AB4200B8B5FA /* main.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = main.mm; sourceTree = ""; }; + 1C564C0D1ED3A92E00087306 /* tflite_camera_example.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = tflite_camera_example.app; sourceTree = BUILT_PRODUCTS_DIR; }; + 1C99111B1ED3B0E600A6BFB9 /* MainStoryboard_iPhone.storyboard */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = file.storyboard; path = MainStoryboard_iPhone.storyboard; sourceTree = ""; }; + 1CA45FFE1ECCC356002FA6A4 /* UIKit.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = UIKit.framework; path = System/Library/Frameworks/UIKit.framework; sourceTree = SDKROOT; }; + 1CA5EB921ED3ABFB00247A34 /* CoreMedia.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = CoreMedia.framework; path = System/Library/Frameworks/CoreMedia.framework; sourceTree = SDKROOT; }; + 1CB47D481ED3AD1700DF7666 /* AVFoundation.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = AVFoundation.framework; path = System/Library/Frameworks/AVFoundation.framework; sourceTree = SDKROOT; }; + 1CDB2D421ED3A9CD007929E9 /* CameraExampleAppDelegate.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = CameraExampleAppDelegate.h; sourceTree = ""; }; + 1CDB2D431ED3A9CD007929E9 /* CameraExampleAppDelegate.m */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.objc; path = CameraExampleAppDelegate.m; sourceTree = ""; }; + 1CDB2D441ED3A9CD007929E9 /* CameraExampleViewController.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = CameraExampleViewController.h; sourceTree = ""; }; + 1CDB2D451ED3A9CD007929E9 /* CameraExampleViewController.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = CameraExampleViewController.mm; sourceTree = ""; }; + 1CDB2D4D1ED3AA35007929E9 /* Info.plist */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text.plist.xml; path = Info.plist; sourceTree = ""; }; + 3BA8BF92C84895BFE59D8236 /* libPods-tflite_camera_example.a */ = {isa = PBXFileReference; explicitFileType = archive.ar; includeInIndex = 0; path = "libPods-tflite_camera_example.a"; sourceTree = BUILT_PRODUCTS_DIR; }; + 3BC5BE4BBD09374D3E98F082 /* Pods-tflite_camera_example.debug.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-tflite_camera_example.debug.xcconfig"; path = "Pods/Target Support Files/Pods-tflite_camera_example/Pods-tflite_camera_example.debug.xcconfig"; sourceTree = ""; }; + 55ED318E8D29C8AFEF03DF1E /* Pods-tflite_camera_example.release.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-tflite_camera_example.release.xcconfig"; path = "Pods/Target Support Files/Pods-tflite_camera_example/Pods-tflite_camera_example.release.xcconfig"; sourceTree = ""; }; + AC1F82641FBA3CBD0052BA77 /* labels.txt */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text; path = labels.txt; sourceTree = ""; }; + AC1F82681FBA3F930052BA77 /* libtensorflow-lite.a */ = {isa = PBXFileReference; lastKnownFileType = archive.ar; name = "libtensorflow-lite.a"; path = "../../../gen/lib/libtensorflow-lite.a"; sourceTree = ""; }; + ACA1A4C91FBB6C28009B8D86 /* mobilenet_quant_v1_224.tflite */ = {isa = PBXFileReference; lastKnownFileType = file; path = mobilenet_quant_v1_224.tflite; sourceTree = ""; }; +/* End PBXFileReference section */ + +/* Begin PBXFrameworksBuildPhase section */ + 1C564C0A1ED3A92E00087306 /* Frameworks */ = { + isa = PBXFrameworksBuildPhase; + buildActionMask = 2147483647; + files = ( + AC1F82691FBA3F930052BA77 /* libtensorflow-lite.a in Frameworks */, + 1CB47D491ED3AD1700DF7666 /* AVFoundation.framework in Frameworks */, + 1CA5EB931ED3ABFB00247A34 /* CoreMedia.framework in Frameworks */, + 54DC6C3C5F734F3A58069F0C /* libPods-tflite_camera_example.a in Frameworks */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXFrameworksBuildPhase section */ + +/* Begin PBXGroup section */ + 24D7686C331131624F4454A0 /* Frameworks */ = { + isa = PBXGroup; + children = ( + AC1F82681FBA3F930052BA77 /* libtensorflow-lite.a */, + 1CB47D481ED3AD1700DF7666 /* AVFoundation.framework */, + 1CA5EB921ED3ABFB00247A34 /* CoreMedia.framework */, + 1C0D734A1ECCC460008C1DAB /* CoreGraphics.framework */, + 1C0D73481ECCC41B008C1DAB /* CoreImage.framework */, + 1CA45FFE1ECCC356002FA6A4 /* UIKit.framework */, + 3BA8BF92C84895BFE59D8236 /* libPods-tflite_camera_example.a */, + ); + name = Frameworks; + sourceTree = ""; + }; + 3E9FC355632FB928EA23BEED /* Pods */ = { + isa = PBXGroup; + children = ( + 3BC5BE4BBD09374D3E98F082 /* Pods-tflite_camera_example.debug.xcconfig */, + 55ED318E8D29C8AFEF03DF1E /* Pods-tflite_camera_example.release.xcconfig */, + ); + name = Pods; + sourceTree = ""; + }; + 591157921CF4011C00C31E3A = { + isa = PBXGroup; + children = ( + 1C99111B1ED3B0E600A6BFB9 /* MainStoryboard_iPhone.storyboard */, + 1C3C9DCA1ED3AB4200B8B5FA /* main.mm */, + 1CDB2D4D1ED3AA35007929E9 /* Info.plist */, + 1CDB2D421ED3A9CD007929E9 /* CameraExampleAppDelegate.h */, + 1CDB2D431ED3A9CD007929E9 /* CameraExampleAppDelegate.m */, + 1CDB2D441ED3A9CD007929E9 /* CameraExampleViewController.h */, + 1CDB2D451ED3A9CD007929E9 /* CameraExampleViewController.mm */, + 59A3CFF31CF4E68100C4259F /* data */, + 5911579C1CF4011C00C31E3A /* Products */, + 3E9FC355632FB928EA23BEED /* Pods */, + 24D7686C331131624F4454A0 /* Frameworks */, + ); + sourceTree = ""; + }; + 5911579C1CF4011C00C31E3A /* Products */ = { + isa = PBXGroup; + children = ( + 1C564C0D1ED3A92E00087306 /* tflite_camera_example.app */, + ); + name = Products; + sourceTree = ""; + }; + 59A3CFF31CF4E68100C4259F /* data */ = { + isa = PBXGroup; + children = ( + ACA1A4C91FBB6C28009B8D86 /* mobilenet_quant_v1_224.tflite */, + AC1F82641FBA3CBD0052BA77 /* labels.txt */, + ); + path = data; + sourceTree = ""; + }; +/* End PBXGroup section */ + +/* Begin PBXNativeTarget section */ + 1C564C0C1ED3A92E00087306 /* tflite_camera_example */ = { + isa = PBXNativeTarget; + buildConfigurationList = 1C564C351ED3A92E00087306 /* Build configuration list for PBXNativeTarget "tflite_camera_example" */; + buildPhases = ( + 66DAEAAEE9EF6550C3A061E0 /* [CP] Check Pods Manifest.lock */, + 1C564C091ED3A92E00087306 /* Sources */, + 1C564C0A1ED3A92E00087306 /* Frameworks */, + 1C564C0B1ED3A92E00087306 /* Resources */, + 00E875C3B066535AE6B77101 /* [CP] Embed Pods Frameworks */, + 5C2D02120E3E5E09567AA946 /* [CP] Copy Pods Resources */, + ); + buildRules = ( + ); + dependencies = ( + ); + name = tflite_camera_example; + productName = tflite_camera_example; + productReference = 1C564C0D1ED3A92E00087306 /* tflite_camera_example.app */; + productType = "com.apple.product-type.application"; + }; +/* End PBXNativeTarget section */ + +/* Begin PBXProject section */ + 591157931CF4011C00C31E3A /* Project object */ = { + isa = PBXProject; + attributes = { + LastSwiftUpdateCheck = 0830; + LastUpgradeCheck = 0830; + ORGANIZATIONNAME = Google; + TargetAttributes = { + 1C564C0C1ED3A92E00087306 = { + CreatedOnToolsVersion = 8.3.2; + DevelopmentTeam = EQHXZ8M8AV; + ProvisioningStyle = Automatic; + }; + }; + }; + buildConfigurationList = 591157961CF4011C00C31E3A /* Build configuration list for PBXProject "tflite_camera_example" */; + compatibilityVersion = "Xcode 3.2"; + developmentRegion = English; + hasScannedForEncodings = 0; + knownRegions = ( + en, + Base, + ); + mainGroup = 591157921CF4011C00C31E3A; + productRefGroup = 5911579C1CF4011C00C31E3A /* Products */; + projectDirPath = ""; + projectRoot = ""; + targets = ( + 1C564C0C1ED3A92E00087306 /* tflite_camera_example */, + ); + }; +/* End PBXProject section */ + +/* Begin PBXResourcesBuildPhase section */ + 1C564C0B1ED3A92E00087306 /* Resources */ = { + isa = PBXResourcesBuildPhase; + buildActionMask = 2147483647; + files = ( + ACA1A4CA1FBB6C28009B8D86 /* mobilenet_quant_v1_224.tflite in Resources */, + 1C99111C1ED3B0E600A6BFB9 /* MainStoryboard_iPhone.storyboard in Resources */, + 1CDB2D4E1ED3AA35007929E9 /* Info.plist in Resources */, + AC1F82661FBA3CBD0052BA77 /* labels.txt in Resources */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXResourcesBuildPhase section */ + +/* Begin PBXShellScriptBuildPhase section */ + 00E875C3B066535AE6B77101 /* [CP] Embed Pods Frameworks */ = { + isa = PBXShellScriptBuildPhase; + buildActionMask = 2147483647; + files = ( + ); + inputPaths = ( + ); + name = "[CP] Embed Pods Frameworks"; + outputPaths = ( + ); + runOnlyForDeploymentPostprocessing = 0; + shellPath = /bin/sh; + shellScript = "\"${SRCROOT}/Pods/Target Support Files/Pods-tflite_camera_example/Pods-tflite_camera_example-frameworks.sh\"\n"; + showEnvVarsInLog = 0; + }; + 5C2D02120E3E5E09567AA946 /* [CP] Copy Pods Resources */ = { + isa = PBXShellScriptBuildPhase; + buildActionMask = 2147483647; + files = ( + ); + inputPaths = ( + ); + name = "[CP] Copy Pods Resources"; + outputPaths = ( + ); + runOnlyForDeploymentPostprocessing = 0; + shellPath = /bin/sh; + shellScript = "\"${SRCROOT}/Pods/Target Support Files/Pods-tflite_camera_example/Pods-tflite_camera_example-resources.sh\"\n"; + showEnvVarsInLog = 0; + }; + 66DAEAAEE9EF6550C3A061E0 /* [CP] Check Pods Manifest.lock */ = { + isa = PBXShellScriptBuildPhase; + buildActionMask = 2147483647; + files = ( + ); + inputPaths = ( + "${PODS_PODFILE_DIR_PATH}/Podfile.lock", + "${PODS_ROOT}/Manifest.lock", + ); + name = "[CP] Check Pods Manifest.lock"; + outputPaths = ( + "$(DERIVED_FILE_DIR)/Pods-tflite_camera_example-checkManifestLockResult.txt", + ); + runOnlyForDeploymentPostprocessing = 0; + shellPath = /bin/sh; + shellScript = "diff \"${PODS_PODFILE_DIR_PATH}/Podfile.lock\" \"${PODS_ROOT}/Manifest.lock\" > /dev/null\nif [ $? != 0 ] ; then\n # print error to STDERR\n echo \"error: The sandbox is not in sync with the Podfile.lock. Run 'pod install' or update your CocoaPods installation.\" >&2\n exit 1\nfi\n# This output is used by Xcode 'outputs' to avoid re-running this script phase.\necho \"SUCCESS\" > \"${SCRIPT_OUTPUT_FILE_0}\"\n"; + showEnvVarsInLog = 0; + }; +/* End PBXShellScriptBuildPhase section */ + +/* Begin PBXSourcesBuildPhase section */ + 1C564C091ED3A92E00087306 /* Sources */ = { + isa = PBXSourcesBuildPhase; + buildActionMask = 2147483647; + files = ( + 1CDB2D4A1ED3A9CD007929E9 /* CameraExampleViewController.mm in Sources */, + 1CDB2D491ED3A9CD007929E9 /* CameraExampleAppDelegate.m in Sources */, + 1C3C9DCC1ED3AB4200B8B5FA /* main.mm in Sources */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXSourcesBuildPhase section */ + +/* Begin XCBuildConfiguration section */ + 1C564C361ED3A92E00087306 /* Debug */ = { + isa = XCBuildConfiguration; + baseConfigurationReference = 3BC5BE4BBD09374D3E98F082 /* Pods-tflite_camera_example.debug.xcconfig */; + buildSettings = { + ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; + CLANG_ANALYZER_NONNULL = YES; + CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; + CLANG_WARN_DOCUMENTATION_COMMENTS = YES; + DEVELOPMENT_TEAM = EQHXZ8M8AV; + INFOPLIST_FILE = Info.plist; + IPHONEOS_DEPLOYMENT_TARGET = 10.3; + LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks"; + PRODUCT_BUNDLE_IDENTIFIER = "com.pf.tf-camera-example"; + PRODUCT_NAME = "$(TARGET_NAME)"; + SWIFT_ACTIVE_COMPILATION_CONDITIONS = DEBUG; + SWIFT_OPTIMIZATION_LEVEL = "-Onone"; + SWIFT_VERSION = 3.0; + }; + name = Debug; + }; + 1C564C371ED3A92E00087306 /* Release */ = { + isa = XCBuildConfiguration; + baseConfigurationReference = 55ED318E8D29C8AFEF03DF1E /* Pods-tflite_camera_example.release.xcconfig */; + buildSettings = { + ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; + CLANG_ANALYZER_NONNULL = YES; + CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; + CLANG_WARN_DOCUMENTATION_COMMENTS = YES; + DEVELOPMENT_TEAM = EQHXZ8M8AV; + INFOPLIST_FILE = Info.plist; + IPHONEOS_DEPLOYMENT_TARGET = 10.3; + LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks"; + PRODUCT_BUNDLE_IDENTIFIER = "com.pf.tf-camera-example"; + PRODUCT_NAME = "$(TARGET_NAME)"; + SWIFT_OPTIMIZATION_LEVEL = "-Owholemodule"; + SWIFT_VERSION = 3.0; + }; + name = Release; + }; + 591157B01CF4011D00C31E3A /* Debug */ = { + isa = XCBuildConfiguration; + buildSettings = { + ALWAYS_SEARCH_USER_PATHS = NO; + CLANG_CXX_LANGUAGE_STANDARD = "gnu++0x"; + CLANG_CXX_LIBRARY = "libc++"; + CLANG_ENABLE_MODULES = YES; + CLANG_ENABLE_OBJC_ARC = YES; + CLANG_WARN_BOOL_CONVERSION = YES; + CLANG_WARN_CONSTANT_CONVERSION = YES; + CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; + CLANG_WARN_EMPTY_BODY = YES; + CLANG_WARN_ENUM_CONVERSION = YES; + CLANG_WARN_INFINITE_RECURSION = YES; + CLANG_WARN_INT_CONVERSION = YES; + CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; + CLANG_WARN_SUSPICIOUS_MOVE = YES; + CLANG_WARN_UNREACHABLE_CODE = YES; + CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; + "CODE_SIGN_IDENTITY[sdk=iphoneos*]" = "iPhone Developer"; + COPY_PHASE_STRIP = NO; + DEBUG_INFORMATION_FORMAT = dwarf; + ENABLE_STRICT_OBJC_MSGSEND = YES; + ENABLE_TESTABILITY = YES; + GCC_C_LANGUAGE_STANDARD = gnu99; + GCC_DYNAMIC_NO_PIC = NO; + GCC_NO_COMMON_BLOCKS = YES; + GCC_OPTIMIZATION_LEVEL = 0; + GCC_PREPROCESSOR_DEFINITIONS = ( + "DEBUG=1", + "$(inherited)", + ); + GCC_WARN_64_TO_32_BIT_CONVERSION = YES; + GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; + GCC_WARN_UNDECLARED_SELECTOR = YES; + GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; + GCC_WARN_UNUSED_FUNCTION = YES; + GCC_WARN_UNUSED_VARIABLE = YES; + HEADER_SEARCH_PATHS = ( + "$(inherited)", + ../../../../../../, + ../../../downloads/flatbuffers/include/, + ../../../downloads/eigen/, + ../../../downloads/, + ); + IPHONEOS_DEPLOYMENT_TARGET = 8.0; + LIBRARY_SEARCH_PATHS = ../../../gen/lib/; + MTL_ENABLE_DEBUG_INFO = YES; + ONLY_ACTIVE_ARCH = YES; + SDKROOT = iphoneos; + TARGETED_DEVICE_FAMILY = "1,2"; + }; + name = Debug; + }; + 591157B11CF4011D00C31E3A /* Release */ = { + isa = XCBuildConfiguration; + buildSettings = { + ALWAYS_SEARCH_USER_PATHS = NO; + CLANG_CXX_LANGUAGE_STANDARD = "gnu++0x"; + CLANG_CXX_LIBRARY = "libc++"; + CLANG_ENABLE_MODULES = YES; + CLANG_ENABLE_OBJC_ARC = YES; + CLANG_WARN_BOOL_CONVERSION = YES; + CLANG_WARN_CONSTANT_CONVERSION = YES; + CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; + CLANG_WARN_EMPTY_BODY = YES; + CLANG_WARN_ENUM_CONVERSION = YES; + CLANG_WARN_INFINITE_RECURSION = YES; + CLANG_WARN_INT_CONVERSION = YES; + CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; + CLANG_WARN_SUSPICIOUS_MOVE = YES; + CLANG_WARN_UNREACHABLE_CODE = YES; + CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; + "CODE_SIGN_IDENTITY[sdk=iphoneos*]" = "iPhone Developer"; + COPY_PHASE_STRIP = NO; + DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym"; + ENABLE_NS_ASSERTIONS = NO; + ENABLE_STRICT_OBJC_MSGSEND = YES; + GCC_C_LANGUAGE_STANDARD = gnu99; + GCC_NO_COMMON_BLOCKS = YES; + GCC_WARN_64_TO_32_BIT_CONVERSION = YES; + GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; + GCC_WARN_UNDECLARED_SELECTOR = YES; + GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; + GCC_WARN_UNUSED_FUNCTION = YES; + GCC_WARN_UNUSED_VARIABLE = YES; + HEADER_SEARCH_PATHS = ( + "$(inherited)", + ../../../../../../, + ../../../downloads/flatbuffers/include/, + ../../../downloads/eigen/, + ../../../downloads/, + ); + IPHONEOS_DEPLOYMENT_TARGET = 8.0; + LIBRARY_SEARCH_PATHS = ../../../gen/lib/; + MTL_ENABLE_DEBUG_INFO = NO; + SDKROOT = iphoneos; + TARGETED_DEVICE_FAMILY = "1,2"; + VALIDATE_PRODUCT = YES; + }; + name = Release; + }; +/* End XCBuildConfiguration section */ + +/* Begin XCConfigurationList section */ + 1C564C351ED3A92E00087306 /* Build configuration list for PBXNativeTarget "tflite_camera_example" */ = { + isa = XCConfigurationList; + buildConfigurations = ( + 1C564C361ED3A92E00087306 /* Debug */, + 1C564C371ED3A92E00087306 /* Release */, + ); + defaultConfigurationIsVisible = 0; + defaultConfigurationName = Release; + }; + 591157961CF4011C00C31E3A /* Build configuration list for PBXProject "tflite_camera_example" */ = { + isa = XCConfigurationList; + buildConfigurations = ( + 591157B01CF4011D00C31E3A /* Debug */, + 591157B11CF4011D00C31E3A /* Release */, + ); + defaultConfigurationIsVisible = 0; + defaultConfigurationName = Release; + }; +/* End XCConfigurationList section */ + }; + rootObject = 591157931CF4011C00C31E3A /* Project object */; +} diff --git a/tensorflow/contrib/lite/examples/ios/simple/AppDelegate.h b/tensorflow/contrib/lite/examples/ios/simple/AppDelegate.h new file mode 100644 index 0000000000000000000000000000000000000000..94046d9728258901091f018fd0d081651145f400 --- /dev/null +++ b/tensorflow/contrib/lite/examples/ios/simple/AppDelegate.h @@ -0,0 +1,21 @@ +// Copyright 2015 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +@interface AppDelegate : UIResponder + +@property(strong, nonatomic) UIWindow *window; + +@end diff --git a/tensorflow/contrib/lite/examples/ios/simple/AppDelegate.mm b/tensorflow/contrib/lite/examples/ios/simple/AppDelegate.mm new file mode 100644 index 0000000000000000000000000000000000000000..d1215fa0bffd978b4aaadbd8bc13b07723703c9a --- /dev/null +++ b/tensorflow/contrib/lite/examples/ios/simple/AppDelegate.mm @@ -0,0 +1,48 @@ +// Copyright 2015 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "AppDelegate.h" + +#import "RunModelViewController.h" + +@implementation AppDelegate + +- (BOOL)application:(UIApplication *)application + didFinishLaunchingWithOptions:(NSDictionary *)launchOptions { + + UITabBarController *bar = [[UITabBarController alloc] init]; + [bar setViewControllers:@[ [[RunModelViewController alloc] init] ]]; + bar.selectedIndex = 0; + self.window = [[UIWindow alloc] initWithFrame:[[UIScreen mainScreen] bounds]]; + self.window.rootViewController = bar; + [self.window makeKeyAndVisible]; + return YES; +} + +- (void)applicationWillResignActive:(UIApplication *)application { +} + +- (void)applicationDidEnterBackground:(UIApplication *)application { +} + +- (void)applicationWillEnterForeground:(UIApplication *)application { +} + +- (void)applicationDidBecomeActive:(UIApplication *)application { +} + +- (void)applicationWillTerminate:(UIApplication *)application { +} + +@end diff --git a/tensorflow/contrib/lite/examples/ios/simple/Podfile b/tensorflow/contrib/lite/examples/ios/simple/Podfile new file mode 100644 index 0000000000000000000000000000000000000000..1740ad64573a84fae6de0fcf284eb06afec67e25 --- /dev/null +++ b/tensorflow/contrib/lite/examples/ios/simple/Podfile @@ -0,0 +1,5 @@ +platform :ios, '8.0' +inhibit_all_warnings! + +target 'tf_simple_example' + pod 'TensorFlow-experimental' diff --git a/tensorflow/contrib/lite/examples/ios/simple/RunModel-Info.plist b/tensorflow/contrib/lite/examples/ios/simple/RunModel-Info.plist new file mode 100644 index 0000000000000000000000000000000000000000..1a3eaa8a2c18d1cd24dfd475d396b00ec4d86c9d --- /dev/null +++ b/tensorflow/contrib/lite/examples/ios/simple/RunModel-Info.plist @@ -0,0 +1,47 @@ + + + + + CFBundleDevelopmentRegion + en + CFBundleDisplayName + tflite-simple-example + CFBundleExecutable + tf_simple_example + CFBundleIdentifier + $(PRODUCT_BUNDLE_IDENTIFIER) + CFBundleInfoDictionaryVersion + 6.0 + CFBundleName + ios-app + CFBundlePackageType + APPL + CFBundleShortVersionString + 1.0 + CFBundleSignature + ???? + CFBundleVersion + 1.0 + LSRequiresIPhoneOS + + UILaunchStoryboardName + RunModelViewController + UIRequiredDeviceCapabilities + + armv7 + + UISupportedInterfaceOrientations + + UIInterfaceOrientationPortrait + UIInterfaceOrientationLandscapeLeft + UIInterfaceOrientationLandscapeRight + + UISupportedInterfaceOrientations~ipad + + UIInterfaceOrientationPortrait + UIInterfaceOrientationPortraitUpsideDown + UIInterfaceOrientationLandscapeLeft + UIInterfaceOrientationLandscapeRight + + + diff --git a/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.h b/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.h new file mode 100644 index 0000000000000000000000000000000000000000..a4b358b4eb7f6ba109638405091b798d30bd1768 --- /dev/null +++ b/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.h @@ -0,0 +1,24 @@ +// Copyright 2015 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +@interface RunModelViewController : UIViewController + +- (IBAction)getUrl:(id)sender; + +@property(weak, nonatomic) IBOutlet UITextView *urlContentTextView; +@property(weak, nonatomic) IBOutlet UITextField *urlTextField; + +@end diff --git a/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.mm b/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.mm new file mode 100644 index 0000000000000000000000000000000000000000..0dafb1f61e19f46bb3b17f07c55e09f5813ed560 --- /dev/null +++ b/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.mm @@ -0,0 +1,221 @@ +// Copyright 2015 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import "RunModelViewController.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/string_util.h" +#include "tensorflow/contrib/lite/tools/mutable_op_resolver.h" + +#include "ios_image_load.h" + +#define LOG(x) std::cerr +#define CHECK(x) \ + if (!(x)) { \ + LOG(ERROR) << #x << "failed"; \ + exit(1); \ + } + +NSString* RunInferenceOnImage(); + +@interface RunModelViewController () +@end + +@implementation RunModelViewController { +} + +- (IBAction)getUrl:(id)sender { + NSString* inference_result = RunInferenceOnImage(); + self.urlContentTextView.text = inference_result; +} + +@end + +// Returns the top N confidence values over threshold in the provided vector, +// sorted by confidence in descending order. +static void GetTopN(const float* prediction, const int prediction_size, const int num_results, + const float threshold, std::vector >* top_results) { + // Will contain top N results in ascending order. + std::priority_queue, std::vector >, + std::greater > > + top_result_pq; + + const long count = prediction_size; + for (int i = 0; i < count; ++i) { + const float value = prediction[i]; + + // Only add it if it beats the threshold and has a chance at being in + // the top N. + if (value < threshold) { + continue; + } + + top_result_pq.push(std::pair(value, i)); + + // If at capacity, kick the smallest value out. + if (top_result_pq.size() > num_results) { + top_result_pq.pop(); + } + } + + // Copy to output vector and reverse into descending order. + while (!top_result_pq.empty()) { + top_results->push_back(top_result_pq.top()); + top_result_pq.pop(); + } + std::reverse(top_results->begin(), top_results->end()); +} + +NSString* FilePathForResourceName(NSString* name, NSString* extension) { + NSString* file_path = [[NSBundle mainBundle] pathForResource:name ofType:extension]; + if (file_path == NULL) { + LOG(FATAL) << "Couldn't find '" << [name UTF8String] << "." << [extension UTF8String] + << "' in bundle."; + } + return file_path; +} + +NSString* RunInferenceOnImage() { + std::string graph; + const int num_threads = 1; + std::string input_layer_type = "float"; + std::vector sizes = {1, 224, 224, 3}; + + NSString* graph_path = FilePathForResourceName(@"mobilenet_v1_1.0_224", @"tflite"); + + std::unique_ptr model( + tflite::FlatBufferModel::BuildFromFile([graph_path UTF8String])); + if (!model) { + LOG(FATAL) << "Failed to mmap model " << graph; + } + LOG(INFO) << "Loaded model " << graph; + model->error_reporter(); + LOG(INFO) << "resolved reporter"; + +#ifdef TFLITE_CUSTOM_OPS_HEADER + tflite::MutableOpResolver resolver; + RegisterSelectedOps(&resolver); +#else + tflite::ops::builtin::BuiltinOpResolver resolver; +#endif + + std::unique_ptr interpreter; + tflite::InterpreterBuilder(*model, resolver)(&interpreter); + if (!interpreter) { + LOG(FATAL) << "Failed to construct interpreter"; + } + + if (num_threads != -1) { + interpreter->SetNumThreads(num_threads); + } + + int input = interpreter->inputs()[0]; + + if (input_layer_type != "string") { + interpreter->ResizeInputTensor(input, sizes); + } + + if (interpreter->AllocateTensors() != kTfLiteOk) { + LOG(FATAL) << "Failed to allocate tensors!"; + } + + // Read the label list + NSString* labels_path = FilePathForResourceName(@"labels", @"txt"); + std::vector label_strings; + std::ifstream t; + t.open([labels_path UTF8String]); + std::string line; + while (t) { + std::getline(t, line); + label_strings.push_back(line); + } + t.close(); + + // Read the Grace Hopper image. + NSString* image_path = FilePathForResourceName(@"grace_hopper", @"jpg"); + int image_width; + int image_height; + int image_channels; + std::vector image_data = + LoadImageFromFile([image_path UTF8String], &image_width, &image_height, &image_channels); + const int wanted_width = 224; + const int wanted_height = 224; + const int wanted_channels = 3; + const float input_mean = 127.5f; + const float input_std = 127.5f; + assert(image_channels >= wanted_channels); + uint8_t* in = image_data.data(); + float* out = interpreter->typed_tensor(input); + for (int y = 0; y < wanted_height; ++y) { + const int in_y = (y * image_height) / wanted_height; + uint8_t* in_row = in + (in_y * image_width * image_channels); + float* out_row = out + (y * wanted_width * wanted_channels); + for (int x = 0; x < wanted_width; ++x) { + const int in_x = (x * image_width) / wanted_width; + uint8_t* in_pixel = in_row + (in_x * image_channels); + float* out_pixel = out_row + (x * wanted_channels); + for (int c = 0; c < wanted_channels; ++c) { + out_pixel[c] = (in_pixel[c] - input_mean) / input_std; + } + } + } + + if (interpreter->Invoke() != kTfLiteOk) { + LOG(FATAL) << "Failed to invoke!"; + } + + float* output = interpreter->typed_output_tensor(0); + const int output_size = 1000; + const int kNumResults = 5; + const float kThreshold = 0.1f; + std::vector > top_results; + GetTopN(output, output_size, kNumResults, kThreshold, &top_results); + + std::stringstream ss; + ss.precision(3); + for (const auto& result : top_results) { + const float confidence = result.first; + const int index = result.second; + + ss << index << " " << confidence << " "; + + // Write out the result as a string + if (index < label_strings.size()) { + // just for safety: theoretically, the output is under 1000 unless there + // is some numerical issues leading to a wrong prediction. + ss << label_strings[index]; + } else { + ss << "Prediction: " << index; + } + + ss << "\n"; + } + + LOG(INFO) << "Predictions: " << ss.str(); + + std::string predictions = ss.str(); + NSString* result = @""; + result = [NSString stringWithFormat:@"%@ - %s", result, predictions.c_str()]; + + return result; +} diff --git a/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.xib b/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.xib new file mode 100644 index 0000000000000000000000000000000000000000..93f334b9850c6f5f22455b3d14a075c17a7c9171 --- /dev/null +++ b/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.xib @@ -0,0 +1,46 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tensorflow/contrib/lite/examples/ios/simple/data/grace_hopper.jpg b/tensorflow/contrib/lite/examples/ios/simple/data/grace_hopper.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d2a427810f679db537236c5430873a81a62ef412 Binary files /dev/null and b/tensorflow/contrib/lite/examples/ios/simple/data/grace_hopper.jpg differ diff --git a/tensorflow/contrib/lite/examples/ios/simple/ios_image_load.h b/tensorflow/contrib/lite/examples/ios/simple/ios_image_load.h new file mode 100644 index 0000000000000000000000000000000000000000..98934ce41d349b33d4fc010a39a956e52f3d5721 --- /dev/null +++ b/tensorflow/contrib/lite/examples/ios/simple/ios_image_load.h @@ -0,0 +1,23 @@ +// Copyright 2015 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TENSORFLOW_EXAMPLES_IOS_IOS_IMAGE_LOAD_H_ +#define TENSORFLOW_EXAMPLES_IOS_IOS_IMAGE_LOAD_H_ + +#include + +std::vector LoadImageFromFile(const char* file_name, int* out_width, + int* out_height, int* out_channels); + +#endif // TENSORFLOW_EXAMPLES_IOS_IOS_IMAGE_LOAD_H_ diff --git a/tensorflow/contrib/lite/examples/ios/simple/ios_image_load.mm b/tensorflow/contrib/lite/examples/ios/simple/ios_image_load.mm new file mode 100644 index 0000000000000000000000000000000000000000..cb0fe1a7650c572d3745066431f2759daa94ffc9 --- /dev/null +++ b/tensorflow/contrib/lite/examples/ios/simple/ios_image_load.mm @@ -0,0 +1,82 @@ +// Copyright 2015 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ios_image_load.h" + +#include +#include +#include +#include + +#import +#import + +std::vector LoadImageFromFile(const char* file_name, int* out_width, int* out_height, + int* out_channels) { + FILE* file_handle = fopen(file_name, "rb"); + fseek(file_handle, 0, SEEK_END); + const size_t bytes_in_file = ftell(file_handle); + fseek(file_handle, 0, SEEK_SET); + std::vector file_data(bytes_in_file); + fread(file_data.data(), 1, bytes_in_file, file_handle); + fclose(file_handle); + + CFDataRef file_data_ref = + CFDataCreateWithBytesNoCopy(NULL, file_data.data(), bytes_in_file, kCFAllocatorNull); + CGDataProviderRef image_provider = CGDataProviderCreateWithCFData(file_data_ref); + + const char* suffix = strrchr(file_name, '.'); + if (!suffix || suffix == file_name) { + suffix = ""; + } + CGImageRef image; + if (strcasecmp(suffix, ".png") == 0) { + image = CGImageCreateWithPNGDataProvider(image_provider, NULL, true, kCGRenderingIntentDefault); + } else if ((strcasecmp(suffix, ".jpg") == 0) || (strcasecmp(suffix, ".jpeg") == 0)) { + image = + CGImageCreateWithJPEGDataProvider(image_provider, NULL, true, kCGRenderingIntentDefault); + } else { + CFRelease(image_provider); + CFRelease(file_data_ref); + fprintf(stderr, "Unknown suffix for file '%s'\n", file_name); + *out_width = 0; + *out_height = 0; + *out_channels = 0; + return std::vector(); + } + + const int width = (int)CGImageGetWidth(image); + const int height = (int)CGImageGetHeight(image); + const int channels = 4; + CGColorSpaceRef color_space = CGColorSpaceCreateDeviceRGB(); + const int bytes_per_row = (width * channels); + const int bytes_in_image = (bytes_per_row * height); + std::vector result(bytes_in_image); + const int bits_per_component = 8; + + CGContextRef context = + CGBitmapContextCreate(result.data(), width, height, bits_per_component, bytes_per_row, + color_space, kCGImageAlphaPremultipliedLast | kCGBitmapByteOrder32Big); + CGColorSpaceRelease(color_space); + CGContextDrawImage(context, CGRectMake(0, 0, width, height), image); + CGContextRelease(context); + CFRelease(image); + CFRelease(image_provider); + CFRelease(file_data_ref); + + *out_width = width; + *out_height = height; + *out_channels = channels; + return result; +} diff --git a/tensorflow/contrib/lite/examples/ios/simple/main.mm b/tensorflow/contrib/lite/examples/ios/simple/main.mm new file mode 100644 index 0000000000000000000000000000000000000000..05cb55ddd7a230593863e64b351f6aac31a1b4d7 --- /dev/null +++ b/tensorflow/contrib/lite/examples/ios/simple/main.mm @@ -0,0 +1,22 @@ +// Copyright 2015 Google Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#import + +int main(int argc, char *argv[]) { + @autoreleasepool { + NSString *delegateClassName = @"AppDelegate"; + return UIApplicationMain(argc, argv, nil, delegateClassName); + } +} diff --git a/tensorflow/contrib/lite/examples/ios/simple/simple.xcodeproj/project.pbxproj b/tensorflow/contrib/lite/examples/ios/simple/simple.xcodeproj/project.pbxproj new file mode 100644 index 0000000000000000000000000000000000000000..9277c230b8cce1b5673a50d32d7640d52e2e8f9d --- /dev/null +++ b/tensorflow/contrib/lite/examples/ios/simple/simple.xcodeproj/project.pbxproj @@ -0,0 +1,359 @@ +// !$*UTF8*$! +{ + archiveVersion = 1; + classes = { + }; + objectVersion = 46; + objects = { + +/* Begin PBXBuildFile section */ + 1C0D734B1ECCC460008C1DAB /* CoreGraphics.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 1C0D734A1ECCC460008C1DAB /* CoreGraphics.framework */; }; + 1CA45FFF1ECCC356002FA6A4 /* UIKit.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 1CA45FFE1ECCC356002FA6A4 /* UIKit.framework */; }; + 594C14AE1FB8F9B500EE8BFE /* libtensorflow-lite.a in Frameworks */ = {isa = PBXBuildFile; fileRef = 594C14AD1FB8F9B500EE8BFE /* libtensorflow-lite.a */; }; + 594C14B11FB9037100EE8BFE /* labels.txt in Resources */ = {isa = PBXBuildFile; fileRef = 594C14AF1FB9037100EE8BFE /* labels.txt */; }; + 594C14B21FB9037100EE8BFE /* mobilenet_v1_1.0_224.tflite in Resources */ = {isa = PBXBuildFile; fileRef = 594C14B01FB9037100EE8BFE /* mobilenet_v1_1.0_224.tflite */; }; + 59A3D0011CF4E68100C4259F /* AppDelegate.mm in Sources */ = {isa = PBXBuildFile; fileRef = 59A3CFF21CF4E68100C4259F /* AppDelegate.mm */; }; + 59A3D0031CF4E68100C4259F /* grace_hopper.jpg in Resources */ = {isa = PBXBuildFile; fileRef = 59A3CFF51CF4E68100C4259F /* grace_hopper.jpg */; }; + 59A3D0081CF4E68100C4259F /* ios_image_load.mm in Sources */ = {isa = PBXBuildFile; fileRef = 59A3CFFB1CF4E68100C4259F /* ios_image_load.mm */; }; + 59A3D0091CF4E68100C4259F /* main.mm in Sources */ = {isa = PBXBuildFile; fileRef = 59A3CFFC1CF4E68100C4259F /* main.mm */; }; + 59A3D00B1CF4E68100C4259F /* RunModelViewController.mm in Sources */ = {isa = PBXBuildFile; fileRef = 59A3CFFF1CF4E68100C4259F /* RunModelViewController.mm */; }; + 59A3D00C1CF4E68100C4259F /* RunModelViewController.xib in Resources */ = {isa = PBXBuildFile; fileRef = 59A3D0001CF4E68100C4259F /* RunModelViewController.xib */; }; +/* End PBXBuildFile section */ + +/* Begin PBXFileReference section */ + 1C0D73481ECCC41B008C1DAB /* CoreImage.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = CoreImage.framework; path = System/Library/Frameworks/CoreImage.framework; sourceTree = SDKROOT; }; + 1C0D734A1ECCC460008C1DAB /* CoreGraphics.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = CoreGraphics.framework; path = System/Library/Frameworks/CoreGraphics.framework; sourceTree = SDKROOT; }; + 1CA45FFE1ECCC356002FA6A4 /* UIKit.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = UIKit.framework; path = System/Library/Frameworks/UIKit.framework; sourceTree = SDKROOT; }; + 5911579B1CF4011C00C31E3A /* tf_simple_example.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = tf_simple_example.app; sourceTree = BUILT_PRODUCTS_DIR; }; + 594C14AD1FB8F9B500EE8BFE /* libtensorflow-lite.a */ = {isa = PBXFileReference; lastKnownFileType = archive.ar; name = "libtensorflow-lite.a"; path = "../../../gen/lib/libtensorflow-lite.a"; sourceTree = ""; }; + 594C14AF1FB9037100EE8BFE /* labels.txt */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text; path = labels.txt; sourceTree = ""; }; + 594C14B01FB9037100EE8BFE /* mobilenet_v1_1.0_224.tflite */ = {isa = PBXFileReference; lastKnownFileType = file; path = mobilenet_v1_1.0_224.tflite; sourceTree = ""; }; + 59A3CFF11CF4E68100C4259F /* AppDelegate.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = AppDelegate.h; sourceTree = ""; }; + 59A3CFF21CF4E68100C4259F /* AppDelegate.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = AppDelegate.mm; sourceTree = ""; }; + 59A3CFF51CF4E68100C4259F /* grace_hopper.jpg */ = {isa = PBXFileReference; lastKnownFileType = image.jpeg; path = grace_hopper.jpg; sourceTree = ""; }; + 59A3CFFA1CF4E68100C4259F /* ios_image_load.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = ios_image_load.h; sourceTree = ""; }; + 59A3CFFB1CF4E68100C4259F /* ios_image_load.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = ios_image_load.mm; sourceTree = ""; }; + 59A3CFFC1CF4E68100C4259F /* main.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = main.mm; sourceTree = ""; }; + 59A3CFFD1CF4E68100C4259F /* RunModel-Info.plist */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text.plist.xml; path = "RunModel-Info.plist"; sourceTree = ""; }; + 59A3CFFE1CF4E68100C4259F /* RunModelViewController.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = RunModelViewController.h; sourceTree = ""; }; + 59A3CFFF1CF4E68100C4259F /* RunModelViewController.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = RunModelViewController.mm; sourceTree = ""; }; + 59A3D0001CF4E68100C4259F /* RunModelViewController.xib */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = file.xib; path = RunModelViewController.xib; sourceTree = ""; }; + 73DBC33C5DD9A526EE6D1EF2 /* libPods-tf_simple_example.a */ = {isa = PBXFileReference; explicitFileType = archive.ar; includeInIndex = 0; path = "libPods-tf_simple_example.a"; sourceTree = BUILT_PRODUCTS_DIR; }; +/* End PBXFileReference section */ + +/* Begin PBXFrameworksBuildPhase section */ + 591157981CF4011C00C31E3A /* Frameworks */ = { + isa = PBXFrameworksBuildPhase; + buildActionMask = 2147483647; + files = ( + 594C14AE1FB8F9B500EE8BFE /* libtensorflow-lite.a in Frameworks */, + 1C0D734B1ECCC460008C1DAB /* CoreGraphics.framework in Frameworks */, + 1CA45FFF1ECCC356002FA6A4 /* UIKit.framework in Frameworks */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXFrameworksBuildPhase section */ + +/* Begin PBXGroup section */ + 24D7686C331131624F4454A0 /* Frameworks */ = { + isa = PBXGroup; + children = ( + 594C14AD1FB8F9B500EE8BFE /* libtensorflow-lite.a */, + 1C0D734A1ECCC460008C1DAB /* CoreGraphics.framework */, + 1C0D73481ECCC41B008C1DAB /* CoreImage.framework */, + 1CA45FFE1ECCC356002FA6A4 /* UIKit.framework */, + 73DBC33C5DD9A526EE6D1EF2 /* libPods-tf_simple_example.a */, + ); + name = Frameworks; + sourceTree = ""; + }; + 591157921CF4011C00C31E3A = { + isa = PBXGroup; + children = ( + 59A3CFF11CF4E68100C4259F /* AppDelegate.h */, + 59A3CFF21CF4E68100C4259F /* AppDelegate.mm */, + 59A3CFF31CF4E68100C4259F /* data */, + 59A3CFFA1CF4E68100C4259F /* ios_image_load.h */, + 59A3CFFB1CF4E68100C4259F /* ios_image_load.mm */, + 59A3CFFC1CF4E68100C4259F /* main.mm */, + 59A3CFFD1CF4E68100C4259F /* RunModel-Info.plist */, + 59A3CFFE1CF4E68100C4259F /* RunModelViewController.h */, + 59A3CFFF1CF4E68100C4259F /* RunModelViewController.mm */, + 59A3D0001CF4E68100C4259F /* RunModelViewController.xib */, + 5911579C1CF4011C00C31E3A /* Products */, + 24D7686C331131624F4454A0 /* Frameworks */, + ); + sourceTree = ""; + }; + 5911579C1CF4011C00C31E3A /* Products */ = { + isa = PBXGroup; + children = ( + 5911579B1CF4011C00C31E3A /* tf_simple_example.app */, + ); + name = Products; + sourceTree = ""; + }; + 59A3CFF31CF4E68100C4259F /* data */ = { + isa = PBXGroup; + children = ( + 59A3CFF51CF4E68100C4259F /* grace_hopper.jpg */, + 594C14AF1FB9037100EE8BFE /* labels.txt */, + 594C14B01FB9037100EE8BFE /* mobilenet_v1_1.0_224.tflite */, + ); + path = data; + sourceTree = ""; + }; +/* End PBXGroup section */ + +/* Begin PBXNativeTarget section */ + 5911579A1CF4011C00C31E3A /* tf_simple_example */ = { + isa = PBXNativeTarget; + buildConfigurationList = 591157B21CF4011D00C31E3A /* Build configuration list for PBXNativeTarget "tf_simple_example" */; + buildPhases = ( + 591157971CF4011C00C31E3A /* Sources */, + 591157981CF4011C00C31E3A /* Frameworks */, + 591157991CF4011C00C31E3A /* Resources */, + ); + buildRules = ( + ); + dependencies = ( + ); + name = tf_simple_example; + productName = tf_ios_makefile_example; + productReference = 5911579B1CF4011C00C31E3A /* tf_simple_example.app */; + productType = "com.apple.product-type.application"; + }; +/* End PBXNativeTarget section */ + +/* Begin PBXProject section */ + 591157931CF4011C00C31E3A /* Project object */ = { + isa = PBXProject; + attributes = { + LastUpgradeCheck = 0830; + ORGANIZATIONNAME = Google; + TargetAttributes = { + 5911579A1CF4011C00C31E3A = { + CreatedOnToolsVersion = 7.2; + DevelopmentTeam = EQHXZ8M8AV; + ProvisioningStyle = Manual; + }; + }; + }; + buildConfigurationList = 591157961CF4011C00C31E3A /* Build configuration list for PBXProject "simple" */; + compatibilityVersion = "Xcode 3.2"; + developmentRegion = English; + hasScannedForEncodings = 0; + knownRegions = ( + en, + Base, + ); + mainGroup = 591157921CF4011C00C31E3A; + productRefGroup = 5911579C1CF4011C00C31E3A /* Products */; + projectDirPath = ""; + projectRoot = ""; + targets = ( + 5911579A1CF4011C00C31E3A /* tf_simple_example */, + ); + }; +/* End PBXProject section */ + +/* Begin PBXResourcesBuildPhase section */ + 591157991CF4011C00C31E3A /* Resources */ = { + isa = PBXResourcesBuildPhase; + buildActionMask = 2147483647; + files = ( + 59A3D00C1CF4E68100C4259F /* RunModelViewController.xib in Resources */, + 594C14B11FB9037100EE8BFE /* labels.txt in Resources */, + 59A3D0031CF4E68100C4259F /* grace_hopper.jpg in Resources */, + 594C14B21FB9037100EE8BFE /* mobilenet_v1_1.0_224.tflite in Resources */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXResourcesBuildPhase section */ + +/* Begin PBXSourcesBuildPhase section */ + 591157971CF4011C00C31E3A /* Sources */ = { + isa = PBXSourcesBuildPhase; + buildActionMask = 2147483647; + files = ( + 59A3D0091CF4E68100C4259F /* main.mm in Sources */, + 59A3D0011CF4E68100C4259F /* AppDelegate.mm in Sources */, + 59A3D00B1CF4E68100C4259F /* RunModelViewController.mm in Sources */, + 59A3D0081CF4E68100C4259F /* ios_image_load.mm in Sources */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXSourcesBuildPhase section */ + +/* Begin XCBuildConfiguration section */ + 591157B01CF4011D00C31E3A /* Debug */ = { + isa = XCBuildConfiguration; + buildSettings = { + ALWAYS_SEARCH_USER_PATHS = NO; + CLANG_CXX_LANGUAGE_STANDARD = "gnu++0x"; + CLANG_CXX_LIBRARY = "libc++"; + CLANG_ENABLE_MODULES = YES; + CLANG_ENABLE_OBJC_ARC = YES; + CLANG_WARN_BOOL_CONVERSION = YES; + CLANG_WARN_CONSTANT_CONVERSION = YES; + CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; + CLANG_WARN_EMPTY_BODY = YES; + CLANG_WARN_ENUM_CONVERSION = YES; + CLANG_WARN_INFINITE_RECURSION = YES; + CLANG_WARN_INT_CONVERSION = YES; + CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; + CLANG_WARN_SUSPICIOUS_MOVE = YES; + CLANG_WARN_UNREACHABLE_CODE = YES; + CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; + "CODE_SIGN_IDENTITY[sdk=iphoneos*]" = "iPhone Developer"; + COPY_PHASE_STRIP = NO; + DEBUG_INFORMATION_FORMAT = dwarf; + ENABLE_STRICT_OBJC_MSGSEND = YES; + ENABLE_TESTABILITY = YES; + GCC_C_LANGUAGE_STANDARD = gnu99; + GCC_DYNAMIC_NO_PIC = NO; + GCC_NO_COMMON_BLOCKS = YES; + GCC_OPTIMIZATION_LEVEL = 0; + GCC_PREPROCESSOR_DEFINITIONS = ( + "DEBUG=1", + "$(inherited)", + ); + GCC_WARN_64_TO_32_BIT_CONVERSION = YES; + GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; + GCC_WARN_UNDECLARED_SELECTOR = YES; + GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; + GCC_WARN_UNUSED_FUNCTION = YES; + GCC_WARN_UNUSED_VARIABLE = YES; + IPHONEOS_DEPLOYMENT_TARGET = 8.0; + MTL_ENABLE_DEBUG_INFO = YES; + ONLY_ACTIVE_ARCH = YES; + SDKROOT = iphoneos; + TARGETED_DEVICE_FAMILY = "1,2"; + }; + name = Debug; + }; + 591157B11CF4011D00C31E3A /* Release */ = { + isa = XCBuildConfiguration; + buildSettings = { + ALWAYS_SEARCH_USER_PATHS = NO; + CLANG_CXX_LANGUAGE_STANDARD = "gnu++0x"; + CLANG_CXX_LIBRARY = "libc++"; + CLANG_ENABLE_MODULES = YES; + CLANG_ENABLE_OBJC_ARC = YES; + CLANG_WARN_BOOL_CONVERSION = YES; + CLANG_WARN_CONSTANT_CONVERSION = YES; + CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; + CLANG_WARN_EMPTY_BODY = YES; + CLANG_WARN_ENUM_CONVERSION = YES; + CLANG_WARN_INFINITE_RECURSION = YES; + CLANG_WARN_INT_CONVERSION = YES; + CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; + CLANG_WARN_SUSPICIOUS_MOVE = YES; + CLANG_WARN_UNREACHABLE_CODE = YES; + CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; + "CODE_SIGN_IDENTITY[sdk=iphoneos*]" = "iPhone Developer"; + COPY_PHASE_STRIP = NO; + DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym"; + ENABLE_NS_ASSERTIONS = NO; + ENABLE_STRICT_OBJC_MSGSEND = YES; + GCC_C_LANGUAGE_STANDARD = gnu99; + GCC_NO_COMMON_BLOCKS = YES; + GCC_WARN_64_TO_32_BIT_CONVERSION = YES; + GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; + GCC_WARN_UNDECLARED_SELECTOR = YES; + GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; + GCC_WARN_UNUSED_FUNCTION = YES; + GCC_WARN_UNUSED_VARIABLE = YES; + IPHONEOS_DEPLOYMENT_TARGET = 8.0; + MTL_ENABLE_DEBUG_INFO = NO; + SDKROOT = iphoneos; + TARGETED_DEVICE_FAMILY = "1,2"; + VALIDATE_PRODUCT = YES; + }; + name = Release; + }; + 591157B31CF4011D00C31E3A /* Debug */ = { + isa = XCBuildConfiguration; + buildSettings = { + CLANG_DEBUG_INFORMATION_LEVEL = default; + CODE_SIGN_IDENTITY = "iPhone Developer"; + DEVELOPMENT_TEAM = EQHXZ8M8AV; + ENABLE_BITCODE = NO; + GCC_ENABLE_CPP_EXCEPTIONS = YES; + GCC_ENABLE_CPP_RTTI = YES; + HEADER_SEARCH_PATHS = ( + "$(inherited)", + ../../../../../../, + ../../../downloads/flatbuffers/include/, + ../../../downloads/eigen/, + ../../../downloads/, + ); + INFOPLIST_FILE = "$(SRCROOT)/RunModel-Info.plist"; + IPHONEOS_DEPLOYMENT_TARGET = 9.2; + LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks"; + LIBRARY_SEARCH_PATHS = ../../../gen/lib/; + OTHER_CPLUSPLUSFLAGS = "$(OTHER_CFLAGS)"; + OTHER_LDFLAGS = "$(inherited)"; + PRODUCT_BUNDLE_IDENTIFIER = "com.google.tflite-simple-example"; + PRODUCT_NAME = "$(TARGET_NAME)"; + PROVISIONING_PROFILE = "1072bd47-ff19-4e5f-8107-d912748f83f1"; + PROVISIONING_PROFILE_SPECIFIER = "Google Development"; + SEPARATE_STRIP = NO; + }; + name = Debug; + }; + 591157B41CF4011D00C31E3A /* Release */ = { + isa = XCBuildConfiguration; + buildSettings = { + CLANG_DEBUG_INFORMATION_LEVEL = default; + CODE_SIGN_IDENTITY = "iPhone Developer"; + DEVELOPMENT_TEAM = ""; + ENABLE_BITCODE = NO; + GCC_ENABLE_CPP_EXCEPTIONS = YES; + GCC_ENABLE_CPP_RTTI = YES; + HEADER_SEARCH_PATHS = ( + "$(inherited)", + ../../../../../../, + ../../../downloads/flatbuffers/include/, + ../../../downloads/eigen/, + ../../../downloads/, + ); + INFOPLIST_FILE = "$(SRCROOT)/RunModel-Info.plist"; + IPHONEOS_DEPLOYMENT_TARGET = 9.2; + LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks"; + LIBRARY_SEARCH_PATHS = ../../../gen/lib/; + ONLY_ACTIVE_ARCH = YES; + OTHER_CPLUSPLUSFLAGS = "$(OTHER_CFLAGS)"; + OTHER_LDFLAGS = "$(inherited)"; + PRODUCT_BUNDLE_IDENTIFIER = "com.google.tflite-simple-example"; + PRODUCT_NAME = "$(TARGET_NAME)"; + PROVISIONING_PROFILE_SPECIFIER = ""; + SEPARATE_STRIP = NO; + }; + name = Release; + }; +/* End XCBuildConfiguration section */ + +/* Begin XCConfigurationList section */ + 591157961CF4011C00C31E3A /* Build configuration list for PBXProject "simple" */ = { + isa = XCConfigurationList; + buildConfigurations = ( + 591157B01CF4011D00C31E3A /* Debug */, + 591157B11CF4011D00C31E3A /* Release */, + ); + defaultConfigurationIsVisible = 0; + defaultConfigurationName = Release; + }; + 591157B21CF4011D00C31E3A /* Build configuration list for PBXNativeTarget "tf_simple_example" */ = { + isa = XCConfigurationList; + buildConfigurations = ( + 591157B31CF4011D00C31E3A /* Debug */, + 591157B41CF4011D00C31E3A /* Release */, + ); + defaultConfigurationIsVisible = 0; + defaultConfigurationName = Release; + }; +/* End XCConfigurationList section */ + }; + rootObject = 591157931CF4011C00C31E3A /* Project object */; +} diff --git a/tensorflow/contrib/lite/g3doc/TFLite-Architecture.jpg b/tensorflow/contrib/lite/g3doc/TFLite-Architecture.jpg new file mode 100644 index 0000000000000000000000000000000000000000..bc83946647c6a923a8a0bd3a041b42e4febe6a31 Binary files /dev/null and b/tensorflow/contrib/lite/g3doc/TFLite-Architecture.jpg differ diff --git a/tensorflow/contrib/lite/g3doc/apis.md b/tensorflow/contrib/lite/g3doc/apis.md new file mode 100644 index 0000000000000000000000000000000000000000..fe208e47d1ac10995881e55c8596ae14ff4242df --- /dev/null +++ b/tensorflow/contrib/lite/g3doc/apis.md @@ -0,0 +1,359 @@ +# TensorFlow Lite APIs + +TensorFlow Lite provides programming APIs in C++ and Java, and in both cases +the API design reflects a preference for performance over ease of use. +TensorFlow Lite is designed for fast inference on small devices so it should be +no surprise that the APIs try to avoid unnecessary copies at the expense of +convenience. Similarly, consistency with TensorFlow APIs was not an explicit +goal and some variance is to be expected. + +## C++ + +In order to run the inference model in TensorFlow Lite, one has to load the +model into a `FlatBufferModel` object which then can be executed by an +`Interpreter`. The `FlatBufferModel` needs to remain valid for the whole +lifetime of the `Interpreter`, and a single `FlatBufferModel` can be +simultaneously used by more than one `Interpreter`. In concrete terms, the +`FlatBufferModel` object must be created before any `Interpreter` objects that +use it, and must be kept around until they have all been destroyed. + +The simplest usage of TensorFlow Lite will look like this: + +```c++ +tflite::FlatBufferModel model(path_to_model); +tflite::ops::builtin::BuiltinOpResolver resolver; +std::unique_ptr interpreter; +tflite::InterpreterBuilder(*model, resolver)(&interpreter); +// Resize input tensors, if desired. +interpreter->AllocateTensors(); +float* input = interpreter->typed_input_tensor(0); +// Fill `input`. +interpreter->Invoke(); +float* output = interpreter->type_output_tensor(0); +``` +### Data Alignment + +TensorFlow Lite data is usually aligned to 32-bit boundaries. It is recommended +that all data provided to TensorFlow Lite be aligned that way. + +### Error Reporting + +In many places TensorFlow Lite returns status information through +`TfLiteStatus` objects: + +```c++ +typedef enum { + kTfLiteOk = 0, + kTfLiteError = 1 +} TfLiteStatus; + +``` + +Failures can be easily verified with: +```c++ +if (status != kTfLiteOk) { + // ... error handling here ... +} +``` + +In order to obtain detailed error information an ErrorReporter must be +provided: + +```c++ +class ErrorReporter { + virtual int Report(const char* format, va_list args) = 0; +}; +``` + +The `DefaultErrorReporter` takes care of reporting to `stderr`. + +### Loading a Model + +The `FlatBufferModel` class encapsulates a model and can be built in a couple of +slightly different ways depending on where the model is stored: + +```c++ +class FlatBufferModel { +  // Build a model based on a file. Return a nullptr in case of failure. +  static std::unique_ptr BuildFromFile( +      const char* filename, +      ErrorReporter* error_reporter); + +  // Build a model based on a pre-loaded flatbuffer. The caller retains +  // ownership of the buffer and should keep it alive until the returned object +  // is destroyed. Return a nullptr in case of failure. +  static std::unique_ptr BuildFromBuffer( +      const char* buffer, +      size_t buffer_size, +      ErrorReporter* error_reporter); +}; +``` + +Note that if TensorFlow Lite detects the presence of Android's NNAPI it will +automatically try to use shared memory to store the FlatBufferModel. + +### Running a Model + +Running a model involves a few simple steps: + + * Build an `Interpreter` based on an existing `FlatBufferModel` + * Optionally resize input tensors if the predefined sizes are not desired. + * Set input tensor values + * Invoke inference + * Read output tensor values + +The important parts of public interface of the `Interpreter` are provided +below. It should be noted that: + + * Tensors are represented by integers, in order to avoid string comparisons + (and any fixed dependency on string libraries). + * An interpreter must not be accessed from concurrent threads + * Memory allocation for input and output tensors must be triggered + by calling AllocateTensors() right after resizing tensors. + +```c++ +class Interpreter { + Interpreter(ErrorReporter* error_reporter); + + // Read only access to list of inputs. + const std::vector& inputs() const; + + // Read only access to list of outputs. + const std::vector& outputs() const; + + // Change the dimensionality of a given tensor. + TfLiteStatus ResizeInputTensor(int tensor_index, + const std::vector& dims); + + // Returns status of success or failure. + TfLiteStatus AllocateTensors(); + + // Return a pointer into the data of a given input tensor. + template + T* typed_input_tensor(int index) { + return typed_tensor(inputs_[index]); + } + + // Return a pointer into the data of a given output tensor. + template + T* typed_output_tensor(int index) { + return typed_tensor(outputs_[index]); + } + + // Execute the model, populating output tensors. + TfLiteStatus Invoke(); +}; +``` + +### Writing Custom Operators + +All TensorFlow Lite operators (both custom and builtin) are defined using a +simple pure-C interface that consists of four functions: + +```c++ +typedef struct { + void* (*init)(TfLiteContext* context, const char* buffer, size_t length); + void (*free)(TfLiteContext* context, void* buffer); + TfLiteStatus (*prepare)(TfLiteContext* context, TfLiteNode* node); + TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node); +} TfLiteRegistration; +``` + +Refer to `context.h` for details on `TfLiteContext` and `TfLiteNode`. The +former provides error reporting facilities and access to global objects, +including all the tensors. The latter allows implementations to access their +inputs and outputs. + +When the interpreter loads a model, it calls init() once for each node in the +graph. A given `init()` will be called more than once if the op is used +multiple times in the graph. For custom ops a configuration buffer will be +provided, containing a flexbuffer that maps parameter names to their values. +The buffer is empty for builtin ops because the interpreter has already parsed +the op parameters. Kernel implementation that require state should initialize +it here and transfer ownership to the caller. For each `init()` call, there +will be a corresponding call to `free()`, allowing implementations to dispose +of the buffer they might have allocated in `init()`. + +Whenever the input tensors are resized the interpreter will go through the +graph notifying implementations of the change. This gives them the chance to +resize their internal buffer, check validity of input shapes and types, and +recalculate output shapes. This is all done through `prepare()` and +implementation can access their state using `node->user_data`. + +Finally, each time inference runs the interpreter traverses the graph calling +`invoke()`, and here too the state is available as `node->user_data`. + +Custom ops can be implemented in exactly the same way as builtin ops, by +defined those four functions and a global registration function that usually +looks like this: + +```c++ +namespace tflite { +namespace ops { +namespace custom { + TfLiteRegistration* Register_MY_CUSTOM_OP() { + static TfLiteRegistration r = {my_custom_op::Init, + my_custom_op::Free, + my_custom_op::Prepare, + my_custom_op::Eval}; + return &r; + } +} // namespace custom +} // namespace ops +} // namespace tflite +``` + +Note that registration is not automatic and an explicit call to +`Register_MY_CUSTOM_OP` should be made somewhere. While the standard +`:builtin_ops` takes care of the registration of builtins, custom ops will have +to be collected in separated custom libraries. + +### Customizing the kernel library + +Behind the scenes the interpreter will load a library of kernels which will be +assigned to execute each of the operators in the model. While the default +library only contains builtin kernels, it is possible to replace it with a +custom library. + +The interpreter uses an `OpResolver` to translate operator codes and names into +actual code: + +```c++ +class OpResolver { + virtual TfLiteRegistration* FindOp(tflite::BuiltinOperator op) const = 0; + virtual TfLiteRegistration* FindOp(const char* op) const = 0; + virtual void AddOp(tflite::BuiltinOperator op, TfLiteRegistration* registration) = 0; + virtual void AddOp(const char* op, TfLiteRegistration* registration) = 0; +}; +``` + +The regular usage will require the developer to use the `BuiltinOpResolver` and +write: + +```c++ +tflite::ops::builtin::BuiltinOpResolver resolver; +``` + +They can then optionally register custom ops: + +```c++ +resolver.AddOp("MY_CUSTOM_OP", Register_MY_CUSTOM_OP()); +``` + +before the resolver is passed to the `InterpreterBuilder`. + +If the set of builtin ops is deemed to be too large, a new `OpResolver` could +be code-generated based on a given subset of ops, possibly only the ones +contained in a given model. This is the equivalent of TensorFlow's selective +registration (and a simple version of it is available in the `tools` +directory). + +## Java + +TensorFlow Lite's Java API supports on-device inference and is provided as an +Android Studio Library that allows loading models, feeding inputs, and +retrieving inference outputs. + +The simplest usage of Tensorflow Lite Java API looks like this: + +```java +try (Interpreter interpreter = new Interpreter(file_of_a_tensorflowlite_model)) { + interpreter.run(input, output); +} +``` + +### Loading a Model + +The `Interpreter.java` class drives model inference with TensorFlow Lite. In +most of the cases, this is the only class an app developer will need. + +#### Initializing an `Interpreter` With a Model File + +The `Interpreter` can be initialized with a model file using the constructor: + +```java +public Interpreter(@NotNull File modelFile); +``` + +or with a `MappedByteBuffer`: + +```java +public Interpreter(@NotNull MappedByteBuffer mappedByteBuffer); +``` + +In both cases a valid TensorFlow Lite must be provided or an +`IllegalArgumentException` with be thrown. If a `MappedByteBuffer` is used to +initialize an Interpreter, it should remain unchanged for the whole lifetime of +the `Interpreter`. + +### Running a Model + +#### Supported Data Types + +To use TensorFlow Lite, the data types of the input and output tensors must be +one of the following primitive types: + +* `float` +* `int` +* `long` +* `byte` + +If other data types, including boxed types like `Integer` and `Float`, are used, +an `IllegalArgumentException` will be thrown. + +#### Inputs + +Each input should be an array, a multi-dimensional array, or a `ByteBuffer` of +the supported primitive types. + +The use of `ByteBuffer` is preferred since it allows the `Interpreter` to avoid +unnecessary copies. Each `ByteBuffer` needs to be a direct byte buffer, and its +order must be `ByteOrder.nativeOrder()`. After it is used for a model inference, +it must remain unchanged until the model inference is finished. + +#### Outputs + +Each output should be an array, or a multi-dimensional array of the supported +primitive types. + +#### Running Model Inference + +If a model takes only one input and returns only one output, the following will +trigger an inference run: + +```java +interpreter.run(input, output); +``` + +For models with multiple inputs, or multiple outputs, use: + +```java +interpreter.runForMultipleInputsOutputs(inputs, map_of_indices_to_outputs); +``` + +where each entry in `inputs` corresponds to an input tensor and +`map_of_indices_to_outputs` maps indices of output tensors to the +corresponding output data. In both cases the tensor indices should correspond to +the values given to the `TensorFlow Lite Optimized Converter` when the model was +created. Be aware that the order of tensors in `input` must match the order +given to the `TensorFlow Lite Optimized Converter`. + +The Java API also provides convenient functions for app developers to get the +index of any model input or output using a tensor name: + +```java +public int getInputIndex(String tensorName); +public int getOutputIndex(String tensorName); +``` + +If tensorName is not a valid name in model, an `IllegalArgumentException` will +be thrown. + +### Releasing Resources After Use + +An `Interpreter` owns resources. To avoid memory leak, the resources must be +released after use by: + +```java +interpreter.close(); +``` diff --git a/tensorflow/contrib/lite/g3doc/custom_operators.md b/tensorflow/contrib/lite/g3doc/custom_operators.md new file mode 100644 index 0000000000000000000000000000000000000000..204a489a93519309bb09238f1b2c8bbd4f1f19e4 --- /dev/null +++ b/tensorflow/contrib/lite/g3doc/custom_operators.md @@ -0,0 +1,91 @@ +# How to use custom operators + +TensorFlow Lite currently supports a subset of TensorFlow operators. However, it +does support the use of user-provided implementations (as known as custom +implementations) if the model contains an operator that is not supported. + +Let’s walk through this via an example. Assume we are using the `Sin` operator +and that we are building a very simple model for a function `y = sin(x + +offset)`, where `offset` is trainable. + +The code to train the TensorFlow model will be something like: + +```python +offset = tf.get_variable("offset", [1,], tf.float32) +x = tf.placeholder(tf.float32, shape=(None,)) +y = tf.sin(x + offset) +y_ = tf.placeholder(tf.float32, shape=(None,)) +loss = tf.reduce_sum(tf.square(y - y_)) +optimizer = tf.train.GradientDescentOptimizer(0.001) +train = optimizer.minimize(loss) +``` + +If you convert this model to Tensorflow Lite format using the TensorFlow Lite +Optimizing Converter with `--allow_custom_ops` argument, and run it with the +default interpreter, the interpreter will raise the following error messages: + +``` +Didn't find custom op for name 'Sin' +Registration failed. +``` + +All we need to do to use the op in TensorFlow Lite is define two functions +(`Prepare` and `Eval`), and construct a `TfLiteRegistration`. This code would +look something like this: + +```cpp +TfLiteStatus SinPrepare(TfLiteContext* context, TfLiteNode* node) { + using namespace tflite; + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + TfLiteTensor* input = GetInput(context, node, 0); + TfLiteTensor* output = GetOutput(context, node, 0); + + int num_dims = NumDimensions(input); + + TfLiteIntArray* output_size = TfLiteIntArrayCreate(num_dims); + for (int i=0; idata[i] = input->dims->data[i]; + } + + return context->ResizeTensor(context, output, output_size); +} + +TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) { + using namespace tflite; + TfLiteTensor* input = GetInput(context, node,0); + TfLiteTensor* output = GetOutput(context, node,0); + + float* input_data = input->data.f; + float* output_data = output->data.f; + + size_t count = 1; + int num_dims = NumDimensions(input); + for (int i = 0; i < num_dims; ++i) { + count *= input->dims->data[i]; + } + + for (size_t i=0; i +#include +#include +#include +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/kernels/gemm_support.h" +#include "tensorflow/contrib/lite/nnapi_delegate.h" + +namespace { + +// Memory allocation tuning +constexpr const int kDefaultArenaAlignment = 64; +constexpr const int kDefaultTensorAlignment = 4; +// std::vector preallocation tuning. +constexpr const int kSlotsToReserve = 128; + +} // namespace + +namespace tflite { + +Interpreter::Interpreter(ErrorReporter* error_reporter) + : arena_(kDefaultArenaAlignment), + persistent_arena_(kDefaultArenaAlignment), + error_reporter_(error_reporter ? error_reporter + : DefaultErrorReporter()) { + context_.impl_ = static_cast(this); + context_.ResizeTensor = ResizeTensor; + context_.ReportError = ReportError; + context_.AddTensors = AddTensors; + context_.tensors = nullptr; + context_.tensors_size = 0; + context_.gemm_context = nullptr; + // Reserve some space for the tensors to avoid excessive resizing. + tensors_.reserve(kSlotsToReserve); + nodes_and_registration_.reserve(kSlotsToReserve); + next_allocate_node_id_ = 0; + UseNNAPI(false); +} + +Interpreter::~Interpreter() { + for (auto& nodeAndReg : nodes_and_registration_) { + TfLiteNode& node = nodeAndReg.first; + TfLiteIntArrayFree(node.inputs); + TfLiteIntArrayFree(node.outputs); + TfLiteIntArrayFree(node.temporaries); + if (node.builtin_data) free(node.builtin_data); + OpFree(nodeAndReg.second, node.user_data); + node.builtin_data = nullptr; + } + + for (int i = 0; i < context_.tensors_size; i++) { + TfLiteTensorFree(&context_.tensors[i]); + } +} + +TfLiteStatus Interpreter::SetInputs(std::vector inputs) { + TF_LITE_ENSURE_OK(&context_, + CheckTensorIndices("inputs", inputs.data(), inputs.size())); + inputs_ = std::move(inputs); + return kTfLiteOk; +} + +TfLiteStatus Interpreter::SetOutputs(std::vector outputs) { + TF_LITE_ENSURE_OK( + &context_, CheckTensorIndices("outputs", outputs.data(), outputs.size())); + outputs_ = std::move(outputs); + return kTfLiteOk; +} + +TfLiteStatus Interpreter::CheckTensorIndices(const char* label, + const int* indices, int length) { + // Making sure kOptionalTensor is not re-defined to something other than -1. + static_assert(kOptionalTensor == -1, "kOptionalTensor should be defined -1"); + + for (int i = 0; i < length; i++) { + int index = indices[i]; + if (index < kOptionalTensor || index >= context_.tensors_size) { + ReportError(&context_, "Invalid tensor index %d in %s\n", index, label); + consistent_ = false; + return kTfLiteError; + } + } + return kTfLiteOk; +} + +TfLiteStatus Interpreter::BytesRequired(TfLiteType type, const int* dims, + int dims_size, size_t* bytes) { + // TODO(aselle): Check for overflow here using overflow.h in TensorFlow + // MultiplyWithoutOverflow. + TF_LITE_ENSURE(&context_, bytes != nullptr); + size_t count = 1; + for (int k = 0; k < dims_size; k++) count *= dims[k]; + switch (type) { + case kTfLiteFloat32: + *bytes = sizeof(float) * count; + break; + case kTfLiteInt32: + *bytes = sizeof(int32_t) * count; + break; + case kTfLiteUInt8: + *bytes = sizeof(uint8_t) * count; + break; + case kTfLiteInt64: + *bytes = sizeof(int64_t) * count; + break; + default: + ReportError(&context_, + "Only float32, int32, int64, uint8 supported currently."); + return kTfLiteError; + } + return kTfLiteOk; +} + +TfLiteStatus Interpreter::AllocateTensorsWhoseSizesAreKnown() { + if (!consistent_) { + ReportError(&context_, "AllocateTensors() called on inconsistent model."); + return kTfLiteError; + } + if (next_allocate_node_id_ == nodes_and_registration_.size() && invokable_) { + return kTfLiteOk; + } + allocs_and_refcounts_.resize(context_.tensors_size); + + int new_next_allocate_node_id = next_allocate_node_id_; + invokable_ = false; + + // Allocate graph input nodes. + if (next_allocate_node_id_ == 0) { + for (int i = 0; i < inputs_.size(); ++i) { + int tensor_index = inputs_[i]; + if (tensor_index == kOptionalTensor) { + continue; + } + TfLiteTensor& tensor = context_.tensors[tensor_index]; + if (tensor.allocation_type == kTfLiteArenaRw) { + TF_LITE_ENSURE_OK( + &context_, + arena_.Allocate(&context_, kDefaultTensorAlignment, tensor.bytes, + &allocs_and_refcounts_[tensor_index].alloc)); + } + } + // Add 1 to output tensors, so they will not get overwritten. + for (int i = 0; i < outputs_.size(); ++i) { + allocs_and_refcounts_[outputs_[i]].count++; + } + } + + // Count references to node input tensors, and resize node-referenced tensors + // until we encounter a node that has a dynamic output tensor. + for (int k = next_allocate_node_id_; k < nodes_and_registration_.size(); + k++) { + new_next_allocate_node_id++; + TfLiteNode& node = nodes_and_registration_[k].first; + const TfLiteRegistration& registration = nodes_and_registration_[k].second; + if (OpPrepare(registration, &node) == kTfLiteError) { + return kTfLiteError; + } + + TfLiteIntArray* node_inputs = node.inputs; + for (int i = 0; i < node_inputs->size; ++i) { + int tensor_index = node_inputs->data[i]; + if (tensor_index != kOptionalTensor) { + allocs_and_refcounts_[node_inputs->data[i]].count++; + } + } + + // Discontinue if the node has dynamic outputs. + bool has_unallocated_dynamic_tensor = false; + TfLiteIntArray* node_outputs = node.outputs; + for (int i = 0; i < node_outputs->size; ++i) { + TfLiteTensor& tensor = context_.tensors[node_outputs->data[i]]; + if (tensor.allocation_type == kTfLiteDynamic) { + has_unallocated_dynamic_tensor = true; + break; + } + } + if (has_unallocated_dynamic_tensor) { + break; + } + } + + // Allocate graph persistent outputs, e.g. RNN cell states, etc. + for (int k = next_allocate_node_id_; k < new_next_allocate_node_id; k++) { + TfLiteNode& node = nodes_and_registration_[k].first; + + // Go through output tensors and allocate the persistent ones first. + TfLiteIntArray* node_outputs = node.outputs; + for (int i = 0; i < node_outputs->size; ++i) { + int tensor_index = node_outputs->data[i]; + TfLiteTensor& tensor = context_.tensors[tensor_index]; + if (tensor.allocation_type == kTfLiteArenaRwPersistent) { + TF_LITE_ENSURE_OK(&context_, + persistent_arena_.Allocate( + &context_, kDefaultTensorAlignment, tensor.bytes, + &allocs_and_refcounts_[tensor_index].alloc)); + } + } + } + + // Go through the graph in execution order. + for (int k = next_allocate_node_id_; k < new_next_allocate_node_id; k++) { + TfLiteNode& node = nodes_and_registration_[k].first; + + // First allocate output tensors. + TfLiteIntArray* node_outputs = node.outputs; + for (int i = 0; i < node_outputs->size; ++i) { + int tensor_index = node_outputs->data[i]; + TfLiteTensor& tensor = context_.tensors[tensor_index]; + if (tensor.allocation_type == kTfLiteArenaRw) { + TF_LITE_ENSURE_OK( + &context_, + arena_.Allocate(&context_, kDefaultTensorAlignment, tensor.bytes, + &allocs_and_refcounts_[tensor_index].alloc)); + } + } + // Then the temporaries, in two passes. First allocate them all, them + // deallocate them. + TfLiteIntArray* node_temporaries = node.temporaries; + for (int i = 0; i < node_temporaries->size; ++i) { + int tensor_index = node_temporaries->data[i]; + TfLiteTensor& tensor = context_.tensors[tensor_index]; + if (tensor.allocation_type == kTfLiteArenaRw) { + TF_LITE_ENSURE_OK( + &context_, + arena_.Allocate(&context_, kDefaultTensorAlignment, tensor.bytes, + &allocs_and_refcounts_[tensor_index].alloc)); + } + } + for (int i = 0; i < node_temporaries->size; ++i) { + int tensor_index = node_temporaries->data[i]; + TfLiteTensor& tensor = context_.tensors[tensor_index]; + allocs_and_refcounts_[tensor_index].count--; + if (tensor.allocation_type == kTfLiteArenaRw && + allocs_and_refcounts_[tensor_index].count == 0) { + TF_LITE_ENSURE_OK( + &context_, + arena_.Deallocate(&context_, + allocs_and_refcounts_[tensor_index].alloc)); + } + } + + // Then process the node's inputs. + TfLiteIntArray* node_inputs = node.inputs; + for (int i = 0; i < node_inputs->size; ++i) { + int tensor_index = node_inputs->data[i]; + if (tensor_index == kOptionalTensor) { + continue; + } + TfLiteTensor& tensor = context_.tensors[tensor_index]; + + // Decrease reference count and deallocate if not needed anymore. + allocs_and_refcounts_[tensor_index].count--; + if (tensor.allocation_type == kTfLiteArenaRw && + allocs_and_refcounts_[tensor_index].count == 0) { + TF_LITE_ENSURE_OK( + &context_, + arena_.Deallocate(&context_, + allocs_and_refcounts_[tensor_index].alloc)); + } + } + } + + // Resize the buffer and commit the arena. + TF_LITE_ENSURE_OK(&context_, arena_.Commit(&context_)); + TF_LITE_ENSURE_OK(&context_, persistent_arena_.Commit(&context_)); + + // Rewire the tensors to use the underlying arena buffer. + for (int i = 0; i < context_.tensors_size; ++i) { + TfLiteTensor& tensor = context_.tensors[i]; + if (tensor.allocation_type == kTfLiteArenaRw) { + TF_LITE_ENSURE_OK( + &context_, + arena_.ResolveAlloc(&context_, allocs_and_refcounts_[i].alloc, + &tensor.data.raw)); + } + if (tensor.allocation_type == kTfLiteArenaRwPersistent) { + TF_LITE_ENSURE_OK( + &context_, + persistent_arena_.ResolveAlloc( + &context_, allocs_and_refcounts_[i].alloc, &tensor.data.raw)); + } + } + + invokable_ = true; + next_allocate_node_id_ = new_next_allocate_node_id; + return kTfLiteOk; +} + +namespace { +TfLiteIntArray* convertVectorToTfLiteIntArray(const std::vector& x) { + TfLiteIntArray* lite = TfLiteIntArrayCreate(x.size()); + for (size_t i = 0; i < x.size(); i++) lite->data[i] = x[i]; + return lite; +} +} // namespace + +TfLiteStatus Interpreter::AllocateTensors() { + next_allocate_node_id_ = 0; + TF_LITE_ENSURE_OK(&context_, arena_.Clear()); + TF_LITE_ENSURE_OK(&context_, persistent_arena_.Clear()); + allocs_and_refcounts_.clear(); + return AllocateTensorsWhoseSizesAreKnown(); +} + +TfLiteStatus Interpreter::AddNodeWithParameters( + const std::vector& inputs, const std::vector& outputs, + const char* init_data, size_t init_data_size, void* builtin_data, + const TfLiteRegistration* registration, int* node_index) { + invokable_ = false; + + std::unique_ptr builtin_data_deleter(builtin_data, + free); + + TF_LITE_ENSURE_OK(&context_, CheckTensorIndices("node inputs", inputs.data(), + inputs.size())); + TF_LITE_ENSURE_OK( + &context_, + CheckTensorIndices("node outputs", outputs.data(), outputs.size())); + + if (node_index) *node_index = nodes_and_registration_.size(); + nodes_and_registration_.resize(nodes_and_registration_.size() + 1); + auto& node_and_reg = nodes_and_registration_.back(); + TfLiteNode& node = node_and_reg.first; + if (node.inputs) TfLiteIntArrayFree(node.inputs); + if (node.outputs) TfLiteIntArrayFree(node.outputs); + if (node.temporaries) TfLiteIntArrayFree(node.temporaries); + + // NOTE, here we are not using move semantics yet, since our internal + // representation isn't std::vector, but in the future we would like to avoid + // copies, so we want the interface to take r-value references now. + node.inputs = convertVectorToTfLiteIntArray(inputs); + node.outputs = convertVectorToTfLiteIntArray(outputs); + node.temporaries = TfLiteIntArrayCreate(0); + if (init_data) { + node.user_data = OpInit(*registration, init_data, init_data_size); + } else { + node.user_data = + OpInit(*registration, + reinterpret_cast(builtin_data_deleter.get()), 0); + } + node.builtin_data = builtin_data_deleter.release(); + node_and_reg.second = *registration; + return kTfLiteOk; +} + +TfLiteStatus Interpreter::ResizeInputTensor(int tensor_index, + const std::vector& dims) { + // TODO(aselle): All bounds checks can be implemented as one-sided bounds + // checks by casting to unsigned for efficiency. Profile before doing this. + + TF_LITE_ENSURE(&context_, + tensor_index < context_.tensors_size && tensor_index >= 0); + invokable_ = false; + TfLiteIntArray* dims_lite = convertVectorToTfLiteIntArray(dims); + return ResizeTensorImpl(&context_.tensors[tensor_index], dims_lite); +} + +TfLiteStatus Interpreter::Invoke() { + if (!consistent_) { + ReportError(&context_, "Invoke called on model that is not consistent."); + return kTfLiteError; + } + if (!invokable_) { + ReportError(&context_, "Invoke called on model that is not ready."); + return kTfLiteError; + } + + TfLiteStatus status = kTfLiteOk; + if (nnapi_delegate_) { + if (AllocateTensorsWhoseSizesAreKnown() == kTfLiteError) { + return kTfLiteError; + } + if (next_allocate_node_id_ == nodes_and_registration_.size()) { + TF_LITE_ENSURE_OK(&context_, nnapi_delegate_->Invoke(this)); + return kTfLiteOk; + } else { + // TODO(aselle): In the future, we would like this to be an + // automatic tflite CPU fallback. + ReportError(&context_, + "NNAPI was requested, but dependent sized tensors " + "being used.\n"); + return kTfLiteError; + } + } + + for (int i = 0; i < nodes_and_registration_.size(); i++) { + // Ensure we have allocated up to this node. The point of this is to + // allocate as much as possible before running any evaluation, but + // dynamic shapes can prevent this from being possible. + if (i >= next_allocate_node_id_) { + if (AllocateTensorsWhoseSizesAreKnown() == kTfLiteError) { + return kTfLiteError; + } + } + TfLiteNode& node = nodes_and_registration_[i].first; + const TfLiteRegistration& registration = nodes_and_registration_[i].second; + if (OpInvoke(registration, &node) == kTfLiteError) { + status = kTfLiteError; + } + } + return status; +} + +TfLiteStatus Interpreter::ResizeTensor(TfLiteContext* context, + TfLiteTensor* tensor, + TfLiteIntArray* new_size) { + // Note here that context->impl_ is recovering the this pointer for an + // instance of Interpreter to call into the member function ResizeTensorImpl + // (this function is static). + return static_cast(context->impl_) + ->ResizeTensorImpl(tensor, new_size); +} + +void Interpreter::ReportErrorImpl(const char* format, va_list args) { + error_reporter_->Report(format, args); +} + +void Interpreter::ReportError(TfLiteContext* context, const char* format, ...) { + va_list args; + va_start(args, format); + auto* f = static_cast(context->impl_); + // Note here that context->impl_ is recovering the this pointer for an + // instance of Interpreter to call into the member function ReportErrorImpl + // (this function is static). + f->ReportErrorImpl(format, args); + va_end(args); +} + +TfLiteStatus Interpreter::AddTensors(int tensors_to_add, + int* first_new_tensor_index) { + int base_index = tensors_.size(); + if (first_new_tensor_index) *first_new_tensor_index = base_index; + tensors_.resize(tensors_.size() + tensors_to_add); + for (int i = base_index; i < tensors_.size(); i++) { + memset(&tensors_[i], 0, sizeof(tensors_[i])); + } + context_.tensors = tensors_.data(); + context_.tensors_size = tensors_.size(); + return kTfLiteOk; +} + +TfLiteStatus Interpreter::AddTensors(TfLiteContext* context, int tensors_to_add, + int* first_new_tensor_index) { + // Note here that context->impl_ is recovering the this pointer for an + // instance of Interpreter to call into the member function AddTensors + // (this function is static). + return static_cast(context->impl_) + ->AddTensors(tensors_to_add, first_new_tensor_index); +} + +TfLiteStatus Interpreter::SetTensorParametersReadOnly( + int tensor_index, TfLiteType type, const char* name, + const std::vector& dims, TfLiteQuantizationParams quantization, + const char* buffer, size_t bytes, const Allocation* allocation) { + TF_LITE_ENSURE(&context_, + tensor_index < context_.tensors_size && tensor_index >= 0); + // For most tensors we know exactly how much memory is necessary so we can + // ensure the buffer is large enough. However, we need to skip string tensors + // because their sizes change with the contents of the individual strings. + if (type != kTfLiteString) { + size_t required_bytes; + TF_LITE_ENSURE_OK(&context_, BytesRequired(type, dims.data(), dims.size(), + &required_bytes)); + TF_LITE_ENSURE_EQ(&context_, required_bytes, bytes); + } + invokable_ = false; + TfLiteTensorReset(type, name, convertVectorToTfLiteIntArray(dims), + quantization, const_cast(buffer), bytes, + kTfLiteMmapRo, allocation, &context_.tensors[tensor_index]); + return kTfLiteOk; +} + +// Set description of inputs/outputs/data/fptrs for node `node_index`. +// This variant assumes an external buffer has been allocated of size +// bytes. The lifetime of buffer must be ensured to be greater or equal +// to Interpreter. +TfLiteStatus Interpreter::SetTensorParametersReadWrite( + int tensor_index, TfLiteType type, const char* name, + const std::vector& dims, TfLiteQuantizationParams quantization) { + invokable_ = false; + TF_LITE_ENSURE(&context_, + tensor_index < context_.tensors_size && tensor_index >= 0); + size_t required_bytes = 0; + if (type != kTfLiteString) { + // These types will be allocated in our arena so we need to record how + // many bytes we will need based on the dimensions. String tensors are + // allocated dynamically and we can't know ahead of time how much space + // they will require. + TF_LITE_ENSURE_OK(&context_, BytesRequired(type, dims.data(), dims.size(), + &required_bytes)); + } + TfLiteTensorReset(type, name, convertVectorToTfLiteIntArray(dims), + quantization, + /*buffer=*/nullptr, required_bytes, + type == kTfLiteString ? kTfLiteDynamic : kTfLiteArenaRw, + nullptr, &context_.tensors[tensor_index]); + return kTfLiteOk; +} + +TfLiteStatus Interpreter::ResizeTensorImpl(TfLiteTensor* tensor, + TfLiteIntArray* new_size) { + // Note that in theory we could resize kTfLiteArenaRwPersistent tensors too. + if (tensor->allocation_type == kTfLiteArenaRw || + tensor->allocation_type == kTfLiteDynamic) { + if (tensor->type != kTfLiteString) { + size_t bytesRequired; + TfLiteStatus status = BytesRequired(tensor->type, new_size->data, + new_size->size, &bytesRequired); + if (status != kTfLiteOk) { + TfLiteIntArrayFree(new_size); + return kTfLiteError; + } + tensor->bytes = bytesRequired; + } + if (tensor->dims) TfLiteIntArrayFree(tensor->dims); + tensor->dims = new_size; + + if (tensor->allocation_type != kTfLiteDynamic) { + tensor->data.raw = nullptr; + } + } else { + // kTfLiteMmapRo tensors are stored in the flatbuffer and are therefore + // of fixed size. + TfLiteIntArrayFree(new_size); + ReportError(&context_, "Attempting to resize a fixed-size tensor."); + return kTfLiteError; + } + return kTfLiteOk; +} + +void Interpreter::UseNNAPI(bool enable) { + // TODO(aselle): This is a workaround for finding if NNAPI exists. + // We also need to make sure getLibraryHandle() is renamed to be NNAPI + // prefixed. + if (!NNAPIExists()) enable = false; + if (!enable) { + nnapi_delegate_.reset(); + } else if (!nnapi_delegate_) { + nnapi_delegate_.reset(new NNAPIDelegate); + } +} + +void Interpreter::SetNumThreads(int num_threads) { + // TODO(ahentz): this forces us to link against gemmlowp even when the ops + // don't use it. We should implement some dynamic mechanism for this sort of + // library-specific initialization. + tflite::gemm_support::SetMaxNumThreads(&context_, num_threads); +} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h new file mode 100644 index 0000000000000000000000000000000000000000..65c61e44bee48535f884a3afaddc691972f5e04b --- /dev/null +++ b/tensorflow/contrib/lite/interpreter.h @@ -0,0 +1,374 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Main abstraction controlling the tflite interpreter. +// See context.h for the API for defining operations (TfLiteRegistration). +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_INTERPRETER_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_INTERPRETER_H_ + +#include +#include +#include +#include "tensorflow/contrib/lite/allocation.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/simple_memory_arena.h" + +namespace tflite { + +// Map statically from a c++ type to a TfLiteType (used below for safe casts). +template +constexpr TfLiteType typeToTfLiteType() { + return kTfLiteNoType; +} +template <> +constexpr TfLiteType typeToTfLiteType() { + return kTfLiteInt32; +} +template <> +constexpr TfLiteType typeToTfLiteType() { + return kTfLiteInt64; +} +template <> +constexpr TfLiteType typeToTfLiteType() { + return kTfLiteFloat32; +} +template <> +constexpr TfLiteType typeToTfLiteType() { + return kTfLiteUInt8; +} + +struct ArenaAllocRefCount { + ArenaAllocRefCount() : alloc(), count(0) {} + + ArenaAlloc alloc; + int count; +}; + +// Forward declare since NNAPIDelegate uses Interpreter. +class NNAPIDelegate; + +// An interpreter for a graph of nodes that input and output from tensors. +// Each node of the graph processes a set of input tensors and produces a +// set of output Tensors. All inputs/output tensors are referenced by index. +// +// Usage: +// +// -- Create basic model +// Interpreter foo(2, 1); +// foo.SetTensorParametersReadWrite(0, ...); +// foo.SetTensorParametersReadOnly(1, ...); +// foo.SetNodeParameters(0, ...) +// +// -- Resize input array to 1 length. +// foo.ResizeInputTensor(0, 1); +// foo.AllocateTensors(); +// -- Install array data +// foo.typed_tensor(0)[0] = 3; +// foo.Invoke(); +// foo.typed_tensor(0)[0] = 4; +// foo.Invoke(); +// -- Resize input array and set data. +// foo.ResizeInputTensor(0, 2); +// foo.AllocateTensors(); +// foo.typed_tensor(0)[0] = 4; +// foo.typed_tensor(0)[1] = 8; +// foo.Invoke(); +// + +class Interpreter { + public: + // Instantiate an interpreter. All errors associated with reading and + // processing this model will be forwarded to the error_reporter object. + // + // Note, if error_reporter is nullptr, then a default StderrReporter is + // used. + explicit Interpreter(ErrorReporter* error_reporter = DefaultErrorReporter()); + + ~Interpreter(); + + Interpreter(const Interpreter&) = delete; + Interpreter& operator=(const Interpreter&) = delete; + + // Functions to build interpreter + + // Provide a list of tensor indexes that are inputs to the model. + // Each index is bound check and this modifies the consistent_ flag of the + // interpreter. + TfLiteStatus SetInputs(std::vector inputs); + + // Provide a list of tensor indexes that are outputs to the model + // Each index is bound check and this modifies the consistent_ flag of the + // interpreter. + TfLiteStatus SetOutputs(std::vector outputs); + + // Adds a node with the given parameters and returns the index of the new + // node in `node_index` (optionally). Interpreter will take ownership of + // `builtin_data` and destroy it with `delete`. Ownership of 'init_data' + // remains with the caller. + TfLiteStatus AddNodeWithParameters(const std::vector& inputs, + const std::vector& outputs, + const char* init_data, + size_t init_data_size, void* builtin_data, + const TfLiteRegistration* registration, + int* node_index = nullptr); + + // Adds `tensors_to_add` tensors, preserving pre-existing Tensor entries. + // The value pointed to by `first_new_tensor_index` will be set to the + // index of the first new tensor if `first_new_tensor_index` is non-null. + TfLiteStatus AddTensors(int tensors_to_add, + int* first_new_tensor_index = nullptr); + + // Set description of inputs/outputs/data/fptrs for node `node_index`. + // This variant assumes an external buffer has been allocated of size + // bytes. The lifetime of buffer must be ensured to be greater or equal + // to Interpreter. + TfLiteStatus SetTensorParametersReadOnly( + int tensor_index, TfLiteType type, const char* name, + const std::vector& dims, TfLiteQuantizationParams quantization, + const char* buffer, size_t bytes, const Allocation* allocation = nullptr); + + // Set description of inputs/outputs/data/fptrs for node `node_index`. + // This variant assumes an external buffer has been allocated of size + // bytes. The lifetime of buffer must be ensured to be greater or equal + // to Interpreter. + TfLiteStatus SetTensorParametersReadWrite( + int tensor_index, TfLiteType type, const char* name, + const std::vector& dims, TfLiteQuantizationParams quantization); + + // Functions to access tensor data + + // Read only access to list of inputs. + const std::vector& inputs() const { return inputs_; } + + // Return the name of a given input. The given index must be between 0 and + // inputs().size(). + const char* GetInputName(int index) const { + return context_.tensors[inputs_[index]].name; + } + + // Read only access to list of outputs. + const std::vector& outputs() const { return outputs_; } + + // Return the name of a given output. The given index must be between 0 and + // outputs().size(). + const char* GetOutputName(int index) const { + return context_.tensors[outputs_[index]].name; + } + + // Return the number of tensors in the model. + int tensors_size() const { return context_.tensors_size; } + + // Return the number of ops in the model. + int nodes_size() const { return nodes_and_registration_.size(); } + + // Get a tensor data structure. + // TODO(aselle): Create a safe ArrayHandle interface to avoid exposing this + // read/write access to structure + TfLiteTensor* tensor(int tensor_index) { + if (tensor_index >= context_.tensors_size || tensor_index < 0) + return nullptr; + return &context_.tensors[tensor_index]; + } + + // Get a pointer to an operation and registration data structure if in bounds. + // TODO(aselle): Create a safe ArrayHandle interface to avoid exposing this + // read/write access to structure + const std::pair* node_and_registration( + int node_index) { + if (node_index >= nodes_and_registration_.size() || node_index < 0) + return nullptr; + return &nodes_and_registration_[node_index]; + } + + // Perform a checked cast to the appropriate tensor type. + template + T* typed_tensor(int tensor_index) { + if (TfLiteTensor* tensor_ptr = tensor(tensor_index)) { + if (tensor_ptr->type == typeToTfLiteType()) { + return reinterpret_cast(tensor_ptr->data.raw); + } + } + return nullptr; + } + + // Return a pointer into the data of a given input tensor. The given index + // must be between 0 and inputs().size(). + template + T* typed_input_tensor(int index) { + return typed_tensor(inputs_[index]); + } + + // Return a pointer into the data of a given output tensor. The given index + // must be between 0 and outputs().size(). + template + T* typed_output_tensor(int index) { + return typed_tensor(outputs_[index]); + } + + // Change the dimensionality of a given tensor. Note, this is only acceptable + // for tensor indices that are inputs. + // Returns status of failure or success. + // TODO(aselle): Consider implementing ArraySlice equivalent to make this + // more adept at accepting data without an extra copy. Use absl::ArraySlice + // if our partners determine that dependency is acceptable. + TfLiteStatus ResizeInputTensor(int tensor_index, + const std::vector& dims); + + // Update allocations for all tensors. This will redim dependent tensors using + // the input tensor dimensionality as given. This is relatively expensive. + // If you know that your sizes are not changing, you need not call this. + + // Returns status of success or failure. + TfLiteStatus AllocateTensors(); + + // Invoke the interpreter (run the whole graph in dependency order). + // + // NOTE: It is possible that the interpreter is not in a ready state + // to evaluate (i.e. if a ResizeTensor() has been performed without an + // AllocateTensors(). + // Returns status of success or failure. + TfLiteStatus Invoke(); + + // Enable or disable the NN API (true to enable) + void UseNNAPI(bool enable); + + // Set the number of threads available to the interpreter. + void SetNumThreads(int num_threads); + + private: + // Give 'op_reg' a chance to initialize itself using the contents of + // 'buffer'. + void* OpInit(const TfLiteRegistration& op_reg, const char* buffer, + size_t length) { + if (op_reg.init == nullptr) return nullptr; + return op_reg.init(&context_, buffer, length); + } + + // Let 'op_reg' release any memory it might have allocated via 'OpInit'. + void OpFree(const TfLiteRegistration& op_reg, void* buffer) { + if (op_reg.free == nullptr) return; + if (buffer) { + op_reg.free(&context_, buffer); + } + } + + // Prepare the given 'node' for execution. + TfLiteStatus OpPrepare(const TfLiteRegistration& op_reg, TfLiteNode* node) { + if (op_reg.prepare == nullptr) return kTfLiteOk; + return op_reg.prepare(&context_, node); + } + + // Invoke the operator represented by 'node'. + TfLiteStatus OpInvoke(const TfLiteRegistration& op_reg, TfLiteNode* node) { + if (op_reg.invoke == nullptr) return kTfLiteError; + return op_reg.invoke(&context_, node); + } + + // Allocate tensors whose sizes are known in order of nodes. Discontinue when + // we encounter a node that has a dynamic output tensor. + TfLiteStatus AllocateTensorsWhoseSizesAreKnown(); + + // Tensors needed by the interpreter. Use `AddTensors` to add more blank + // tensor entries. Note, `tensors_.data()` needs to be synchronized to the + // `context_` whenever this std::vector is reallocated. Currently this + // only happens in `AddTensors()`. + std::vector tensors_; + + // Check if an array of tensor indices are valid with respect to the Tensor + // array. + // NOTE: this changes consistent_ to be false if indices are out of bounds. + TfLiteStatus CheckTensorIndices(const char* label, const int* indices, + int length); + + // Compute the number of bytes required to represent a tensor with dimensions + // specified by the array dims (of length dims_size). Returns the status code + // and bytes. + TfLiteStatus BytesRequired(TfLiteType type, const int* dims, int dims_size, + size_t* bytes); + + // Request an tensor be resized implementation. + TfLiteStatus ResizeTensorImpl(TfLiteTensor* tensor, TfLiteIntArray* new_size); + + // Report a detailed error string (will be printed to stderr). + // TODO(aselle): allow user of class to provide alternative destinations. + void ReportErrorImpl(const char* format, va_list args); + + // Entry point for C node plugin API to request an tensor be resized. + static TfLiteStatus ResizeTensor(TfLiteContext* context, TfLiteTensor* tensor, + TfLiteIntArray* new_size); + // Entry point for C node plugin API to report an error. + static void ReportError(TfLiteContext* context, const char* format, ...); + + // Entry point for C node plugin API to add new tensors. + static TfLiteStatus AddTensors(TfLiteContext* context, int tensors_to_add, + int* first_new_tensor_index); + + // A pure C data structure used to communicate with the pure C plugin + // interface. To avoid copying tensor metadata, this is also the definitive + // structure to store tensors. + TfLiteContext context_; + + // Node inputs/outputs are stored in TfLiteNode and TfLiteRegistration stores + // function pointers to actual implementation. + std::vector> + nodes_and_registration_; + + // Raw memory buffer that is allocated for all temporary and graph outputs. + // that are declared kTfLiteArenaRw. + SimpleMemoryArena arena_; + + // Raw memory buffer that is allocated for persistent tensors that are + // declared as kTfLiteArenaRwPersistent. + SimpleMemoryArena persistent_arena_; + + // Stores allocation and reference counts of all tensors. + std::vector allocs_and_refcounts_; + + // Whether the model is consistent. That is to say if the inputs and outputs + // of every node and the global inputs and outputs are valid indexes into + // the tensor array. + bool consistent_ = true; + + // Whether the model is safe to invoke (if any errors occurred this + // will be false). + bool invokable_ = false; + + // Array of indices representing the tensors that are inputs to the + // interpreter. + std::vector inputs_; + + // Array of indices representing the tensors that are outputs to the + // interpreter. + std::vector outputs_; + + // The error reporter delegate that tflite will forward queries errors to. + ErrorReporter* error_reporter_; + + // Next node to allocate output tensors. + // During Invoke(), Interpreter will allocate input tensors first, which are + // known to be fixed size. Then it will allocate outputs from nodes as many + // as possible. When there is a node that produces dynamic sized tensor. + // Intepreter will stop allocating tensors, set the value of next allocate + // node id, and execute the node to generate the output tensor before continue + // to allocate successors. This process repeats until all nodes are executed. + // NOTE: this relies on the order of nodes that is in topological order. + int next_allocate_node_id_; + + // Whether to delegate to NN API + std::unique_ptr nnapi_delegate_; +}; + +} // namespace tflite +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_INTERPRETER_H_ diff --git a/tensorflow/contrib/lite/interpreter_test.cc b/tensorflow/contrib/lite/interpreter_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..edff2109430c6e1ec6c481619ed7772237a3301d --- /dev/null +++ b/tensorflow/contrib/lite/interpreter_test.cc @@ -0,0 +1,526 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/interpreter.h" +#include +#include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/string_util.h" + +namespace tflite { +namespace { + +// Make an interpreter that has no tensors and no nodes +TEST(BasicInterpreter, ZeroInterpreter) { + Interpreter interpreter; + interpreter.SetInputs({}); + interpreter.SetOutputs({}); + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + ASSERT_EQ(interpreter.Invoke(), kTfLiteOk); +} + +// Test various error conditions. +TEST(BasicInterpreter, InvokeInvalidModel) { + Interpreter interpreter; + ASSERT_NE(interpreter.Invoke(), kTfLiteOk); + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + ASSERT_EQ(interpreter.Invoke(), kTfLiteOk); +} + +// Test size accesser functions. +TEST(BasicInterpreter, TestSizeFunctions) { + Interpreter interpreter; + int base_index; + ASSERT_EQ(interpreter.nodes_size(), 0); + ASSERT_EQ(interpreter.tensors_size(), 0); + ASSERT_EQ(interpreter.AddTensors(2, &base_index), kTfLiteOk); + ASSERT_EQ(interpreter.tensors_size(), 2); + ASSERT_EQ(base_index, 0); + ASSERT_EQ(interpreter.AddTensors(3, &base_index), kTfLiteOk); + ASSERT_EQ(interpreter.tensors_size(), 5); + ASSERT_EQ(interpreter.AddTensors(1), kTfLiteOk); + ASSERT_EQ(interpreter.tensors_size(), 6); + ASSERT_EQ(base_index, 2); +} + +// Test if invalid indices make a model inconsistent (and conversely if +// valid indices keep a model consistent). +TEST(BasicInterpreter, InconsistentModel) { + // Invalid inputs + { + Interpreter interpreter; + ASSERT_NE(interpreter.SetInputs({5}), kTfLiteOk); + ASSERT_NE(interpreter.AllocateTensors(), kTfLiteOk); + ASSERT_NE(interpreter.Invoke(), kTfLiteOk); + ASSERT_EQ(interpreter.inputs(), std::vector()); + } + // Invalid outputs + { + Interpreter interpreter; + ASSERT_NE(interpreter.SetOutputs({5}), kTfLiteOk); + ASSERT_NE(interpreter.AllocateTensors(), kTfLiteOk); + ASSERT_NE(interpreter.Invoke(), kTfLiteOk); + ASSERT_EQ(interpreter.outputs(), std::vector()); + } + // Invalid node inputs + { + Interpreter interpreter; + TfLiteRegistration registration = {nullptr, nullptr, nullptr, nullptr}; + ASSERT_NE(interpreter.AddNodeWithParameters({3}, {0}, nullptr, 0, nullptr, + ®istration), + kTfLiteOk); + ASSERT_NE(interpreter.AllocateTensors(), kTfLiteOk); + ASSERT_NE(interpreter.Invoke(), kTfLiteOk); + } + // Valid inputs and outputs and a node with valid inputs and outputs + { + Interpreter interpreter; + ASSERT_EQ(interpreter.AddTensors(2), kTfLiteOk); + TfLiteRegistration registration = {nullptr, nullptr, nullptr, nullptr}; + ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk); + ASSERT_EQ(interpreter.SetOutputs({0}), kTfLiteOk); + ASSERT_EQ(interpreter.AddNodeWithParameters({0}, {1}, nullptr, 0, nullptr, + ®istration), + kTfLiteOk); + } +} + +// Make an interpreter that has one tensor but no ops +TEST(BasicInterpreter, CheckAllocate) { + struct { + TfLiteType type; + size_t size; + } cases[] = { + {kTfLiteFloat32, sizeof(float)}, + {kTfLiteInt32, sizeof(int32_t)}, + {kTfLiteUInt8, sizeof(uint8_t)}, + {kTfLiteInt64, sizeof(int64_t)}, + }; + + for (auto test : cases) { + Interpreter interpreter; + ASSERT_EQ(interpreter.AddTensors(2), kTfLiteOk); + interpreter.SetInputs({0, 1}); + interpreter.SetOutputs({}); + TfLiteQuantizationParams quant; + + interpreter.SetTensorParametersReadWrite(0, test.type, "", {3}, quant); + interpreter.SetTensorParametersReadWrite(1, test.type, "", {4}, quant); + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + ASSERT_EQ(interpreter.tensor(0)->bytes, 3 * test.size); + ASSERT_NE(interpreter.tensor(0)->data.raw, nullptr); + ASSERT_EQ(interpreter.tensor(1)->bytes, 4 * test.size); + ASSERT_NE(interpreter.tensor(1)->data.raw, nullptr); + } +} + +TEST(BasicInterpreter, CheckResize) { + const float floats[] = {-3., -4.}; + const int32_t int32s[] = {-3, -4}; + const uint8_t uint8s[] = {3, 4}; + const int64_t int64s[] = {6, -7}; + + struct { + TfLiteType type; + size_t size; + const char* array; + } cases[] = { + {kTfLiteFloat32, sizeof(float), reinterpret_cast(floats)}, + {kTfLiteInt32, sizeof(int32_t), reinterpret_cast(int32s)}, + {kTfLiteUInt8, sizeof(uint8_t), reinterpret_cast(uint8s)}, + {kTfLiteInt64, sizeof(int64_t), reinterpret_cast(int64s)}, + }; + + for (auto test : cases) { + Interpreter interpreter; + + ASSERT_EQ(interpreter.AddTensors(2), kTfLiteOk); + interpreter.SetInputs({0, 1}); + interpreter.SetOutputs({}); + TfLiteQuantizationParams quant; + + ASSERT_EQ( + interpreter.SetTensorParametersReadWrite(0, test.type, "", {3}, quant), + kTfLiteOk); + ASSERT_EQ(interpreter.SetTensorParametersReadOnly( + 1, test.type, "", {2}, quant, test.array, 2 * test.size), + kTfLiteOk); + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + ASSERT_EQ(interpreter.ResizeInputTensor(0, {1, 2}), kTfLiteOk); + // Resizing a mmapped tensor is not allowed and should produce error. + ASSERT_NE(interpreter.ResizeInputTensor(1, {3}), kTfLiteOk); + // Set the tensor to be mmapped but with a buffer size that is insufficient + // to match the dimensionality. + ASSERT_NE(interpreter.SetTensorParametersReadOnly( + 1, test.type, "", {2}, quant, test.array, 1 * test.size), + kTfLiteOk); + // Allocating should work since we should have our last correct array + // values in place. + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + } +} + +TEST(BasicInterpreter, CheckAlignment) { + struct { + TfLiteType type; + } cases[] = { + {kTfLiteFloat32}, + {kTfLiteInt32}, + {kTfLiteUInt8}, + {kTfLiteInt64}, + }; + + for (auto test : cases) { + Interpreter interpreter; + + ASSERT_EQ(interpreter.AddTensors(4), kTfLiteOk); + + for (int i = 0; i < 4; i++) { + TfLiteQuantizationParams quant; + interpreter.SetTensorParametersReadWrite(i, test.type, "", {2 * i + 1}, + quant); + } + interpreter.AllocateTensors(); + for (int i = 0; i < 4; i++) { + const TfLiteTensor& tensor = *interpreter.tensor(i); + ASSERT_EQ(reinterpret_cast(tensor.data.raw) % 4, 0); + } + } +} + +TEST(BasicInterpreter, CheckArenaAllocation) { + Interpreter interpreter; + ASSERT_EQ(interpreter.AddTensors(10), kTfLiteOk); + + TfLiteQuantizationParams quant; + TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr}; + + std::vector sizes{2048, 4096, 1023, 2047, 1021, + 2047, 1023, 2046, 1021, 2048}; + for (int i = 0; i < sizes.size(); ++i) { + interpreter.SetTensorParametersReadWrite(i, kTfLiteUInt8, "", {sizes[i]}, + quant); + } + interpreter.SetInputs({0, 1}); + interpreter.SetOutputs({9, 4}); + interpreter.AddNodeWithParameters({0, 1}, {2, 3}, nullptr, 0, nullptr, ®); + interpreter.AddNodeWithParameters({2, 1}, {4, 5}, nullptr, 0, nullptr, ®); + interpreter.AddNodeWithParameters({4, 3}, {6, 7}, nullptr, 0, nullptr, ®); + interpreter.AddNodeWithParameters({6, 5}, {8}, nullptr, 0, nullptr, ®); + interpreter.AddNodeWithParameters({8, 7}, {9}, nullptr, 0, nullptr, ®); + + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + + ASSERT_EQ(interpreter.tensor(0)->data.raw, interpreter.tensor(4)->data.raw); + ASSERT_EQ(interpreter.tensor(1)->data.raw, interpreter.tensor(7)->data.raw); + + ASSERT_LT(interpreter.tensor(4)->data.raw, interpreter.tensor(1)->data.raw); + ASSERT_LT(interpreter.tensor(6)->data.raw, interpreter.tensor(1)->data.raw); + ASSERT_LT(interpreter.tensor(0)->data.raw, interpreter.tensor(1)->data.raw); + + ASSERT_LT(interpreter.tensor(0)->data.raw, interpreter.tensor(3)->data.raw); + ASSERT_LT(interpreter.tensor(1)->data.raw, interpreter.tensor(3)->data.raw); + ASSERT_LT(interpreter.tensor(2)->data.raw, interpreter.tensor(3)->data.raw); + ASSERT_LT(interpreter.tensor(4)->data.raw, interpreter.tensor(3)->data.raw); + ASSERT_LT(interpreter.tensor(6)->data.raw, interpreter.tensor(3)->data.raw); + ASSERT_LT(interpreter.tensor(7)->data.raw, interpreter.tensor(3)->data.raw); + ASSERT_LT(interpreter.tensor(8)->data.raw, interpreter.tensor(3)->data.raw); + ASSERT_LT(interpreter.tensor(9)->data.raw, interpreter.tensor(3)->data.raw); + + ASSERT_LT(interpreter.tensor(0)->data.raw, interpreter.tensor(5)->data.raw); + ASSERT_LT(interpreter.tensor(1)->data.raw, interpreter.tensor(5)->data.raw); + ASSERT_LT(interpreter.tensor(2)->data.raw, interpreter.tensor(5)->data.raw); + ASSERT_LT(interpreter.tensor(3)->data.raw, interpreter.tensor(5)->data.raw); + ASSERT_LT(interpreter.tensor(4)->data.raw, interpreter.tensor(5)->data.raw); + ASSERT_LT(interpreter.tensor(6)->data.raw, interpreter.tensor(5)->data.raw); + ASSERT_LT(interpreter.tensor(7)->data.raw, interpreter.tensor(5)->data.raw); + ASSERT_LT(interpreter.tensor(8)->data.raw, interpreter.tensor(5)->data.raw); + ASSERT_LT(interpreter.tensor(9)->data.raw, interpreter.tensor(5)->data.raw); +} + +TEST(BasicInterpreter, BufferAccess) { + Interpreter interpreter; + ASSERT_EQ(interpreter.AddTensors(1), kTfLiteOk); + ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk); + + ASSERT_EQ(interpreter.SetTensorParametersReadWrite( + 0, kTfLiteFloat32, "", {3}, TfLiteQuantizationParams()), + kTfLiteOk); + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + // Verify we get a valid pointer.r + ASSERT_NE(interpreter.typed_tensor(0), nullptr); + // Verify incorrect pointer will not returned. + ASSERT_EQ(interpreter.typed_tensor(0), nullptr); + // Verify that raw c interface ptr matches safe interface. + ASSERT_EQ(interpreter.typed_tensor(0), interpreter.tensor(0)->data.f); +} + +TEST(BasicInterpreter, NoOpInterpreter) { + Interpreter interpreter; + ASSERT_EQ(interpreter.AddTensors(1), kTfLiteOk); + ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk); + ASSERT_EQ(interpreter.SetOutputs({0}), kTfLiteOk); + + ASSERT_EQ(interpreter.SetTensorParametersReadWrite( + 0, kTfLiteFloat32, "", {3}, TfLiteQuantizationParams()), + kTfLiteOk); + + ASSERT_EQ(interpreter.ResizeInputTensor(interpreter.inputs()[0], {1, 2, 3}), + kTfLiteOk); + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + ASSERT_EQ(interpreter.Invoke(), kTfLiteOk); +} + +TEST(BasicInterpreter, OneOpInterpreter) { + Interpreter interpreter; + ASSERT_EQ(interpreter.AddTensors(2), kTfLiteOk); + ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk); + ASSERT_EQ(interpreter.SetOutputs({1}), kTfLiteOk); + + TfLiteQuantizationParams quantized; + ASSERT_EQ(interpreter.SetTensorParametersReadWrite(0, kTfLiteFloat32, "in1", + {3}, quantized), + kTfLiteOk); + ASSERT_EQ(interpreter.SetTensorParametersReadWrite(1, kTfLiteFloat32, "out0", + {3}, quantized), + kTfLiteOk); + + ASSERT_EQ(interpreter.GetInputName(0), "in1"); + ASSERT_EQ(interpreter.GetOutputName(0), "out0"); + + TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr}; + reg.init = [](TfLiteContext* context, const char*, size_t) -> void* { + auto* first_new_tensor = new int; + context->AddTensors(context, 2, first_new_tensor); + return first_new_tensor; + }; + reg.free = [](TfLiteContext* context, void* buffer) { + delete reinterpret_cast(buffer); + }; + reg.prepare = [](TfLiteContext* context, TfLiteNode* node) { + auto* first_new_tensor = reinterpret_cast(node->user_data); + + TfLiteTensor* tensor0 = &context->tensors[node->inputs->data[0]]; + TfLiteTensor* tensor1 = &context->tensors[node->outputs->data[0]]; + + TfLiteIntArray* newSize = TfLiteIntArrayCopy(tensor0->dims); + TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, tensor1, newSize)); + + TfLiteIntArrayFree(node->temporaries); + node->temporaries = TfLiteIntArrayCreate(2); + for (int i = 0; i < 2; ++i) { + node->temporaries->data[i] = *(first_new_tensor) + i; + } + + auto setup_temporary = [&](int id) { + TfLiteTensor* tmp = &context->tensors[id]; + tmp->type = kTfLiteFloat32; + tmp->allocation_type = kTfLiteArenaRw; + return context->ResizeTensor(context, tmp, + TfLiteIntArrayCopy(tensor0->dims)); + }; + TF_LITE_ENSURE_STATUS(setup_temporary(node->temporaries->data[0])); + TF_LITE_ENSURE_STATUS(setup_temporary(node->temporaries->data[1])); + + return kTfLiteOk; + }; + reg.invoke = [](TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* a0 = &context->tensors[node->inputs->data[0]]; + + auto populate = [&](int id) { + TfLiteTensor* t = &context->tensors[id]; + int num = a0->dims->data[0]; + for (int i = 0; i < num; i++) { + t->data.f[i] = a0->data.f[i]; + } + }; + + populate(node->outputs->data[0]); + populate(node->temporaries->data[0]); + populate(node->temporaries->data[1]); + return kTfLiteOk; + }; + ASSERT_EQ( + interpreter.AddNodeWithParameters({0}, {1}, nullptr, 0, nullptr, ®), + kTfLiteOk); + ASSERT_EQ(interpreter.ResizeInputTensor(0, {3}), kTfLiteOk); + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + + ASSERT_EQ(interpreter.Invoke(), kTfLiteOk); +} + +// Forcefully divides tensor allocation in three steps: one before invocation +// and two more at invocation time. This happens because we use string tensors +// and their sizes can't be determined until invocation time. +TEST(BasicInterpreter, ThreeStepAllocate) { + Interpreter interpreter; + ASSERT_EQ(interpreter.AddTensors(5), kTfLiteOk); + ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk); + ASSERT_EQ(interpreter.SetOutputs({4}), kTfLiteOk); + + TfLiteQuantizationParams quantized; + char data[] = {1, 0, 0, 0, 12, 0, 0, 0, 15, 0, 0, 0, 'A', 'B', 'C'}; + // Read only string tensor. + ASSERT_EQ(interpreter.SetTensorParametersReadOnly(0, kTfLiteString, "", {1}, + quantized, data, 15), + kTfLiteOk); + // Read-write string tensor. + ASSERT_EQ(interpreter.SetTensorParametersReadWrite(1, kTfLiteString, "", {1}, + quantized), + kTfLiteOk); + ASSERT_EQ(interpreter.SetTensorParametersReadWrite(2, kTfLiteInt32, "", {1}, + quantized), + kTfLiteOk); + ASSERT_EQ(interpreter.SetTensorParametersReadWrite(3, kTfLiteString, "", {1}, + quantized), + kTfLiteOk); + ASSERT_EQ(interpreter.SetTensorParametersReadWrite(4, kTfLiteInt32, "", {1}, + quantized), + kTfLiteOk); + + // String-in String-out node. + TfLiteRegistration reg_copy = {nullptr, nullptr, nullptr, nullptr}; + reg_copy.invoke = [](TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* a0 = &context->tensors[node->inputs->data[0]]; + TfLiteTensor* a1 = &context->tensors[node->outputs->data[0]]; + DynamicBuffer buf; + StringRef str_ref = GetString(a0, 0); + buf.AddString(str_ref); + buf.WriteToTensor(a1); + return kTfLiteOk; + }; + + // String-in Int-out node. + TfLiteRegistration reg_len = {nullptr, nullptr, nullptr, nullptr}; + reg_len.prepare = [](TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* output = &context->tensors[node->outputs->data[0]]; + TfLiteIntArray* outputSize = TfLiteIntArrayCreate(1); + outputSize->data[0] = 1; + return context->ResizeTensor(context, output, outputSize); + }; + reg_len.invoke = [](TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* a0 = &context->tensors[node->inputs->data[0]]; + TfLiteTensor* a1 = &context->tensors[node->outputs->data[0]]; + a1->data.i32[0] = a0->bytes; + return kTfLiteOk; + }; + + ASSERT_EQ(interpreter.AddNodeWithParameters({0}, {1}, nullptr, 0, nullptr, + ®_copy), + kTfLiteOk); + ASSERT_EQ(interpreter.AddNodeWithParameters({1}, {2}, nullptr, 0, nullptr, + ®_len), + kTfLiteOk); + ASSERT_EQ(interpreter.AddNodeWithParameters({0}, {3}, nullptr, 0, nullptr, + ®_copy), + kTfLiteOk); + ASSERT_EQ(interpreter.AddNodeWithParameters({3}, {4}, nullptr, 0, nullptr, + ®_len), + kTfLiteOk); + + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + ASSERT_EQ(interpreter.Invoke(), kTfLiteOk); + + ASSERT_EQ(interpreter.tensor(0)->bytes, 15); + ASSERT_NE(interpreter.tensor(0)->data.raw, nullptr); + ASSERT_EQ(interpreter.tensor(1)->bytes, 15); + ASSERT_NE(interpreter.tensor(1)->data.raw, nullptr); + ASSERT_EQ(interpreter.tensor(3)->bytes, 15); + ASSERT_NE(interpreter.tensor(4)->data.raw, nullptr); + ASSERT_EQ(interpreter.tensor(2)->bytes, 4); + ASSERT_EQ(interpreter.tensor(2)->data.i32[0], 15); + ASSERT_EQ(interpreter.tensor(4)->bytes, 4); + ASSERT_EQ(interpreter.tensor(4)->data.i32[0], 15); +} + +TEST(BasicInterpreter, AllocateTwice) { + Interpreter interpreter; + ASSERT_EQ(interpreter.AddTensors(2), kTfLiteOk); + ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk); + ASSERT_EQ(interpreter.SetOutputs({1}), kTfLiteOk); + + TfLiteQuantizationParams quantized; + ASSERT_EQ(interpreter.SetTensorParametersReadWrite(0, kTfLiteFloat32, "", {3}, + quantized), + kTfLiteOk); + ASSERT_EQ(interpreter.SetTensorParametersReadWrite(1, kTfLiteFloat32, "", {3}, + quantized), + kTfLiteOk); + + TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr}; + reg.prepare = [](TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* tensor0 = &context->tensors[node->inputs->data[0]]; + TfLiteTensor* tensor1 = &context->tensors[node->outputs->data[0]]; + TfLiteIntArray* newSize = TfLiteIntArrayCopy(tensor0->dims); + return context->ResizeTensor(context, tensor1, newSize); + }; + reg.invoke = [](TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* a0 = &context->tensors[node->inputs->data[0]]; + TfLiteTensor* a1 = &context->tensors[node->outputs->data[0]]; + int num = a0->dims->data[0]; + for (int i = 0; i < num; i++) { + a1->data.f[i] = a0->data.f[i]; + } + return kTfLiteOk; + }; + ASSERT_EQ( + interpreter.AddNodeWithParameters({0}, {1}, nullptr, 0, nullptr, ®), + kTfLiteOk); + ASSERT_EQ(interpreter.ResizeInputTensor(0, {3}), kTfLiteOk); + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + ASSERT_EQ(interpreter.Invoke(), kTfLiteOk); + char* old_tensor0_ptr = interpreter.tensor(0)->data.raw; + char* old_tensor1_ptr = interpreter.tensor(1)->data.raw; + + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + ASSERT_EQ(interpreter.Invoke(), kTfLiteOk); + ASSERT_EQ(old_tensor0_ptr, interpreter.tensor(0)->data.raw); + ASSERT_EQ(old_tensor1_ptr, interpreter.tensor(1)->data.raw); +} + +struct TestErrorReporter : public ErrorReporter { + int Report(const char* format, va_list args) override { + char buffer[1024]; + int size = vsnprintf(buffer, sizeof(buffer), format, args); + all_reports += buffer; + calls++; + return size; + } + int calls = 0; + std::string all_reports; +}; + +TEST(BasicInterpreter, TestNullErrorReporter) { + TestErrorReporter reporter; + Interpreter interpreter; +} + +TEST(BasicInterpreter, TestCustomErrorReporter) { + TestErrorReporter reporter; + Interpreter interpreter(&reporter); + ASSERT_NE(interpreter.Invoke(), kTfLiteOk); + ASSERT_EQ(reporter.all_reports, "Invoke called on model that is not ready."); + ASSERT_EQ(reporter.calls, 1); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { +#ifdef OS_LINUX + FLAGS_logtostderr = true; +#endif + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/ios_makefile.inc b/tensorflow/contrib/lite/ios_makefile.inc new file mode 100644 index 0000000000000000000000000000000000000000..bcff7ed9889e95c13294b6cf0d0f4788991a04df --- /dev/null +++ b/tensorflow/contrib/lite/ios_makefile.inc @@ -0,0 +1,47 @@ +# Settings for iOS. +ifeq ($(TARGET), IOS) + BUILD_FOR_IOS_SIMULATOR := false + ifeq ($(IOS_ARCH), x86_64) + BUILD_FOR_IOS_SIMULATOR := true + endif + ifeq ($(IOS_ARCH), i386) + BUILD_FOR_IOS_SIMULATOR := true + endif + ifeq ($(BUILD_FOR_IOS_SIMULATOR), true) + IPHONEOS_PLATFORM := $(shell xcrun --sdk iphonesimulator \ + --show-sdk-platform-path) + IPHONEOS_SYSROOT := $(shell xcrun --sdk iphonesimulator \ + --show-sdk-path) + else + IPHONEOS_PLATFORM := $(shell xcrun --sdk iphoneos --show-sdk-platform-path) + IPHONEOS_SYSROOT := $(shell xcrun --sdk iphoneos --show-sdk-path) + endif + IOS_SDK_VERSION := $(shell xcrun --sdk iphoneos --show-sdk-version) + MIN_SDK_VERSION := 9.0 + # Override IOS_ARCH with armv7, armv7s, arm64, i386, or x86_64. + IOS_ARCH := x86_64 + CXXFLAGS += -miphoneos-version-min=$(MIN_SDK_VERSION) \ + -DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK \ + -fembed-bitcode \ + -Wno-c++11-narrowing \ + -mno-thumb \ + -fno-exceptions \ + -isysroot \ + ${IPHONEOS_SYSROOT} \ + -arch $(IOS_ARCH) \ + -O3 + CCFLAGS += -miphoneos-version-min=$(MIN_SDK_VERSION) \ + -fembed-bitcode \ + -mno-thumb \ + -isysroot \ + ${IPHONEOS_SYSROOT} \ + -arch $(IOS_ARCH) \ + -O3 + LDFLAGS := -fembed-bitcode \ + -miphoneos-version-min=${MIN_SDK_VERSION} \ + -arch $(IOS_ARCH) + OBJDIR := $(OBJDIR)ios_$(IOS_ARCH)/ + LIBDIR := $(LIBDIR)ios_$(IOS_ARCH)/ + BINDIR := $(BINDIR)ios_$(IOS_ARCH)/ + DEPDIR := $(DEPDIR)ios_$(IOS_ARCH)/ +endif diff --git a/tensorflow/contrib/lite/java/BUILD b/tensorflow/contrib/lite/java/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..1de28eb52ddb458df0be0a8f9ef453f7caf68654 --- /dev/null +++ b/tensorflow/contrib/lite/java/BUILD @@ -0,0 +1,150 @@ +# Description: +# TensorFlow Lite Java API. + +package(default_visibility = ["//visibility:private"]) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow/java:build_defs.bzl", "JAVACOPTS") +load("//tensorflow/contrib/lite:build_def.bzl", "tflite_jni_binary") + +android_library( + name = "tensorflowlite", + srcs = glob( + [ + "src/main/java/org/tensorflow/lite/*.java", + ], + ), + visibility = ["//visibility:public"], + deps = [ + ":tflite_runtime", + "@javax_validation", + ], +) + +android_library( + name = "tensorflowlite_java", + srcs = glob( + [ + "src/main/java/org/tensorflow/lite/*.java", + ], + ), + visibility = ["//visibility:public"], + deps = [ + "@javax_validation", + ], +) + +java_library( + name = "tensorflowlitelib", + srcs = glob( + [ + "src/main/java/org/tensorflow/lite/*.java", + ], + ), + javacopts = JAVACOPTS, + visibility = ["//visibility:public"], + deps = [ + ":libtensorflowlite_jni.so", + "//tensorflow/contrib/lite/java/src/main/native", + "@javax_validation", + ], +) + +java_test( + name = "TensorFlowLiteTest", + size = "small", + srcs = ["src/test/java/org/tensorflow/lite/TensorFlowLiteTest.java"], + javacopts = JAVACOPTS, + test_class = "org.tensorflow.lite.TensorFlowLiteTest", + deps = [ + ":libtensorflowlite_jni.so", + ":tensorflowlitelib", + "@com_google_truth", + "@junit", + ], +) + +java_test( + name = "DataTypeTest", + size = "small", + srcs = ["src/test/java/org/tensorflow/lite/DataTypeTest.java"], + javacopts = JAVACOPTS, + test_class = "org.tensorflow.lite.DataTypeTest", + deps = [ + ":libtensorflowlite_jni.so", + ":tensorflowlitelib", + "@com_google_truth", + "@junit", + ], +) + +java_test( + name = "NativeInterpreterWrapperTest", + size = "small", + srcs = ["src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java"], + data = [ + "src/testdata/add.bin", + "src/testdata/int32.bin", + "src/testdata/int64.bin", + "src/testdata/invalid_model.bin", + "src/testdata/uint8.bin", + ], + javacopts = JAVACOPTS, + test_class = "org.tensorflow.lite.NativeInterpreterWrapperTest", + deps = [ + ":libtensorflowlite_jni.so", + ":tensorflowlitelib", + "@com_google_truth", + "@junit", + ], +) + +java_test( + name = "TensorTest", + size = "small", + srcs = ["src/test/java/org/tensorflow/lite/TensorTest.java"], + data = [ + "src/testdata/add.bin", + ], + javacopts = JAVACOPTS, + test_class = "org.tensorflow.lite.TensorTest", + deps = [ + ":tensorflowlitelib", + "@com_google_truth", + "@junit", + ], +) + +filegroup( + name = "libtensorflowlite_jni", + srcs = select({ + "//conditions:default": [":libtensorflowlite_jni.so"], + }), + visibility = ["//visibility:public"], +) + +cc_library( + name = "tflite_runtime", + srcs = ["libtensorflowlite_jni.so"], + visibility = ["//visibility:public"], +) + +tflite_jni_binary( + name = "libtensorflowlite_jni.so", + deps = [ + "//tensorflow/contrib/lite/java/src/main/native", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/lite/java/demo/.gitignore b/tensorflow/contrib/lite/java/demo/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..39fb081a42a86ccf8f9cf99dbccc8bdf7c828bce --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/.gitignore @@ -0,0 +1,9 @@ +*.iml +.gradle +/local.properties +/.idea/workspace.xml +/.idea/libraries +.DS_Store +/build +/captures +.externalNativeBuild diff --git a/tensorflow/contrib/lite/java/demo/README.md b/tensorflow/contrib/lite/java/demo/README.md new file mode 100644 index 0000000000000000000000000000000000000000..2e818f728ef208d30b0eeb27ffd7e3fa0c7c1a2d --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/README.md @@ -0,0 +1,46 @@ +# TF Lite Android App + +## Building from Source with Bazel + +1. Follow the [Bazel steps for the TF Demo App](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android#bazel): + + 1. [Install Bazel and Android Prerequisites](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android#install-bazel-and-android-prerequisites). + It's easiest with Android Studio. + + - You'll need at least SDK version 23. + - Make sure to install the latest version of Bazel. Some distributions + ship with Bazel 0.5.4, which is too old. + - Bazel requires Android Build Tools `26.0.1` or higher. + - **Bazel is incompatible with NDK revisions 15 and above,** with revision + 16 being a compile-breaking change. [Download an older version manually + instead of using the SDK Manager.](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android#install-bazel-and-android-prerequisites) + - You also need to install the Android Support Repository, available + through Android Studio under `Android SDK Manager -> SDK Tools -> + Android Support Repository`. + + 2. [Edit your `WORKSPACE`](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android#edit-workspace) + to add SDK and NDK targets. + + NOTE: As long as you have the SDK and NDK installed, the `./configure` + script will create these rules for you. Answer "Yes" when the script asks + to automatically configure the `./WORKSPACE`. + + - Make sure the `api_level` in `WORKSPACE` is set to an SDK version that + you have installed. + - By default, Android Studio will install the SDK to `~/Android/Sdk` and + the NDK to `~/Android/Sdk/ndk-bundle` (but the NDK should be a manual + download until Bazel supports NDK 16. See bullet points under (1)). + +2. Build the app with Bazel. The demo needs C++11: + + ```shell + bazel build -c opt --cxxopt='--std=c++11' \ + //tensorflow/contrib/lite/java/demo/app/src/main:TfLiteCameraDemo + ``` + +3. Install the demo on a + [debug-enabled device](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android#install): + + ```shell + adb install bazel-bin/tensorflow/contrib/lite/java/demo/app/src/main/TfLiteCameraDemo.apk + ``` diff --git a/tensorflow/contrib/lite/java/demo/app/build.gradle b/tensorflow/contrib/lite/java/demo/app/build.gradle new file mode 100644 index 0000000000000000000000000000000000000000..b76eaad8bb91224805d16b3d6f7c3274c9feb90c --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/build.gradle @@ -0,0 +1,58 @@ +apply plugin: 'com.android.application' + +android { + compileSdkVersion 26 + buildToolsVersion "26.0.1" + defaultConfig { + applicationId "android.example.com.tflitecamerademo" + minSdkVersion 15 + targetSdkVersion 26 + versionCode 1 + versionName "1.0" + testInstrumentationRunner "android.support.test.runner.AndroidJUnitRunner" + + // Remove this block. + jackOptions { + enabled true + } + } + lintOptions { + abortOnError false + } + buildTypes { + release { + minifyEnabled false + proguardFiles getDefaultProguardFile('proguard-android.txt'), 'proguard-rules.pro' + } + } + aaptOptions { + noCompress "tflite" + } + + compileOptions { + sourceCompatibility JavaVersion.VERSION_1_8 + targetCompatibility JavaVersion.VERSION_1_8 + } +} + +repositories { + maven { + url 'https://google.bintray.com/tensorflow' + } +} + +dependencies { + compile fileTree(dir: 'libs', include: ['*.jar']) + androidTestCompile('com.android.support.test.espresso:espresso-core:2.2.2', { + exclude group: 'com.android.support', module: 'support-annotations' + }) + compile 'com.android.support:appcompat-v7:25.2.0' + compile 'com.android.support.constraint:constraint-layout:1.0.2' + compile 'com.android.support:design:25.2.0' + compile 'com.android.support:support-annotations:25.3.1' + compile 'com.android.support:support-v13:25.2.0' + + compile 'org.tensorflow:tensorflow-lite:+' + + testCompile 'junit:junit:4.12' +} diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/AndroidManifest.xml b/tensorflow/contrib/lite/java/demo/app/src/main/AndroidManifest.xml new file mode 100644 index 0000000000000000000000000000000000000000..ba63dce5d9a7192a2c3c4c5561333d39a3ecc024 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/AndroidManifest.xml @@ -0,0 +1,42 @@ + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/BUILD b/tensorflow/contrib/lite/java/demo/app/src/main/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..654fa9d6d2799fc3cafa3e0e042cb2a5746bf2c5 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/BUILD @@ -0,0 +1,41 @@ +package(default_visibility = ["//visibility:private"]) + +licenses(["notice"]) # Apache 2.0 + +android_binary( + name = "TfLiteCameraDemo", + srcs = glob(["java/**/*.java"]), + assets = [ + "@tflite_mobilenet//:labels.txt", + "@tflite_mobilenet//:mobilenet_quant_v1_224.tflite", + ], + assets_dir = "", + custom_package = "com.example.android.tflitecamerademo", + manifest = "AndroidManifest.xml", + nocompress_extensions = [ + ".tflite", + ], + resource_files = glob(["res/**"]), + # In some platforms we don't have an Android SDK/NDK and this target + # can't be built. We need to prevent the build system from trying to + # use the target in that case. + tags = ["manual"], + deps = [ + "//tensorflow/contrib/lite/java:tensorflowlite", + "//tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite:testhelper", + "@androidsdk//com.android.support:support-v13-25.2.0", + "@androidsdk//com.android.support:support-v4-25.2.0", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/assets/BUILD b/tensorflow/contrib/lite/java/demo/app/src/main/assets/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..dd0cd6c98ff878e9c41875cab74c12191cadb173 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/assets/BUILD @@ -0,0 +1,24 @@ +package(default_visibility = ["//visibility:private"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files( + glob( + ["**/*"], + exclude = [ + "BUILD", + ], + ), +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/assets/labels.txt b/tensorflow/contrib/lite/java/demo/app/src/main/assets/labels.txt new file mode 100644 index 0000000000000000000000000000000000000000..fe811239d8e2989de19fecabb1ebb0c9dddac514 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/assets/labels.txt @@ -0,0 +1,1001 @@ +background +tench +goldfish +great white shark +tiger shark +hammerhead +electric ray +stingray +cock +hen +ostrich +brambling +goldfinch +house finch +junco +indigo bunting +robin +bulbul +jay +magpie +chickadee +water ouzel +kite +bald eagle +vulture +great grey owl +European fire salamander +common newt +eft +spotted salamander +axolotl +bullfrog +tree frog +tailed frog +loggerhead +leatherback turtle +mud turtle +terrapin +box turtle +banded gecko +common iguana +American chameleon +whiptail +agama +frilled lizard +alligator lizard +Gila monster +green lizard +African chameleon +Komodo dragon +African crocodile +American alligator +triceratops +thunder snake +ringneck snake +hognose snake +green snake +king snake +garter snake +water snake +vine snake +night snake +boa constrictor +rock python +Indian cobra +green mamba +sea snake +horned viper +diamondback +sidewinder +trilobite +harvestman +scorpion +black and gold garden spider +barn spider +garden spider +black widow +tarantula +wolf spider +tick +centipede +black grouse +ptarmigan +ruffed grouse +prairie chicken +peacock +quail +partridge +African grey +macaw +sulphur-crested cockatoo +lorikeet +coucal +bee eater +hornbill +hummingbird +jacamar +toucan +drake +red-breasted merganser +goose +black swan +tusker +echidna +platypus +wallaby +koala +wombat +jellyfish +sea anemone +brain coral +flatworm +nematode +conch +snail +slug +sea slug +chiton +chambered nautilus +Dungeness crab +rock crab +fiddler crab +king crab +American lobster +spiny lobster +crayfish +hermit crab +isopod +white stork +black stork +spoonbill +flamingo +little blue heron +American egret +bittern +crane +limpkin +European gallinule +American coot +bustard +ruddy turnstone +red-backed sandpiper +redshank +dowitcher +oystercatcher +pelican +king penguin +albatross +grey whale +killer whale +dugong +sea lion +Chihuahua +Japanese spaniel +Maltese dog +Pekinese +Shih-Tzu +Blenheim spaniel +papillon +toy terrier +Rhodesian ridgeback +Afghan hound +basset +beagle +bloodhound +bluetick +black-and-tan coonhound +Walker hound +English foxhound +redbone +borzoi +Irish wolfhound +Italian greyhound +whippet +Ibizan hound +Norwegian elkhound +otterhound +Saluki +Scottish deerhound +Weimaraner +Staffordshire bullterrier +American Staffordshire terrier +Bedlington terrier +Border terrier +Kerry blue terrier +Irish terrier +Norfolk terrier +Norwich terrier +Yorkshire terrier +wire-haired fox terrier +Lakeland terrier +Sealyham terrier +Airedale +cairn +Australian terrier +Dandie Dinmont +Boston bull +miniature schnauzer +giant schnauzer +standard schnauzer +Scotch terrier +Tibetan terrier +silky terrier +soft-coated wheaten terrier +West Highland white terrier +Lhasa +flat-coated retriever +curly-coated retriever +golden retriever +Labrador retriever +Chesapeake Bay retriever +German short-haired pointer +vizsla +English setter +Irish setter +Gordon setter +Brittany spaniel +clumber +English springer +Welsh springer spaniel +cocker spaniel +Sussex spaniel +Irish water spaniel +kuvasz +schipperke +groenendael +malinois +briard +kelpie +komondor +Old English sheepdog +Shetland sheepdog +collie +Border collie +Bouvier des Flandres +Rottweiler +German shepherd +Doberman +miniature pinscher +Greater Swiss Mountain dog +Bernese mountain dog +Appenzeller +EntleBucher +boxer +bull mastiff +Tibetan mastiff +French bulldog +Great Dane +Saint Bernard +Eskimo dog +malamute +Siberian husky +dalmatian +affenpinscher +basenji +pug +Leonberg +Newfoundland +Great Pyrenees +Samoyed +Pomeranian +chow +keeshond +Brabancon griffon +Pembroke +Cardigan +toy poodle +miniature poodle +standard poodle +Mexican hairless +timber wolf +white wolf +red wolf +coyote +dingo +dhole +African hunting dog +hyena +red fox +kit fox +Arctic fox +grey fox +tabby +tiger cat +Persian cat +Siamese cat +Egyptian cat +cougar +lynx +leopard +snow leopard +jaguar +lion +tiger +cheetah +brown bear +American black bear +ice bear +sloth bear +mongoose +meerkat +tiger beetle +ladybug +ground beetle +long-horned beetle +leaf beetle +dung beetle +rhinoceros beetle +weevil +fly +bee +ant +grasshopper +cricket +walking stick +cockroach +mantis +cicada +leafhopper +lacewing +dragonfly +damselfly +admiral +ringlet +monarch +cabbage butterfly +sulphur butterfly +lycaenid +starfish +sea urchin +sea cucumber +wood rabbit +hare +Angora +hamster +porcupine +fox squirrel +marmot +beaver +guinea pig +sorrel +zebra +hog +wild boar +warthog +hippopotamus +ox +water buffalo +bison +ram +bighorn +ibex +hartebeest +impala +gazelle +Arabian camel +llama +weasel +mink +polecat +black-footed ferret +otter +skunk +badger +armadillo +three-toed sloth +orangutan +gorilla +chimpanzee +gibbon +siamang +guenon +patas +baboon +macaque +langur +colobus +proboscis monkey +marmoset +capuchin +howler monkey +titi +spider monkey +squirrel monkey +Madagascar cat +indri +Indian elephant +African elephant +lesser panda +giant panda +barracouta +eel +coho +rock beauty +anemone fish +sturgeon +gar +lionfish +puffer +abacus +abaya +academic gown +accordion +acoustic guitar +aircraft carrier +airliner +airship +altar +ambulance +amphibian +analog clock +apiary +apron +ashcan +assault rifle +backpack +bakery +balance beam +balloon +ballpoint +Band Aid +banjo +bannister +barbell +barber chair +barbershop +barn +barometer +barrel +barrow +baseball +basketball +bassinet +bassoon +bathing cap +bath towel +bathtub +beach wagon +beacon +beaker +bearskin +beer bottle +beer glass +bell cote +bib +bicycle-built-for-two +bikini +binder +binoculars +birdhouse +boathouse +bobsled +bolo tie +bonnet +bookcase +bookshop +bottlecap +bow +bow tie +brass +brassiere +breakwater +breastplate +broom +bucket +buckle +bulletproof vest +bullet train +butcher shop +cab +caldron +candle +cannon +canoe +can opener +cardigan +car mirror +carousel +carpenter's kit +carton +car wheel +cash machine +cassette +cassette player +castle +catamaran +CD player +cello +cellular telephone +chain +chainlink fence +chain mail +chain saw +chest +chiffonier +chime +china cabinet +Christmas stocking +church +cinema +cleaver +cliff dwelling +cloak +clog +cocktail shaker +coffee mug +coffeepot +coil +combination lock +computer keyboard +confectionery +container ship +convertible +corkscrew +cornet +cowboy boot +cowboy hat +cradle +crane +crash helmet +crate +crib +Crock Pot +croquet ball +crutch +cuirass +dam +desk +desktop computer +dial telephone +diaper +digital clock +digital watch +dining table +dishrag +dishwasher +disk brake +dock +dogsled +dome +doormat +drilling platform +drum +drumstick +dumbbell +Dutch oven +electric fan +electric guitar +electric locomotive +entertainment center +envelope +espresso maker +face powder +feather boa +file +fireboat +fire engine +fire screen +flagpole +flute +folding chair +football helmet +forklift +fountain +fountain pen +four-poster +freight car +French horn +frying pan +fur coat +garbage truck +gasmask +gas pump +goblet +go-kart +golf ball +golfcart +gondola +gong +gown +grand piano +greenhouse +grille +grocery store +guillotine +hair slide +hair spray +half track +hammer +hamper +hand blower +hand-held computer +handkerchief +hard disc +harmonica +harp +harvester +hatchet +holster +home theater +honeycomb +hook +hoopskirt +horizontal bar +horse cart +hourglass +iPod +iron +jack-o'-lantern +jean +jeep +jersey +jigsaw puzzle +jinrikisha +joystick +kimono +knee pad +knot +lab coat +ladle +lampshade +laptop +lawn mower +lens cap +letter opener +library +lifeboat +lighter +limousine +liner +lipstick +Loafer +lotion +loudspeaker +loupe +lumbermill +magnetic compass +mailbag +mailbox +maillot +maillot +manhole cover +maraca +marimba +mask +matchstick +maypole +maze +measuring cup +medicine chest +megalith +microphone +microwave +military uniform +milk can +minibus +miniskirt +minivan +missile +mitten +mixing bowl +mobile home +Model T +modem +monastery +monitor +moped +mortar +mortarboard +mosque +mosquito net +motor scooter +mountain bike +mountain tent +mouse +mousetrap +moving van +muzzle +nail +neck brace +necklace +nipple +notebook +obelisk +oboe +ocarina +odometer +oil filter +organ +oscilloscope +overskirt +oxcart +oxygen mask +packet +paddle +paddlewheel +padlock +paintbrush +pajama +palace +panpipe +paper towel +parachute +parallel bars +park bench +parking meter +passenger car +patio +pay-phone +pedestal +pencil box +pencil sharpener +perfume +Petri dish +photocopier +pick +pickelhaube +picket fence +pickup +pier +piggy bank +pill bottle +pillow +ping-pong ball +pinwheel +pirate +pitcher +plane +planetarium +plastic bag +plate rack +plow +plunger +Polaroid camera +pole +police van +poncho +pool table +pop bottle +pot +potter's wheel +power drill +prayer rug +printer +prison +projectile +projector +puck +punching bag +purse +quill +quilt +racer +racket +radiator +radio +radio telescope +rain barrel +recreational vehicle +reel +reflex camera +refrigerator +remote control +restaurant +revolver +rifle +rocking chair +rotisserie +rubber eraser +rugby ball +rule +running shoe +safe +safety pin +saltshaker +sandal +sarong +sax +scabbard +scale +school bus +schooner +scoreboard +screen +screw +screwdriver +seat belt +sewing machine +shield +shoe shop +shoji +shopping basket +shopping cart +shovel +shower cap +shower curtain +ski +ski mask +sleeping bag +slide rule +sliding door +slot +snorkel +snowmobile +snowplow +soap dispenser +soccer ball +sock +solar dish +sombrero +soup bowl +space bar +space heater +space shuttle +spatula +speedboat +spider web +spindle +sports car +spotlight +stage +steam locomotive +steel arch bridge +steel drum +stethoscope +stole +stone wall +stopwatch +stove +strainer +streetcar +stretcher +studio couch +stupa +submarine +suit +sundial +sunglass +sunglasses +sunscreen +suspension bridge +swab +sweatshirt +swimming trunks +swing +switch +syringe +table lamp +tank +tape player +teapot +teddy +television +tennis ball +thatch +theater curtain +thimble +thresher +throne +tile roof +toaster +tobacco shop +toilet seat +torch +totem pole +tow truck +toyshop +tractor +trailer truck +tray +trench coat +tricycle +trimaran +tripod +triumphal arch +trolleybus +trombone +tub +turnstile +typewriter keyboard +umbrella +unicycle +upright +vacuum +vase +vault +velvet +vending machine +vestment +viaduct +violin +volleyball +waffle iron +wall clock +wallet +wardrobe +warplane +washbasin +washer +water bottle +water jug +water tower +whiskey jug +whistle +wig +window screen +window shade +Windsor tie +wine bottle +wing +wok +wooden spoon +wool +worm fence +wreck +yawl +yurt +web site +comic book +crossword puzzle +street sign +traffic light +book jacket +menu +plate +guacamole +consomme +hot pot +trifle +ice cream +ice lolly +French loaf +bagel +pretzel +cheeseburger +hotdog +mashed potato +head cabbage +broccoli +cauliflower +zucchini +spaghetti squash +acorn squash +butternut squash +cucumber +artichoke +bell pepper +cardoon +mushroom +Granny Smith +strawberry +orange +lemon +fig +pineapple +banana +jackfruit +custard apple +pomegranate +hay +carbonara +chocolate sauce +dough +meat loaf +pizza +potpie +burrito +red wine +espresso +cup +eggnog +alp +bubble +cliff +coral reef +geyser +lakeside +promontory +sandbar +seashore +valley +volcano +ballplayer +groom +scuba diver +rapeseed +daisy +yellow lady's slipper +corn +acorn +hip +buckeye +coral fungus +agaric +gyromitra +stinkhorn +earthstar +hen-of-the-woods +bolete +ear +toilet tissue diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/AutoFitTextureView.java b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/AutoFitTextureView.java new file mode 100644 index 0000000000000000000000000000000000000000..f2045906599218871b51a752dcbb3eeb23b8f085 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/AutoFitTextureView.java @@ -0,0 +1,72 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package com.example.android.tflitecamerademo; + +import android.content.Context; +import android.util.AttributeSet; +import android.view.TextureView; + +/** A {@link TextureView} that can be adjusted to a specified aspect ratio. */ +public class AutoFitTextureView extends TextureView { + + private int mRatioWidth = 0; + private int mRatioHeight = 0; + + public AutoFitTextureView(Context context) { + this(context, null); + } + + public AutoFitTextureView(Context context, AttributeSet attrs) { + this(context, attrs, 0); + } + + public AutoFitTextureView(Context context, AttributeSet attrs, int defStyle) { + super(context, attrs, defStyle); + } + + /** + * Sets the aspect ratio for this view. The size of the view will be measured based on the ratio + * calculated from the parameters. Note that the actual sizes of parameters don't matter, that is, + * calling setAspectRatio(2, 3) and setAspectRatio(4, 6) make the same result. + * + * @param width Relative horizontal size + * @param height Relative vertical size + */ + public void setAspectRatio(int width, int height) { + if (width < 0 || height < 0) { + throw new IllegalArgumentException("Size cannot be negative."); + } + mRatioWidth = width; + mRatioHeight = height; + requestLayout(); + } + + @Override + protected void onMeasure(int widthMeasureSpec, int heightMeasureSpec) { + super.onMeasure(widthMeasureSpec, heightMeasureSpec); + int width = MeasureSpec.getSize(widthMeasureSpec); + int height = MeasureSpec.getSize(heightMeasureSpec); + if (0 == mRatioWidth || 0 == mRatioHeight) { + setMeasuredDimension(width, height); + } else { + if (width < height * mRatioWidth / mRatioHeight) { + setMeasuredDimension(width, width * mRatioHeight / mRatioWidth); + } else { + setMeasuredDimension(height * mRatioWidth / mRatioHeight, height); + } + } + } +} diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java new file mode 100644 index 0000000000000000000000000000000000000000..74737a8b883d23684220dd32bbd7a9e8ab4b2123 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java @@ -0,0 +1,708 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package com.example.android.tflitecamerademo; + +import android.app.Activity; +import android.app.AlertDialog; +import android.app.Dialog; +import android.app.DialogFragment; +import android.app.Fragment; +import android.content.Context; +import android.content.DialogInterface; +import android.content.pm.PackageInfo; +import android.content.pm.PackageManager; +import android.content.res.Configuration; +import android.graphics.Bitmap; +import android.graphics.ImageFormat; +import android.graphics.Matrix; +import android.graphics.Point; +import android.graphics.RectF; +import android.graphics.SurfaceTexture; +import android.hardware.camera2.CameraAccessException; +import android.hardware.camera2.CameraCaptureSession; +import android.hardware.camera2.CameraCharacteristics; +import android.hardware.camera2.CameraDevice; +import android.hardware.camera2.CameraManager; +import android.hardware.camera2.CaptureRequest; +import android.hardware.camera2.CaptureResult; +import android.hardware.camera2.TotalCaptureResult; +import android.hardware.camera2.params.StreamConfigurationMap; +import android.media.ImageReader; +import android.os.Bundle; +import android.os.Handler; +import android.os.HandlerThread; +import android.support.annotation.NonNull; +import android.support.v13.app.FragmentCompat; +import android.support.v4.content.ContextCompat; +import android.util.Log; +import android.util.Size; +import android.view.LayoutInflater; +import android.view.Surface; +import android.view.TextureView; +import android.view.View; +import android.view.ViewGroup; +import android.widget.TextView; +import android.widget.Toast; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.concurrent.Semaphore; +import java.util.concurrent.TimeUnit; + +/** Basic fragments for the Camera. */ +public class Camera2BasicFragment extends Fragment + implements FragmentCompat.OnRequestPermissionsResultCallback { + + /** Tag for the {@link Log}. */ + private static final String TAG = "TfLiteCameraDemo"; + + private static final String FRAGMENT_DIALOG = "dialog"; + + private static final String HANDLE_THREAD_NAME = "CameraBackground"; + + private static final int PERMISSIONS_REQUEST_CODE = 1; + + private final Object lock = new Object(); + private boolean runClassifier = false; + private boolean checkedPermissions = false; + private TextView textView; + private ImageClassifier classifier; + + /** Max preview width that is guaranteed by Camera2 API */ + private static final int MAX_PREVIEW_WIDTH = 1920; + + /** Max preview height that is guaranteed by Camera2 API */ + private static final int MAX_PREVIEW_HEIGHT = 1080; + + /** + * {@link TextureView.SurfaceTextureListener} handles several lifecycle events on a {@link + * TextureView}. + */ + private final TextureView.SurfaceTextureListener surfaceTextureListener = + new TextureView.SurfaceTextureListener() { + + @Override + public void onSurfaceTextureAvailable(SurfaceTexture texture, int width, int height) { + openCamera(width, height); + } + + @Override + public void onSurfaceTextureSizeChanged(SurfaceTexture texture, int width, int height) { + configureTransform(width, height); + } + + @Override + public boolean onSurfaceTextureDestroyed(SurfaceTexture texture) { + return true; + } + + @Override + public void onSurfaceTextureUpdated(SurfaceTexture texture) {} + }; + + /** ID of the current {@link CameraDevice}. */ + private String cameraId; + + /** An {@link AutoFitTextureView} for camera preview. */ + private AutoFitTextureView textureView; + + /** A {@link CameraCaptureSession } for camera preview. */ + private CameraCaptureSession captureSession; + + /** A reference to the opened {@link CameraDevice}. */ + private CameraDevice cameraDevice; + + /** The {@link android.util.Size} of camera preview. */ + private Size previewSize; + + /** {@link CameraDevice.StateCallback} is called when {@link CameraDevice} changes its state. */ + private final CameraDevice.StateCallback stateCallback = + new CameraDevice.StateCallback() { + + @Override + public void onOpened(@NonNull CameraDevice currentCameraDevice) { + // This method is called when the camera is opened. We start camera preview here. + cameraOpenCloseLock.release(); + cameraDevice = currentCameraDevice; + createCameraPreviewSession(); + } + + @Override + public void onDisconnected(@NonNull CameraDevice currentCameraDevice) { + cameraOpenCloseLock.release(); + currentCameraDevice.close(); + cameraDevice = null; + } + + @Override + public void onError(@NonNull CameraDevice currentCameraDevice, int error) { + cameraOpenCloseLock.release(); + currentCameraDevice.close(); + cameraDevice = null; + Activity activity = getActivity(); + if (null != activity) { + activity.finish(); + } + } + }; + + /** An additional thread for running tasks that shouldn't block the UI. */ + private HandlerThread backgroundThread; + + /** A {@link Handler} for running tasks in the background. */ + private Handler backgroundHandler; + + /** An {@link ImageReader} that handles image capture. */ + private ImageReader imageReader; + + /** {@link CaptureRequest.Builder} for the camera preview */ + private CaptureRequest.Builder previewRequestBuilder; + + /** {@link CaptureRequest} generated by {@link #previewRequestBuilder} */ + private CaptureRequest previewRequest; + + /** A {@link Semaphore} to prevent the app from exiting before closing the camera. */ + private Semaphore cameraOpenCloseLock = new Semaphore(1); + + /** A {@link CameraCaptureSession.CaptureCallback} that handles events related to capture. */ + private CameraCaptureSession.CaptureCallback captureCallback = + new CameraCaptureSession.CaptureCallback() { + + @Override + public void onCaptureProgressed( + @NonNull CameraCaptureSession session, + @NonNull CaptureRequest request, + @NonNull CaptureResult partialResult) {} + + @Override + public void onCaptureCompleted( + @NonNull CameraCaptureSession session, + @NonNull CaptureRequest request, + @NonNull TotalCaptureResult result) {} + }; + + /** + * Shows a {@link Toast} on the UI thread for the classification results. + * + * @param text The message to show + */ + private void showToast(final String text) { + final Activity activity = getActivity(); + if (activity != null) { + activity.runOnUiThread( + new Runnable() { + @Override + public void run() { + textView.setText(text); + } + }); + } + } + + /** + * Resizes image. + * + * Attempting to use too large a preview size could exceed the camera bus' bandwidth limitation, + * resulting in gorgeous previews but the storage of garbage capture data. + * + * Given {@code choices} of {@code Size}s supported by a camera, choose the smallest one that is + * at least as large as the respective texture view size, and that is at most as large as the + * respective max size, and whose aspect ratio matches with the specified value. If such size + * doesn't exist, choose the largest one that is at most as large as the respective max size, and + * whose aspect ratio matches with the specified value. + * + * @param choices The list of sizes that the camera supports for the intended output class + * @param textureViewWidth The width of the texture view relative to sensor coordinate + * @param textureViewHeight The height of the texture view relative to sensor coordinate + * @param maxWidth The maximum width that can be chosen + * @param maxHeight The maximum height that can be chosen + * @param aspectRatio The aspect ratio + * @return The optimal {@code Size}, or an arbitrary one if none were big enough + */ + private static Size chooseOptimalSize( + Size[] choices, + int textureViewWidth, + int textureViewHeight, + int maxWidth, + int maxHeight, + Size aspectRatio) { + + // Collect the supported resolutions that are at least as big as the preview Surface + List bigEnough = new ArrayList<>(); + // Collect the supported resolutions that are smaller than the preview Surface + List notBigEnough = new ArrayList<>(); + int w = aspectRatio.getWidth(); + int h = aspectRatio.getHeight(); + for (Size option : choices) { + if (option.getWidth() <= maxWidth + && option.getHeight() <= maxHeight + && option.getHeight() == option.getWidth() * h / w) { + if (option.getWidth() >= textureViewWidth && option.getHeight() >= textureViewHeight) { + bigEnough.add(option); + } else { + notBigEnough.add(option); + } + } + } + + // Pick the smallest of those big enough. If there is no one big enough, pick the + // largest of those not big enough. + if (bigEnough.size() > 0) { + return Collections.min(bigEnough, new CompareSizesByArea()); + } else if (notBigEnough.size() > 0) { + return Collections.max(notBigEnough, new CompareSizesByArea()); + } else { + Log.e(TAG, "Couldn't find any suitable preview size"); + return choices[0]; + } + } + + public static Camera2BasicFragment newInstance() { + return new Camera2BasicFragment(); + } + + /** Layout the preview and buttons. */ + @Override + public View onCreateView( + LayoutInflater inflater, ViewGroup container, Bundle savedInstanceState) { + return inflater.inflate(R.layout.fragment_camera2_basic, container, false); + } + + /** Connect the buttons to their event handler. */ + @Override + public void onViewCreated(final View view, Bundle savedInstanceState) { + textureView = (AutoFitTextureView) view.findViewById(R.id.texture); + textView = (TextView) view.findViewById(R.id.text); + } + + /** Load the model and labels. */ + @Override + public void onActivityCreated(Bundle savedInstanceState) { + super.onActivityCreated(savedInstanceState); + try { + classifier = new ImageClassifier(getActivity()); + } catch (IOException e) { + Log.e(TAG, "Failed to initialize an image classifier."); + } + startBackgroundThread(); + } + + @Override + public void onResume() { + super.onResume(); + startBackgroundThread(); + + // When the screen is turned off and turned back on, the SurfaceTexture is already + // available, and "onSurfaceTextureAvailable" will not be called. In that case, we can open + // a camera and start preview from here (otherwise, we wait until the surface is ready in + // the SurfaceTextureListener). + if (textureView.isAvailable()) { + openCamera(textureView.getWidth(), textureView.getHeight()); + } else { + textureView.setSurfaceTextureListener(surfaceTextureListener); + } + } + + @Override + public void onPause() { + closeCamera(); + stopBackgroundThread(); + super.onPause(); + } + + @Override + public void onDestroy() { + classifier.close(); + super.onDestroy(); + } + + /** + * Sets up member variables related to camera. + * + * @param width The width of available size for camera preview + * @param height The height of available size for camera preview + */ + private void setUpCameraOutputs(int width, int height) { + Activity activity = getActivity(); + CameraManager manager = (CameraManager) activity.getSystemService(Context.CAMERA_SERVICE); + try { + for (String cameraId : manager.getCameraIdList()) { + CameraCharacteristics characteristics = manager.getCameraCharacteristics(cameraId); + + // We don't use a front facing camera in this sample. + Integer facing = characteristics.get(CameraCharacteristics.LENS_FACING); + if (facing != null && facing == CameraCharacteristics.LENS_FACING_FRONT) { + continue; + } + + StreamConfigurationMap map = + characteristics.get(CameraCharacteristics.SCALER_STREAM_CONFIGURATION_MAP); + if (map == null) { + continue; + } + + // // For still image captures, we use the largest available size. + Size largest = + Collections.max( + Arrays.asList(map.getOutputSizes(ImageFormat.JPEG)), new CompareSizesByArea()); + imageReader = + ImageReader.newInstance( + largest.getWidth(), largest.getHeight(), ImageFormat.JPEG, /*maxImages*/ 2); + + // Find out if we need to swap dimension to get the preview size relative to sensor + // coordinate. + int displayRotation = activity.getWindowManager().getDefaultDisplay().getRotation(); + // noinspection ConstantConditions + /* Orientation of the camera sensor */ + int sensorOrientation = characteristics.get(CameraCharacteristics.SENSOR_ORIENTATION); + boolean swappedDimensions = false; + switch (displayRotation) { + case Surface.ROTATION_0: + case Surface.ROTATION_180: + if (sensorOrientation == 90 || sensorOrientation == 270) { + swappedDimensions = true; + } + break; + case Surface.ROTATION_90: + case Surface.ROTATION_270: + if (sensorOrientation == 0 || sensorOrientation == 180) { + swappedDimensions = true; + } + break; + default: + Log.e(TAG, "Display rotation is invalid: " + displayRotation); + } + + Point displaySize = new Point(); + activity.getWindowManager().getDefaultDisplay().getSize(displaySize); + int rotatedPreviewWidth = width; + int rotatedPreviewHeight = height; + int maxPreviewWidth = displaySize.x; + int maxPreviewHeight = displaySize.y; + + if (swappedDimensions) { + rotatedPreviewWidth = height; + rotatedPreviewHeight = width; + maxPreviewWidth = displaySize.y; + maxPreviewHeight = displaySize.x; + } + + if (maxPreviewWidth > MAX_PREVIEW_WIDTH) { + maxPreviewWidth = MAX_PREVIEW_WIDTH; + } + + if (maxPreviewHeight > MAX_PREVIEW_HEIGHT) { + maxPreviewHeight = MAX_PREVIEW_HEIGHT; + } + + previewSize = + chooseOptimalSize( + map.getOutputSizes(SurfaceTexture.class), + rotatedPreviewWidth, + rotatedPreviewHeight, + maxPreviewWidth, + maxPreviewHeight, + largest); + + // We fit the aspect ratio of TextureView to the size of preview we picked. + int orientation = getResources().getConfiguration().orientation; + if (orientation == Configuration.ORIENTATION_LANDSCAPE) { + textureView.setAspectRatio(previewSize.getWidth(), previewSize.getHeight()); + } else { + textureView.setAspectRatio(previewSize.getHeight(), previewSize.getWidth()); + } + + this.cameraId = cameraId; + return; + } + } catch (CameraAccessException e) { + e.printStackTrace(); + } catch (NullPointerException e) { + // Currently an NPE is thrown when the Camera2API is used but not supported on the + // device this code runs. + ErrorDialog.newInstance(getString(R.string.camera_error)) + .show(getChildFragmentManager(), FRAGMENT_DIALOG); + } + } + + private String[] getRequiredPermissions() { + Activity activity = getActivity(); + try { + PackageInfo info = + activity + .getPackageManager() + .getPackageInfo(activity.getPackageName(), PackageManager.GET_PERMISSIONS); + String[] ps = info.requestedPermissions; + if (ps != null && ps.length > 0) { + return ps; + } else { + return new String[0]; + } + } catch (Exception e) { + return new String[0]; + } + } + + /** Opens the camera specified by {@link Camera2BasicFragment#cameraId}. */ + private void openCamera(int width, int height) { + if (!checkedPermissions && !allPermissionsGranted()) { + FragmentCompat.requestPermissions(this, getRequiredPermissions(), PERMISSIONS_REQUEST_CODE); + return; + } else { + checkedPermissions = true; + } + setUpCameraOutputs(width, height); + configureTransform(width, height); + Activity activity = getActivity(); + CameraManager manager = (CameraManager) activity.getSystemService(Context.CAMERA_SERVICE); + try { + if (!cameraOpenCloseLock.tryAcquire(2500, TimeUnit.MILLISECONDS)) { + throw new RuntimeException("Time out waiting to lock camera opening."); + } + manager.openCamera(cameraId, stateCallback, backgroundHandler); + } catch (CameraAccessException e) { + e.printStackTrace(); + } catch (InterruptedException e) { + throw new RuntimeException("Interrupted while trying to lock camera opening.", e); + } + } + + private boolean allPermissionsGranted() { + for (String permission : getRequiredPermissions()) { + if (ContextCompat.checkSelfPermission(getActivity(), permission) + != PackageManager.PERMISSION_GRANTED) { + return false; + } + } + return true; + } + + @Override + public void onRequestPermissionsResult( + int requestCode, @NonNull String[] permissions, @NonNull int[] grantResults) { + super.onRequestPermissionsResult(requestCode, permissions, grantResults); + } + + /** Closes the current {@link CameraDevice}. */ + private void closeCamera() { + try { + cameraOpenCloseLock.acquire(); + if (null != captureSession) { + captureSession.close(); + captureSession = null; + } + if (null != cameraDevice) { + cameraDevice.close(); + cameraDevice = null; + } + if (null != imageReader) { + imageReader.close(); + imageReader = null; + } + } catch (InterruptedException e) { + throw new RuntimeException("Interrupted while trying to lock camera closing.", e); + } finally { + cameraOpenCloseLock.release(); + } + } + + /** Starts a background thread and its {@link Handler}. */ + private void startBackgroundThread() { + backgroundThread = new HandlerThread(HANDLE_THREAD_NAME); + backgroundThread.start(); + backgroundHandler = new Handler(backgroundThread.getLooper()); + synchronized (lock) { + runClassifier = true; + } + backgroundHandler.post(periodicClassify); + } + + /** Stops the background thread and its {@link Handler}. */ + private void stopBackgroundThread() { + backgroundThread.quitSafely(); + try { + backgroundThread.join(); + backgroundThread = null; + backgroundHandler = null; + synchronized (lock) { + runClassifier = false; + } + } catch (InterruptedException e) { + e.printStackTrace(); + } + } + + /** Takes photos and classify them periodically. */ + private Runnable periodicClassify = + new Runnable() { + @Override + public void run() { + synchronized (lock) { + if (runClassifier) { + classifyFrame(); + } + } + backgroundHandler.post(periodicClassify); + } + }; + + /** Creates a new {@link CameraCaptureSession} for camera preview. */ + private void createCameraPreviewSession() { + try { + SurfaceTexture texture = textureView.getSurfaceTexture(); + assert texture != null; + + // We configure the size of default buffer to be the size of camera preview we want. + texture.setDefaultBufferSize(previewSize.getWidth(), previewSize.getHeight()); + + // This is the output Surface we need to start preview. + Surface surface = new Surface(texture); + + // We set up a CaptureRequest.Builder with the output Surface. + previewRequestBuilder = cameraDevice.createCaptureRequest(CameraDevice.TEMPLATE_PREVIEW); + previewRequestBuilder.addTarget(surface); + + // Here, we create a CameraCaptureSession for camera preview. + cameraDevice.createCaptureSession( + Arrays.asList(surface), + new CameraCaptureSession.StateCallback() { + + @Override + public void onConfigured(@NonNull CameraCaptureSession cameraCaptureSession) { + // The camera is already closed + if (null == cameraDevice) { + return; + } + + // When the session is ready, we start displaying the preview. + captureSession = cameraCaptureSession; + try { + // Auto focus should be continuous for camera preview. + previewRequestBuilder.set( + CaptureRequest.CONTROL_AF_MODE, + CaptureRequest.CONTROL_AF_MODE_CONTINUOUS_PICTURE); + + // Finally, we start displaying the camera preview. + previewRequest = previewRequestBuilder.build(); + captureSession.setRepeatingRequest( + previewRequest, captureCallback, backgroundHandler); + } catch (CameraAccessException e) { + e.printStackTrace(); + } + } + + @Override + public void onConfigureFailed(@NonNull CameraCaptureSession cameraCaptureSession) { + showToast("Failed"); + } + }, + null); + } catch (CameraAccessException e) { + e.printStackTrace(); + } + } + + /** + * Configures the necessary {@link android.graphics.Matrix} transformation to `textureView`. This + * method should be called after the camera preview size is determined in setUpCameraOutputs and + * also the size of `textureView` is fixed. + * + * @param viewWidth The width of `textureView` + * @param viewHeight The height of `textureView` + */ + private void configureTransform(int viewWidth, int viewHeight) { + Activity activity = getActivity(); + if (null == textureView || null == previewSize || null == activity) { + return; + } + int rotation = activity.getWindowManager().getDefaultDisplay().getRotation(); + Matrix matrix = new Matrix(); + RectF viewRect = new RectF(0, 0, viewWidth, viewHeight); + RectF bufferRect = new RectF(0, 0, previewSize.getHeight(), previewSize.getWidth()); + float centerX = viewRect.centerX(); + float centerY = viewRect.centerY(); + if (Surface.ROTATION_90 == rotation || Surface.ROTATION_270 == rotation) { + bufferRect.offset(centerX - bufferRect.centerX(), centerY - bufferRect.centerY()); + matrix.setRectToRect(viewRect, bufferRect, Matrix.ScaleToFit.FILL); + float scale = + Math.max( + (float) viewHeight / previewSize.getHeight(), + (float) viewWidth / previewSize.getWidth()); + matrix.postScale(scale, scale, centerX, centerY); + matrix.postRotate(90 * (rotation - 2), centerX, centerY); + } else if (Surface.ROTATION_180 == rotation) { + matrix.postRotate(180, centerX, centerY); + } + textureView.setTransform(matrix); + } + + /** Classifies a frame from the preview stream. */ + private void classifyFrame() { + if (classifier == null || getActivity() == null || cameraDevice == null) { + showToast("Uninitialized Classifier or invalid context."); + return; + } + Bitmap bitmap = + textureView.getBitmap(ImageClassifier.DIM_IMG_SIZE_X, ImageClassifier.DIM_IMG_SIZE_Y); + String textToShow = classifier.classifyFrame(bitmap); + bitmap.recycle(); + showToast(textToShow); + } + + /** Compares two {@code Size}s based on their areas. */ + private static class CompareSizesByArea implements Comparator { + + @Override + public int compare(Size lhs, Size rhs) { + // We cast here to ensure the multiplications won't overflow + return Long.signum( + (long) lhs.getWidth() * lhs.getHeight() - (long) rhs.getWidth() * rhs.getHeight()); + } + } + + /** Shows an error message dialog. */ + public static class ErrorDialog extends DialogFragment { + + private static final String ARG_MESSAGE = "message"; + + public static ErrorDialog newInstance(String message) { + ErrorDialog dialog = new ErrorDialog(); + Bundle args = new Bundle(); + args.putString(ARG_MESSAGE, message); + dialog.setArguments(args); + return dialog; + } + + @Override + public Dialog onCreateDialog(Bundle savedInstanceState) { + final Activity activity = getActivity(); + return new AlertDialog.Builder(activity) + .setMessage(getArguments().getString(ARG_MESSAGE)) + .setPositiveButton( + android.R.string.ok, + new DialogInterface.OnClickListener() { + @Override + public void onClick(DialogInterface dialogInterface, int i) { + activity.finish(); + } + }) + .create(); + } + } +} diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/CameraActivity.java b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/CameraActivity.java new file mode 100644 index 0000000000000000000000000000000000000000..e7161ddb26b379f9dcf6addefa585ccf6431c055 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/CameraActivity.java @@ -0,0 +1,35 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package com.example.android.tflitecamerademo; + +import android.app.Activity; +import android.os.Bundle; + +/** Main {@code Activity} class for the Camera app. */ +public class CameraActivity extends Activity { + + @Override + protected void onCreate(Bundle savedInstanceState) { + super.onCreate(savedInstanceState); + setContentView(R.layout.activity_camera); + if (null == savedInstanceState) { + getFragmentManager() + .beginTransaction() + .replace(R.id.container, Camera2BasicFragment.newInstance()) + .commit(); + } + } +} diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java new file mode 100644 index 0000000000000000000000000000000000000000..e7bad4637041d003c1e507d81c0c30404c587653 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java @@ -0,0 +1,184 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package com.example.android.tflitecamerademo; + +import android.app.Activity; +import android.content.res.AssetFileDescriptor; +import android.graphics.Bitmap; +import android.os.SystemClock; +import android.util.Log; +import java.io.BufferedReader; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.InputStreamReader; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.MappedByteBuffer; +import java.nio.channels.FileChannel; +import java.util.AbstractMap; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import java.util.PriorityQueue; +import org.tensorflow.lite.Interpreter; + +/** Classifies images with Tensorflow Lite. */ +public class ImageClassifier { + + /** Tag for the {@link Log}. */ + private static final String TAG = "TfLiteCameraDemo"; + + /** Name of the model file stored in Assets. */ + private static final String MODEL_PATH = "mobilenet_quant_v1_224.tflite"; + + /** Name of the label file stored in Assets. */ + private static final String LABEL_PATH = "labels.txt"; + + /** Number of results to show in the UI. */ + private static final int RESULTS_TO_SHOW = 3; + + /** Dimensions of inputs. */ + private static final int DIM_BATCH_SIZE = 1; + + private static final int DIM_PIXEL_SIZE = 3; + + static final int DIM_IMG_SIZE_X = 224; + static final int DIM_IMG_SIZE_Y = 224; + + /* Preallocated buffers for storing image data in. */ + private int[] intValues = new int[DIM_IMG_SIZE_X * DIM_IMG_SIZE_Y]; + + /** An instance of the driver class to run model inference with Tensorflow Lite. */ + private Interpreter tflite; + + /** Labels corresponding to the output of the vision model. */ + private List labelList; + + /** A ByteBuffer to hold image data, to be feed into Tensorflow Lite as inputs. */ + private ByteBuffer imgData = null; + + /** An array to hold inference results, to be feed into Tensorflow Lite as outputs. */ + private byte[][] labelProbArray = null; + + private PriorityQueue> sortedLabels = + new PriorityQueue<>( + RESULTS_TO_SHOW, + new Comparator>() { + @Override + public int compare(Map.Entry o1, Map.Entry o2) { + return (o1.getValue()).compareTo(o2.getValue()); + } + }); + + /** Initializes an {@code ImageClassifier}. */ + ImageClassifier(Activity activity) throws IOException { + tflite = new Interpreter(loadModelFile(activity)); + labelList = loadLabelList(activity); + imgData = + ByteBuffer.allocateDirect( + DIM_BATCH_SIZE * DIM_IMG_SIZE_X * DIM_IMG_SIZE_Y * DIM_PIXEL_SIZE); + imgData.order(ByteOrder.nativeOrder()); + labelProbArray = new byte[1][labelList.size()]; + Log.d(TAG, "Created a Tensorflow Lite Image Classifier."); + } + + /** Classifies a frame from the preview stream. */ + String classifyFrame(Bitmap bitmap) { + if (tflite == null) { + Log.e(TAG, "Image classifier has not been initialized; Skipped."); + return "Uninitialized Classifier."; + } + convertBitmapToByteBuffer(bitmap); + // Here's where the magic happens!!! + long startTime = SystemClock.uptimeMillis(); + tflite.run(imgData, labelProbArray); + long endTime = SystemClock.uptimeMillis(); + Log.d(TAG, "Timecost to run model inference: " + Long.toString(endTime - startTime)); + String textToShow = printTopKLabels(); + textToShow = Long.toString(endTime - startTime) + "ms" + textToShow; + return textToShow; + } + + /** Closes tflite to release resources. */ + public void close() { + tflite.close(); + tflite = null; + } + + /** Reads label list from Assets. */ + private List loadLabelList(Activity activity) throws IOException { + List labelList = new ArrayList(); + BufferedReader reader = + new BufferedReader(new InputStreamReader(activity.getAssets().open(LABEL_PATH))); + String line; + while ((line = reader.readLine()) != null) { + labelList.add(line); + } + reader.close(); + return labelList; + } + + /** Memory-map the model file in Assets. */ + private MappedByteBuffer loadModelFile(Activity activity) throws IOException { + AssetFileDescriptor fileDescriptor = activity.getAssets().openFd(MODEL_PATH); + FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor()); + FileChannel fileChannel = inputStream.getChannel(); + long startOffset = fileDescriptor.getStartOffset(); + long declaredLength = fileDescriptor.getDeclaredLength(); + return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength); + } + + /** Writes Image data into a {@code ByteBuffer}. */ + private void convertBitmapToByteBuffer(Bitmap bitmap) { + if (imgData == null) { + return; + } + imgData.rewind(); + bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight()); + // Convert the image to floating point. + int pixel = 0; + long startTime = SystemClock.uptimeMillis(); + for (int i = 0; i < DIM_IMG_SIZE_X; ++i) { + for (int j = 0; j < DIM_IMG_SIZE_Y; ++j) { + final int val = intValues[pixel++]; + imgData.put((byte) ((val >> 16) & 0xFF)); + imgData.put((byte) ((val >> 8) & 0xFF)); + imgData.put((byte) (val & 0xFF)); + } + } + long endTime = SystemClock.uptimeMillis(); + Log.d(TAG, "Timecost to put values into ByteBuffer: " + Long.toString(endTime - startTime)); + } + + /** Prints top-K labels, to be shown in UI as the results. */ + private String printTopKLabels() { + for (int i = 0; i < labelList.size(); ++i) { + sortedLabels.add( + new AbstractMap.SimpleEntry<>(labelList.get(i), (labelProbArray[0][i] & 0xff) / 255.0f)); + if (sortedLabels.size() > RESULTS_TO_SHOW) { + sortedLabels.poll(); + } + } + String textToShow = ""; + final int size = sortedLabels.size(); + for (int i = 0; i < size; ++i) { + Map.Entry label = sortedLabels.poll(); + textToShow = "\n" + label.getKey() + ":" + Float.toString(label.getValue()) + textToShow; + } + return textToShow; + } +} diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/ic_action_info.png b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/ic_action_info.png new file mode 100644 index 0000000000000000000000000000000000000000..e0a70008b10b98162b4710385e21ac65333f1231 Binary files /dev/null and b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/ic_action_info.png differ diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/ic_launcher.png b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/ic_launcher.png new file mode 100644 index 0000000000000000000000000000000000000000..c22509d8dfccae14d9470e3042a9ed5b469ca2c9 Binary files /dev/null and b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/ic_launcher.png differ diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/tile.9.png b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/tile.9.png new file mode 100644 index 0000000000000000000000000000000000000000..a84e3ef52c6dce90ccfa98f64db25fad7a8f0289 Binary files /dev/null and b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/tile.9.png differ diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-mdpi/ic_action_info.png b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-mdpi/ic_action_info.png new file mode 100644 index 0000000000000000000000000000000000000000..520c2dd100b092fad5987dc1b41575e1681b459c Binary files /dev/null and b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-mdpi/ic_action_info.png differ diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-mdpi/ic_launcher.png b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-mdpi/ic_launcher.png new file mode 100644 index 0000000000000000000000000000000000000000..d68af39186ca9cd2bc755cad8397467a11844a1d Binary files /dev/null and b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-mdpi/ic_launcher.png differ diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xhdpi/ic_action_info.png b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xhdpi/ic_action_info.png new file mode 100644 index 0000000000000000000000000000000000000000..1347b091983ebd9d3d58e29194b9335b6c138a2b Binary files /dev/null and b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xhdpi/ic_action_info.png differ diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xhdpi/ic_launcher.png b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xhdpi/ic_launcher.png new file mode 100644 index 0000000000000000000000000000000000000000..15e419b7ccd88651bd21dac36853a827fc4075b8 Binary files /dev/null and b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xhdpi/ic_launcher.png differ diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xxhdpi/ic_action_info.png b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xxhdpi/ic_action_info.png new file mode 100644 index 0000000000000000000000000000000000000000..fd933333b71590608d91201aad29553f9b365b6a Binary files /dev/null and b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xxhdpi/ic_action_info.png differ diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xxhdpi/ic_launcher.png b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xxhdpi/ic_launcher.png new file mode 100644 index 0000000000000000000000000000000000000000..342ce34e1663960d8d7050a9be57face3571d336 Binary files /dev/null and b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xxhdpi/ic_launcher.png differ diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/layout-land/fragment_camera2_basic.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout-land/fragment_camera2_basic.xml new file mode 100644 index 0000000000000000000000000000000000000000..a84f1bbfa0cb48a3fc335c9bc4aa7d8e93d20e75 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout-land/fragment_camera2_basic.xml @@ -0,0 +1,50 @@ + + + + + + + + + + + + + diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/layout/activity_camera.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout/activity_camera.xml new file mode 100644 index 0000000000000000000000000000000000000000..286e549c6569cef4b7a9e46f9c73e6f43b6d3045 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout/activity_camera.xml @@ -0,0 +1,22 @@ + + diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/layout/fragment_camera2_basic.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout/fragment_camera2_basic.xml new file mode 100644 index 0000000000000000000000000000000000000000..15305c436e0d997af15a326ab4027ea713ed8098 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout/fragment_camera2_basic.xml @@ -0,0 +1,45 @@ + + + + + + + + + + + + diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/values-sw600dp/template-dimens.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/values-sw600dp/template-dimens.xml new file mode 100644 index 0000000000000000000000000000000000000000..22074a2bdbaf60efff64d98a0788ef797a966f80 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/values-sw600dp/template-dimens.xml @@ -0,0 +1,24 @@ + + + + + + + @dimen/margin_huge + @dimen/margin_medium + + diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/values-sw600dp/template-styles.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/values-sw600dp/template-styles.xml new file mode 100644 index 0000000000000000000000000000000000000000..03d1974183dd645178c07d247d61b83d067806be --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/values-sw600dp/template-styles.xml @@ -0,0 +1,25 @@ + + + + + + + diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/values-v11/template-styles.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/values-v11/template-styles.xml new file mode 100644 index 0000000000000000000000000000000000000000..8c1ea66f28907ac211f355f4220ff4582cfb31eb --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/values-v11/template-styles.xml @@ -0,0 +1,22 @@ + + + + + + + + diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/values/base-strings.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/base-strings.xml new file mode 100644 index 0000000000000000000000000000000000000000..0a71dbd0e8010f5e3a176de1f7e8321331289f7c --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/base-strings.xml @@ -0,0 +1,30 @@ + + + + + TfLiteCameraDemo + + + + diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/values/colors.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/colors.xml new file mode 100644 index 0000000000000000000000000000000000000000..4b75d2b2bda0f95166d0442ebae19cedcad162d8 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/colors.xml @@ -0,0 +1,19 @@ + + + + #cc4285f4 + diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/values/strings.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/strings.xml new file mode 100644 index 0000000000000000000000000000000000000000..a08ec3eb629250a727cec49a822375fe5569f455 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/strings.xml @@ -0,0 +1,24 @@ + + + Picture + Info + This sample needs camera permission. + This device doesn\'t support Camera2 API. + NN:On + NN:Off + Use NNAPI + diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/values/styles.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/styles.xml new file mode 100644 index 0000000000000000000000000000000000000000..3f3bdfb49480e779c108cd15da854ae82a118d52 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/values/styles.xml @@ -0,0 +1,18 @@ + + + + + + + diff --git a/tensorflow/contrib/lite/java/demo/build.gradle b/tensorflow/contrib/lite/java/demo/build.gradle new file mode 100644 index 0000000000000000000000000000000000000000..b78a0b86c939620b6f05483ce45c4d3ef0ef595e --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/build.gradle @@ -0,0 +1,23 @@ +// Top-level build file where you can add configuration options common to all sub-projects/modules. + +buildscript { + repositories { + jcenter() + } + dependencies { + classpath 'com.android.tools.build:gradle:2.3.1' + + // NOTE: Do not place your application dependencies here; they belong + // in the individual module build.gradle files + } +} + +allprojects { + repositories { + jcenter() + } +} + +task clean(type: Delete) { + delete rootProject.buildDir +} diff --git a/tensorflow/contrib/lite/java/demo/gradle.properties b/tensorflow/contrib/lite/java/demo/gradle.properties new file mode 100644 index 0000000000000000000000000000000000000000..aac7c9b4614ccfde6c721f24994cf30885a791d0 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/gradle.properties @@ -0,0 +1,17 @@ +# Project-wide Gradle settings. + +# IDE (e.g. Android Studio) users: +# Gradle settings configured through the IDE *will override* +# any settings specified in this file. + +# For more details on how to configure your build environment visit +# http://www.gradle.org/docs/current/userguide/build_environment.html + +# Specifies the JVM arguments used for the daemon process. +# The setting is particularly useful for tweaking memory settings. +org.gradle.jvmargs=-Xmx1536m + +# When configured, Gradle will run in incubating parallel mode. +# This option should only be used with decoupled projects. More details, visit +# http://www.gradle.org/docs/current/userguide/multi_project_builds.html#sec:decoupled_projects +# org.gradle.parallel=true diff --git a/tensorflow/contrib/lite/java/demo/gradle/wrapper/gradle-wrapper.jar b/tensorflow/contrib/lite/java/demo/gradle/wrapper/gradle-wrapper.jar new file mode 100644 index 0000000000000000000000000000000000000000..13372aef5e24af05341d49695ee84e5f9b594659 Binary files /dev/null and b/tensorflow/contrib/lite/java/demo/gradle/wrapper/gradle-wrapper.jar differ diff --git a/tensorflow/contrib/lite/java/demo/gradle/wrapper/gradle-wrapper.properties b/tensorflow/contrib/lite/java/demo/gradle/wrapper/gradle-wrapper.properties new file mode 100644 index 0000000000000000000000000000000000000000..fa7a38a0e43eecd1e7292dd49efa79a5d0742e2a --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/gradle/wrapper/gradle-wrapper.properties @@ -0,0 +1,6 @@ +#Thu Sep 28 09:01:41 PDT 2017 +distributionBase=GRADLE_USER_HOME +distributionPath=wrapper/dists +zipStoreBase=GRADLE_USER_HOME +zipStorePath=wrapper/dists +distributionUrl=https\://services.gradle.org/distributions/gradle-3.3-all.zip diff --git a/tensorflow/contrib/lite/java/demo/gradlew b/tensorflow/contrib/lite/java/demo/gradlew new file mode 100755 index 0000000000000000000000000000000000000000..9d82f78915133e1c35a6ea51252590fb38efac2f --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/gradlew @@ -0,0 +1,160 @@ +#!/usr/bin/env bash + +############################################################################## +## +## Gradle start up script for UN*X +## +############################################################################## + +# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +DEFAULT_JVM_OPTS="" + +APP_NAME="Gradle" +APP_BASE_NAME=`basename "$0"` + +# Use the maximum available, or set MAX_FD != -1 to use that value. +MAX_FD="maximum" + +warn ( ) { + echo "$*" +} + +die ( ) { + echo + echo "$*" + echo + exit 1 +} + +# OS specific support (must be 'true' or 'false'). +cygwin=false +msys=false +darwin=false +case "`uname`" in + CYGWIN* ) + cygwin=true + ;; + Darwin* ) + darwin=true + ;; + MINGW* ) + msys=true + ;; +esac + +# Attempt to set APP_HOME +# Resolve links: $0 may be a link +PRG="$0" +# Need this for relative symlinks. +while [ -h "$PRG" ] ; do + ls=`ls -ld "$PRG"` + link=`expr "$ls" : '.*-> \(.*\)$'` + if expr "$link" : '/.*' > /dev/null; then + PRG="$link" + else + PRG=`dirname "$PRG"`"/$link" + fi +done +SAVED="`pwd`" +cd "`dirname \"$PRG\"`/" >/dev/null +APP_HOME="`pwd -P`" +cd "$SAVED" >/dev/null + +CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar + +# Determine the Java command to use to start the JVM. +if [ -n "$JAVA_HOME" ] ; then + if [ -x "$JAVA_HOME/jre/sh/java" ] ; then + # IBM's JDK on AIX uses strange locations for the executables + JAVACMD="$JAVA_HOME/jre/sh/java" + else + JAVACMD="$JAVA_HOME/bin/java" + fi + if [ ! -x "$JAVACMD" ] ; then + die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." + fi +else + JAVACMD="java" + which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." +fi + +# Increase the maximum file descriptors if we can. +if [ "$cygwin" = "false" -a "$darwin" = "false" ] ; then + MAX_FD_LIMIT=`ulimit -H -n` + if [ $? -eq 0 ] ; then + if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then + MAX_FD="$MAX_FD_LIMIT" + fi + ulimit -n $MAX_FD + if [ $? -ne 0 ] ; then + warn "Could not set maximum file descriptor limit: $MAX_FD" + fi + else + warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT" + fi +fi + +# For Darwin, add options to specify how the application appears in the dock +if $darwin; then + GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\"" +fi + +# For Cygwin, switch paths to Windows format before running java +if $cygwin ; then + APP_HOME=`cygpath --path --mixed "$APP_HOME"` + CLASSPATH=`cygpath --path --mixed "$CLASSPATH"` + JAVACMD=`cygpath --unix "$JAVACMD"` + + # We build the pattern for arguments to be converted via cygpath + ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null` + SEP="" + for dir in $ROOTDIRSRAW ; do + ROOTDIRS="$ROOTDIRS$SEP$dir" + SEP="|" + done + OURCYGPATTERN="(^($ROOTDIRS))" + # Add a user-defined pattern to the cygpath arguments + if [ "$GRADLE_CYGPATTERN" != "" ] ; then + OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)" + fi + # Now convert the arguments - kludge to limit ourselves to /bin/sh + i=0 + for arg in "$@" ; do + CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -` + CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option + + if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition + eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"` + else + eval `echo args$i`="\"$arg\"" + fi + i=$((i+1)) + done + case $i in + (0) set -- ;; + (1) set -- "$args0" ;; + (2) set -- "$args0" "$args1" ;; + (3) set -- "$args0" "$args1" "$args2" ;; + (4) set -- "$args0" "$args1" "$args2" "$args3" ;; + (5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;; + (6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;; + (7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;; + (8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;; + (9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;; + esac +fi + +# Split up the JVM_OPTS And GRADLE_OPTS values into an array, following the shell quoting and substitution rules +function splitJvmOpts() { + JVM_OPTS=("$@") +} +eval splitJvmOpts $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS +JVM_OPTS[${#JVM_OPTS[*]}]="-Dorg.gradle.appname=$APP_BASE_NAME" + +exec "$JAVACMD" "${JVM_OPTS[@]}" -classpath "$CLASSPATH" org.gradle.wrapper.GradleWrapperMain "$@" diff --git a/tensorflow/contrib/lite/java/demo/gradlew.bat b/tensorflow/contrib/lite/java/demo/gradlew.bat new file mode 100644 index 0000000000000000000000000000000000000000..8a0b282aa6885fb573c106b3551f7275c5f17e8e --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/gradlew.bat @@ -0,0 +1,90 @@ +@if "%DEBUG%" == "" @echo off +@rem ########################################################################## +@rem +@rem Gradle startup script for Windows +@rem +@rem ########################################################################## + +@rem Set local scope for the variables with windows NT shell +if "%OS%"=="Windows_NT" setlocal + +@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +set DEFAULT_JVM_OPTS= + +set DIRNAME=%~dp0 +if "%DIRNAME%" == "" set DIRNAME=. +set APP_BASE_NAME=%~n0 +set APP_HOME=%DIRNAME% + +@rem Find java.exe +if defined JAVA_HOME goto findJavaFromJavaHome + +set JAVA_EXE=java.exe +%JAVA_EXE% -version >NUL 2>&1 +if "%ERRORLEVEL%" == "0" goto init + +echo. +echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. +echo. +echo Please set the JAVA_HOME variable in your environment to match the +echo location of your Java installation. + +goto fail + +:findJavaFromJavaHome +set JAVA_HOME=%JAVA_HOME:"=% +set JAVA_EXE=%JAVA_HOME%/bin/java.exe + +if exist "%JAVA_EXE%" goto init + +echo. +echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% +echo. +echo Please set the JAVA_HOME variable in your environment to match the +echo location of your Java installation. + +goto fail + +:init +@rem Get command-line arguments, handling Windowz variants + +if not "%OS%" == "Windows_NT" goto win9xME_args +if "%@eval[2+2]" == "4" goto 4NT_args + +:win9xME_args +@rem Slurp the command line arguments. +set CMD_LINE_ARGS= +set _SKIP=2 + +:win9xME_args_slurp +if "x%~1" == "x" goto execute + +set CMD_LINE_ARGS=%* +goto execute + +:4NT_args +@rem Get arguments from the 4NT Shell from JP Software +set CMD_LINE_ARGS=%$ + +:execute +@rem Setup the command line + +set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar + +@rem Execute Gradle +"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %CMD_LINE_ARGS% + +:end +@rem End local scope for the variables with windows NT shell +if "%ERRORLEVEL%"=="0" goto mainEnd + +:fail +rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of +rem the _cmd.exe /c_ return code! +if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1 +exit /b 1 + +:mainEnd +if "%OS%"=="Windows_NT" endlocal + +:omega diff --git a/tensorflow/contrib/lite/java/demo/settings.gradle b/tensorflow/contrib/lite/java/demo/settings.gradle new file mode 100644 index 0000000000000000000000000000000000000000..e7b4def49cb53d9aa04228dd3edb14c9e635e003 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/settings.gradle @@ -0,0 +1 @@ +include ':app' diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/DataType.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/DataType.java new file mode 100644 index 0000000000000000000000000000000000000000..d63c299589d2e8ce1051a52d29b533ed126bbcf7 --- /dev/null +++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/DataType.java @@ -0,0 +1,76 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite; + +/** Type of elements in a {@link TfLiteTensor}. */ +enum DataType { + /** 32-bit single precision floating point. */ + FLOAT32(1), + + /** 32-bit signed integer. */ + INT32(2), + + /** 8-bit unsigned integer. */ + UINT8(3), + + /** 64-bit signed integer. */ + INT64(4), + + /** A {@link ByteBuffer}. */ + BYTEBUFFER(999); + + private final int value; + + DataType(int value) { + this.value = value; + } + + /** Corresponding value of the kTfLite* enum in the TensorFlow Lite CC API. */ + int getNumber() { + return value; + } + + /** Converts an integer to the corresponding type. */ + static DataType fromNumber(int c) { + for (DataType t : values) { + if (t.value == c) { + return t; + } + } + throw new IllegalArgumentException( + "DataType " + c + " is not recognized in Java (version " + TensorFlowLite.version() + ")"); + } + + /** Returns byte size of the type. */ + int elemByteSize() { + switch (this) { + case FLOAT32: + return 4; + case INT32: + return 4; + case UINT8: + return 1; + case INT64: + return 8; + case BYTEBUFFER: + return 1; + } + throw new IllegalArgumentException("DataType " + this + " is not supported yet"); + } + + // Cached to avoid copying it + private static final DataType[] values = values(); +} diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java new file mode 100644 index 0000000000000000000000000000000000000000..dd883d69d2065236ee29012b9bde99972aefbcf7 --- /dev/null +++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java @@ -0,0 +1,172 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite; + +import java.io.File; +import java.nio.MappedByteBuffer; +import java.util.HashMap; +import java.util.Map; +import javax.validation.constraints.NotNull; + +/** + * Driver class to drive model inference with TensorFlow Lite. + * + *

A {@code Interpreter} encapsulates a pre-trained TensorFlow Lite model, in which operations + * are executed for model inference. + * + *

For example, if a model takes only one input and returns only one output: + * + *

{@code
+ * try (Interpreter interpreter = new Interpreter(file_of_a_tensorflowlite_model)) {
+ *   interpreter.run(input, output);
+ * }
+ * }
+ * + *

If a model takes multiple inputs or outputs: + * + *

{@code
+ * Object[] inputs = {input0, input1, ...};
+ * Map map_of_indices_to_outputs = new HashMap<>();
+ * float[][][] ith_output = new float[3][2][4];
+ * map_of_indices_to_outputs.put(i, ith_output);
+ * try (Interpreter interpreter = new Interpreter(file_of_a_tensorflowlite_model)) {
+ *   interpreter.runForMultipleInputsOutputs(inputs, map_of_indices_to_outputs);
+ * }
+ * }
+ * + *

Orders of inputs and outputs are determined when converting TensorFlow model to TensorFlowLite + * model with Toco. + * + *

WARNING:Instances of a {@code Interpreter} is not thread-safe. A {@code + * Interpreter} owns resources that must be explicitly freed by invoking {@link #close()} + */ +public final class Interpreter implements AutoCloseable { + + /** + * Initializes a {@code Interpreter} + * + * @param modelFile: a File of a pre-trained TF Lite model. + */ + public Interpreter(@NotNull File modelFile) { + if (modelFile == null) { + return; + } + wrapper = new NativeInterpreterWrapper(modelFile.getAbsolutePath()); + } + + /** + * Initializes a {@code Interpreter} with a {@code MappedByteBuffer} to the model file. + * + *

The {@code MappedByteBuffer} should remain unchanged after the construction of a {@code + * Interpreter}. + */ + public Interpreter(@NotNull MappedByteBuffer mappedByteBuffer) { + wrapper = new NativeInterpreterWrapper(mappedByteBuffer); + } + + /** + * Runs model inference if the model takes only one input, and provides only one output. + * + * @param input an array or multidimensional array, or a {@link ByteBuffer} of primitive types + * including int, float, long, and byte. {@link ByteBuffer} is the preferred way to pass large + * input data. When {@link ByteBuffer} is used, its content should remain unchanged until + * model inference is done. + * @param output a multidimensional array of output data. + */ + public void run(@NotNull Object input, @NotNull Object output) { + Object[] inputs = {input}; + Map outputs = new HashMap<>(); + outputs.put(0, output); + runForMultipleInputsOutputs(inputs, outputs); + } + + /** + * Runs model inference if the model takes multiple inputs, or returns multiple outputs. + * + * @param inputs an array of input data. The inputs should be in the same order as inputs of the + * model. Each input can be an array or multidimensional array, or a {@link ByteBuffer} of + * primitive types including int, float, long, and byte. {@link ByteBuffer} is the preferred + * way to pass large input data. When {@link ByteBuffer} is used, its content should remain + * unchanged until model inference is done. + * @param outputs a map mapping output indices to multidimensional arrays of output data. It only + * needs to keep entries for the outputs to be used. + */ + public void runForMultipleInputsOutputs( + @NotNull Object[] inputs, @NotNull Map outputs) { + if (wrapper == null) { + throw new IllegalStateException("The Interpreter has already been closed."); + } + Tensor[] tensors = wrapper.run(inputs); + if (outputs == null || tensors == null || outputs.size() > tensors.length) { + throw new IllegalArgumentException("Outputs do not match with model outputs."); + } + final int size = tensors.length; + for (Integer idx : outputs.keySet()) { + if (idx == null || idx < 0 || idx >= size) { + throw new IllegalArgumentException( + String.format("Invalid index of output %d (should be in range [0, %d))", idx, size)); + } + tensors[idx].copyTo(outputs.get(idx)); + } + } + + /** + * Resizes idx-th input of the native model to the given dims. + * + *

IllegalArgumentException will be thrown if it fails to resize. + */ + public void resizeInput(int idx, @NotNull int[] dims) { + if (wrapper == null) { + throw new IllegalStateException("The Interpreter has already been closed."); + } + wrapper.resizeInput(idx, dims); + } + + /** + * Gets index of an input given the op name of the input. + * + *

IllegalArgumentException will be thrown if the op name does not exist in the model file used + * to initialize the {@link Interpreter}. + */ + public int getInputIndex(String opName) { + if (wrapper == null) { + throw new IllegalStateException("The Interpreter has already been closed."); + } + return wrapper.getInputIndex(opName); + } + + /** + * Gets index of an output given the op name of the output. + * + *

IllegalArgumentException will be thrown if the op name does not exist in the model file used + * to initialize the {@link Interpreter}. + */ + public int getOutputIndex(String opName) { + if (wrapper == null) { + throw new IllegalStateException("The Interpreter has already been closed."); + } + return wrapper.getOutputIndex(opName); + } + + /** Release resources associated with the {@code Interpreter}. */ + @Override + public void close() { + wrapper.close(); + wrapper = null; + } + + NativeInterpreterWrapper wrapper; +} diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java new file mode 100644 index 0000000000000000000000000000000000000000..1939a078ad8031b99620773c9b91335c4e8f7b22 --- /dev/null +++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java @@ -0,0 +1,276 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite; + +import java.lang.reflect.Array; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.MappedByteBuffer; +import java.util.HashMap; +import java.util.Map; + +/** + * A wrapper wraps native interpreter and controls model execution. + * + *

WARNING: Resources consumed by the {@code NativeInterpreterWrapper} object must be + * explicitly freed by invoking the {@link #close()} method when the {@code + * NativeInterpreterWrapper} object is no longer needed. + */ +final class NativeInterpreterWrapper implements AutoCloseable { + + NativeInterpreterWrapper(String modelPath) { + errorHandle = createErrorReporter(ERROR_BUFFER_SIZE); + modelHandle = createModel(modelPath, errorHandle); + interpreterHandle = createInterpreter(modelHandle); + } + + /** + * Initializes a {@code NativeInterpreterWrapper} with a {@code MappedByteBuffer}. The + * MappedByteBuffer should not be modified after the construction of a {@code + * NativeInterpreterWrapper}. + */ + NativeInterpreterWrapper(MappedByteBuffer mappedByteBuffer) { + modelByteBuffer = mappedByteBuffer; + errorHandle = createErrorReporter(ERROR_BUFFER_SIZE); + modelHandle = createModelWithBuffer(modelByteBuffer, errorHandle); + interpreterHandle = createInterpreter(modelHandle); + } + + /** Releases resources associated with this {@code NativeInterpreterWrapper}. */ + @Override + public void close() { + delete(errorHandle, modelHandle, interpreterHandle); + errorHandle = 0; + modelHandle = 0; + interpreterHandle = 0; + modelByteBuffer = null; + inputsIndexes = null; + outputsIndexes = null; + } + + /** Sets inputs, runs model inference and returns outputs. */ + Tensor[] run(Object[] inputs) { + if (inputs == null || inputs.length == 0) { + throw new IllegalArgumentException("Invalid inputs. Inputs should not be null or empty."); + } + int[] dataTypes = new int[inputs.length]; + Object[] sizes = new Object[inputs.length]; + int[] numsOfBytes = new int[inputs.length]; + for (int i = 0; i < inputs.length; ++i) { + DataType dataType = dataTypeOf(inputs[i]); + dataTypes[i] = dataType.getNumber(); + if (dataType == DataType.BYTEBUFFER) { + ByteBuffer buffer = (ByteBuffer) inputs[i]; + if (buffer.order() != ByteOrder.nativeOrder()) { + throw new IllegalArgumentException( + "Invalid ByteBuffer. It shoud use ByteOrder.nativeOrder()."); + } + numsOfBytes[i] = buffer.limit(); + sizes[i] = getInputDims(interpreterHandle, i, numsOfBytes[i]); + } else if (isNonEmptyArray(inputs[i])) { + int[] dims = shapeOf(inputs[i]); + sizes[i] = dims; + numsOfBytes[i] = dataType.elemByteSize() * numElements(dims); + } else { + throw new IllegalArgumentException( + String.format( + "%d-th element of the %d inputs is not an array or a ByteBuffer.", + i, inputs.length)); + } + } + long[] outputsHandles = + run(interpreterHandle, errorHandle, sizes, dataTypes, numsOfBytes, inputs); + if (outputsHandles == null || outputsHandles.length == 0) { + throw new IllegalStateException("Interpreter has no outputs."); + } + Tensor[] outputs = new Tensor[outputsHandles.length]; + for (int i = 0; i < outputsHandles.length; ++i) { + outputs[i] = Tensor.fromHandle(outputsHandles[i]); + } + return outputs; + } + + /** Resizes dimensions of a specific input. */ + void resizeInput(int idx, int[] dims) { + resizeInput(interpreterHandle, errorHandle, idx, dims); + } + + void setUseNNAPI(boolean useNNAPI) { + useNNAPI(interpreterHandle, useNNAPI); + } + + /** Gets index of an input given its name. */ + int getInputIndex(String name) { + if (inputsIndexes == null) { + String[] names = getInputNames(interpreterHandle); + inputsIndexes = new HashMap<>(); + if (names != null) { + for (int i = 0; i < names.length; ++i) { + inputsIndexes.put(names[i], i); + } + } + } + if (inputsIndexes.containsKey(name)) { + return inputsIndexes.get(name); + } else { + throw new IllegalArgumentException( + String.format( + "%s is not a valid name for any input. The indexes of the inputs are %s", + name, inputsIndexes.toString())); + } + } + + /** Gets index of an output given its name. */ + int getOutputIndex(String name) { + if (outputsIndexes == null) { + String[] names = getOutputNames(interpreterHandle); + outputsIndexes = new HashMap<>(); + if (names != null) { + for (int i = 0; i < names.length; ++i) { + outputsIndexes.put(names[i], i); + } + } + } + if (outputsIndexes.containsKey(name)) { + return outputsIndexes.get(name); + } else { + throw new IllegalArgumentException( + String.format( + "%s is not a valid name for any output. The indexes of the outputs are %s", + name, outputsIndexes.toString())); + } + } + + static int numElements(int[] shape) { + if (shape == null) { + return 0; + } + int n = 1; + for (int i = 0; i < shape.length; i++) { + n *= shape[i]; + } + return n; + } + + static boolean isNonEmptyArray(Object o) { + return (o != null && o.getClass().isArray() && Array.getLength(o) != 0); + } + + /** Returns the type of the data. */ + static DataType dataTypeOf(Object o) { + if (o != null) { + Class c = o.getClass(); + while (c.isArray()) { + c = c.getComponentType(); + } + if (float.class.equals(c)) { + return DataType.FLOAT32; + } else if (int.class.equals(c)) { + return DataType.INT32; + } else if (byte.class.equals(c)) { + return DataType.UINT8; + } else if (long.class.equals(c)) { + return DataType.INT64; + } else if (ByteBuffer.class.isInstance(o)) { + return DataType.BYTEBUFFER; + } + } + throw new IllegalArgumentException("cannot resolve DataType of " + o.getClass().getName()); + } + + /** Returns the shape of an object as an int array. */ + static int[] shapeOf(Object o) { + int size = numDimensions(o); + int[] dimensions = new int[size]; + fillShape(o, 0, dimensions); + return dimensions; + } + + static int numDimensions(Object o) { + if (o == null || !o.getClass().isArray()) { + return 0; + } + if (Array.getLength(o) == 0) { + throw new IllegalArgumentException("array lengths cannot be 0."); + } + return 1 + numDimensions(Array.get(o, 0)); + } + + static void fillShape(Object o, int dim, int[] shape) { + if (shape == null || dim == shape.length) { + return; + } + final int len = Array.getLength(o); + if (shape[dim] == 0) { + shape[dim] = len; + } else if (shape[dim] != len) { + throw new IllegalArgumentException( + String.format("mismatched lengths (%d and %d) in dimension %d", shape[dim], len, dim)); + } + for (int i = 0; i < len; ++i) { + fillShape(Array.get(o, i), dim + 1, shape); + } + } + + private static final int ERROR_BUFFER_SIZE = 512; + + private long errorHandle; + + private long interpreterHandle; + + private long modelHandle; + + private int inputSize; + + private MappedByteBuffer modelByteBuffer; + + private Map inputsIndexes; + + private Map outputsIndexes; + + private static native String[] getInputNames(long interpreterHandle); + + private static native String[] getOutputNames(long interpreterHandle); + + private static native void resizeInput( + long interpreterHandle, long errorHandle, int inputIdx, int[] dims); + + private static native void useNNAPI(long interpreterHandle, boolean state); + + private static native long createErrorReporter(int size); + + private static native long createModel(String modelPathOrBuffer, long errorHandle); + + private static native long createModelWithBuffer(MappedByteBuffer modelBuffer, long errorHandle); + + private static native long createInterpreter(long modelHandle); + + private static native long[] run( + long interpreterHandle, + long errorHandle, + Object[] sizes, + int[] dtypes, + int[] numsOfBytes, + Object[] values); + + private static native void delete(long errorHandle, long modelHandle, long interpreterHandle); + + private static native int[] getInputDims(long interpreterHandle, int inputIdx, int numBytes); + + static { + TensorFlowLite.init(); + } +} diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java new file mode 100644 index 0000000000000000000000000000000000000000..54ace6c63ce5bd1b38be744176d0378e3cc8a1d3 --- /dev/null +++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java @@ -0,0 +1,71 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite; + +import java.util.Arrays; + +/** + * A typed multi-dimensional array used in Tensorflow Lite. + * + *

The native handle of a {@code Tensor} belongs to {@code NativeInterpreterWrapper}, thus not + * needed to be closed here. + */ +final class Tensor { + + static Tensor fromHandle(long nativeHandle) { + return new Tensor(nativeHandle); + } + + /** Reads Tensor content into an array. */ + T copyTo(T dst) { + if (NativeInterpreterWrapper.dataTypeOf(dst) != dtype) { + throw new IllegalArgumentException( + String.format( + "Cannot convert an TensorFlowLite tensor with type %s to a Java object of " + + "type %s (which is compatible with the TensorFlowLite type %s)", + dtype, dst.getClass().getName(), NativeInterpreterWrapper.dataTypeOf(dst))); + } + int[] dstShape = NativeInterpreterWrapper.shapeOf(dst); + if (!Arrays.equals(dstShape, shapeCopy)) { + throw new IllegalArgumentException( + String.format( + "Shape of output target %s does not match with the shape of the Tensor %s.", + Arrays.toString(dstShape), Arrays.toString(shapeCopy))); + } + readMultiDimensionalArray(nativeHandle, dst); + return dst; + } + + final long nativeHandle; + final DataType dtype; + final int[] shapeCopy; + + private Tensor(long nativeHandle) { + this.nativeHandle = nativeHandle; + this.dtype = DataType.fromNumber(dtype(nativeHandle)); + this.shapeCopy = shape(nativeHandle); + } + + private static native int dtype(long handle); + + private static native int[] shape(long handle); + + private static native void readMultiDimensionalArray(long handle, Object value); + + static { + TensorFlowLite.init(); + } +} diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/TensorFlowLite.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/TensorFlowLite.java new file mode 100644 index 0000000000000000000000000000000000000000..711638a9f995ce270cd362b93a7bcfca990430dc --- /dev/null +++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/TensorFlowLite.java @@ -0,0 +1,44 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite; + +/** Static utility methods loading the TensorFlowLite runtime. */ +public final class TensorFlowLite { + + private static final String LIBNAME = "tensorflowlite_jni"; + + private TensorFlowLite() {} + + /** Returns the version of the underlying TensorFlowLite runtime. */ + public static native String version(); + + /** + * Load the TensorFlowLite runtime C library. + */ + static boolean init() { + try { + System.loadLibrary(LIBNAME); + return true; + } catch (UnsatisfiedLinkError e) { + System.err.println("TensorFlowLite: failed to load native library: " + e.getMessage()); + return false; + } + } + + static { + init(); + } +} diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/package-info.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/package-info.java new file mode 100644 index 0000000000000000000000000000000000000000..68e6a0f57810f6d9675a5d1193601e43e172ab74 --- /dev/null +++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/package-info.java @@ -0,0 +1,17 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +/** Defines classes to load and execute TensorFlowLite models. */ +package org.tensorflow.lite; diff --git a/tensorflow/contrib/lite/java/src/main/native/BUILD b/tensorflow/contrib/lite/java/src/main/native/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..15806d57c8ed7a45d2db9b80e2aab8e22349ee3e --- /dev/null +++ b/tensorflow/contrib/lite/java/src/main/native/BUILD @@ -0,0 +1,108 @@ +# Description: +# Java Native Interface (JNI) library intended for implementing the +# TensorFlow Lite Java API using the TensorFlow Lite CC library. + +package(default_visibility = ["//visibility:public"]) + +load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts") + +licenses(["notice"]) # Apache 2.0 + +cc_library( + name = "native_framework_only", + srcs = [ + "exception_jni.cc", + "nativeinterpreterwrapper_jni.cc", + "tensor_jni.cc", + "tensorflow_lite_jni.cc", + ] + select({ + # The Android toolchain makes "jni.h" available in the include path. + # For non-Android toolchains, generate jni.h and jni_md.h. + "//tensorflow:android": [], + "//conditions:default": [ + ":jni.h", + ":jni_md.h", + ], + }), + hdrs = [ + "exception_jni.h", + "nativeinterpreterwrapper_jni.h", + "tensor_jni.h", + "tensorflow_lite_jni.h", + ], + copts = tflite_copts(), + includes = select({ + "//tensorflow:android": [], + "//conditions:default": ["."], + }), + linkopts = [ + "-lm", + "-ldl", + ], + deps = [ + "//tensorflow/contrib/lite:context", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:schema_fbs_version", + ], + alwayslink = 1, +) + +# Silly rules to make +# #include +# in the source headers work +# (in combination with the "includes" attribute of the tf_cuda_library rule +# above. Not needed when using the Android toolchain). +# +# Inspired from: +# https://github.com/bazelbuild/bazel/blob/f99a0543f8d97339d32075c7176b79f35be84606/src/main/native/BUILD +# but hopefully there is a simpler alternative to this. +genrule( + name = "copy_jni_h", + srcs = ["@bazel_tools//tools/jdk:jni_header"], + outs = ["jni.h"], + cmd = "cp -f $< $@", +) + +genrule( + name = "copy_jni_md_h", + srcs = select({ + "//tensorflow:darwin": ["@bazel_tools//tools/jdk:jni_md_header-darwin"], + "//conditions:default": ["@bazel_tools//tools/jdk:jni_md_header-linux"], + }), + outs = ["jni_md.h"], + cmd = "cp -f $< $@", +) + +# This includes all ops. If you want a smaller binary, you should copy and +# modify builtin_ops_jni.cc. You should then link your binary against both +# ":native_framework_only" and your own version of ":native_builtin_ops". +cc_library( + name = "native", + srcs = [ + "builtin_ops_jni.cc", + ], + copts = tflite_copts(), + deps = [ + ":native_framework_only", + "//tensorflow/contrib/lite/kernels:builtin_ops", + ], + alwayslink = 1, +) + +exports_files( + [ + "version_script.lds", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/lite/java/src/main/native/builtin_ops_jni.cc b/tensorflow/contrib/lite/java/src/main/native/builtin_ops_jni.cc new file mode 100644 index 0000000000000000000000000000000000000000..cce356370fa770de3e44438f08470077fb07c04c --- /dev/null +++ b/tensorflow/contrib/lite/java/src/main/native/builtin_ops_jni.cc @@ -0,0 +1,29 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/kernels/register.h" + +namespace tflite { + +// The JNI code in interpreter_jni.cc expects a CreateOpResolver() function in +// the tflite namespace. This one instantiates a BuiltinOpResolver, with all the +// builtin ops. For smaller binary sizes users should avoid linking this in, and +// should provide a custom make CreateOpResolver() instead. +std::unique_ptr CreateOpResolver() { // NOLINT + return std::unique_ptr( + new tflite::ops::builtin::BuiltinOpResolver()); +} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/java/src/main/native/exception_jni.cc b/tensorflow/contrib/lite/java/src/main/native/exception_jni.cc new file mode 100644 index 0000000000000000000000000000000000000000..1578c9e3ddd034ad9ce17c8c3ae6c942258e2a55 --- /dev/null +++ b/tensorflow/contrib/lite/java/src/main/native/exception_jni.cc @@ -0,0 +1,66 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "tensorflow/contrib/lite/java/src/main/native/exception_jni.h" + +const char kIllegalArgumentException[] = "java/lang/IllegalArgumentException"; +const char kIllegalStateException[] = "java/lang/IllegalStateException"; +const char kNullPointerException[] = "java/lang/NullPointerException"; +const char kIndexOutOfBoundsException[] = "java/lang/IndexOutOfBoundsException"; +const char kUnsupportedOperationException[] = + "java/lang/UnsupportedOperationException"; + +void throwException(JNIEnv* env, const char* clazz, const char* fmt, ...) { + va_list args; + va_start(args, fmt); + const size_t max_msg_len = 512; + auto* message = static_cast(malloc(max_msg_len)); + if (vsnprintf(message, max_msg_len, fmt, args) >= 0) { + env->ThrowNew(env->FindClass(clazz), message); + } else { + env->ThrowNew(env->FindClass(clazz), ""); + } + free(message); + va_end(args); +} + +BufferErrorReporter::BufferErrorReporter(JNIEnv* env, int limit) { + buffer_ = new char[limit]; + if (!buffer_) { + throwException(env, kNullPointerException, + "Malloc of BufferErrorReporter to hold %d char failed.", + limit); + return; + } + start_idx_ = 0; + end_idx_ = limit - 1; +} + +BufferErrorReporter::~BufferErrorReporter() { delete[] buffer_; } + +int BufferErrorReporter::Report(const char* format, va_list args) { + int size = 0; + if (start_idx_ < end_idx_) { + size = vsnprintf(buffer_ + start_idx_, end_idx_ - start_idx_, format, args); + } + start_idx_ += size; + return size; +} + +const char* BufferErrorReporter::CachedErrorMessage() { return buffer_; } diff --git a/tensorflow/contrib/lite/java/src/main/native/exception_jni.h b/tensorflow/contrib/lite/java/src/main/native/exception_jni.h new file mode 100644 index 0000000000000000000000000000000000000000..3ffff052df73c5cb21bb6522d31dc615c38f7d1f --- /dev/null +++ b/tensorflow/contrib/lite/java/src/main/native/exception_jni.h @@ -0,0 +1,50 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_JAVA_EXCEPTION_JNI_H_ +#define TENSORFLOW_CONTRIB_LITE_JAVA_EXCEPTION_JNI_H_ + +#include +#include "tensorflow/contrib/lite/error_reporter.h" + +#ifdef __cplusplus +extern "C" { +#endif + +extern const char kIllegalArgumentException[]; +extern const char kIllegalStateException[]; +extern const char kNullPointerException[]; +extern const char kIndexOutOfBoundsException[]; +extern const char kUnsupportedOperationException[]; + +void throwException(JNIEnv* env, const char* clazz, const char* fmt, ...); + +class BufferErrorReporter : public tflite::ErrorReporter { + public: + BufferErrorReporter(JNIEnv* env, int limit); + virtual ~BufferErrorReporter(); + int Report(const char* format, va_list args) override; + const char* CachedErrorMessage(); + + private: + char* buffer_; + int start_idx_ = 0; + int end_idx_ = 0; +}; + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus +#endif // TENSORFLOW_CONTRIB_LITE_JAVA_EXCEPTION_JNI_H_ diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc new file mode 100644 index 0000000000000000000000000000000000000000..bc6462eb5466e14769f94c5103984f5201b4b8dc --- /dev/null +++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc @@ -0,0 +1,446 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h" + +namespace { + +const int kByteBufferValue = 999; +const int kBufferSize = 256; + +tflite::Interpreter* convertLongToInterpreter(JNIEnv* env, jlong handle) { + if (handle == 0) { + throwException(env, kIllegalArgumentException, + "Invalid handle to Interpreter."); + return nullptr; + } + return reinterpret_cast(handle); +} + +tflite::FlatBufferModel* convertLongToModel(JNIEnv* env, jlong handle) { + if (handle == 0) { + throwException(env, kIllegalArgumentException, "Invalid handle to model."); + return nullptr; + } + return reinterpret_cast(handle); +} + +BufferErrorReporter* convertLongToErrorReporter(JNIEnv* env, jlong handle) { + if (handle == 0) { + throwException(env, kIllegalArgumentException, + "Invalid handle to ErrorReporter."); + return nullptr; + } + return reinterpret_cast(handle); +} + +std::vector convertJIntArrayToVector(JNIEnv* env, jintArray inputs) { + int size = static_cast(env->GetArrayLength(inputs)); + std::vector outputs(size, 0); + jint* ptr = env->GetIntArrayElements(inputs, nullptr); + if (ptr == nullptr) { + throwException(env, kIllegalArgumentException, + "Empty dimensions of input array."); + return {}; + } + for (int i = 0; i < size; ++i) { + outputs[i] = ptr[i]; + } + env->ReleaseIntArrayElements(inputs, ptr, JNI_ABORT); + return outputs; +} + +bool isByteBuffer(jint data_type) { return data_type == kByteBufferValue; } + +TfLiteType resolveDataType(jint data_type) { + switch (data_type) { + case 1: + return kTfLiteFloat32; + case 2: + return kTfLiteInt32; + case 3: + return kTfLiteUInt8; + case 4: + return kTfLiteInt64; + default: + return kTfLiteNoType; + } +} + +void printDims(char* buffer, int max_size, int* dims, int num_dims) { + if (max_size <= 0) return; + buffer[0] = '?'; + int size = 1; + for (int i = 1; i < num_dims; ++i) { + if (max_size > size) { + int written_size = + snprintf(buffer + size, max_size - size, ",%d", dims[i]); + if (written_size < 0) return; + size += written_size; + } + } +} + +TfLiteStatus checkInputs(JNIEnv* env, tflite::Interpreter* interpreter, + const int input_size, jintArray data_types, + jintArray nums_of_bytes, jobjectArray values, + jobjectArray sizes) { + if (input_size != interpreter->inputs().size()) { + throwException(env, kIllegalArgumentException, + "Expected num of inputs is %d but got %d", + interpreter->inputs().size(), input_size); + return kTfLiteError; + } + if (input_size != env->GetArrayLength(data_types) || + input_size != env->GetArrayLength(nums_of_bytes) || + input_size != env->GetArrayLength(values)) { + throwException(env, kIllegalArgumentException, + "Arrays in arguments should be of the same length, but got " + "%d sizes, %d data_types, %d nums_of_bytes, and %d values", + input_size, env->GetArrayLength(data_types), + env->GetArrayLength(nums_of_bytes), + env->GetArrayLength(values)); + return kTfLiteError; + } + for (int i = 0; i < input_size; ++i) { + int input_idx = interpreter->inputs()[i]; + TfLiteTensor* target = interpreter->tensor(input_idx); + jintArray dims = + static_cast(env->GetObjectArrayElement(sizes, i)); + int num_dims = static_cast(env->GetArrayLength(dims)); + if (target->dims->size != num_dims) { + throwException(env, kIllegalArgumentException, + "%d-th input should have %d dimensions, but found %d " + "dimensions", + i, target->dims->size, num_dims); + return kTfLiteError; + } + jint* ptr = env->GetIntArrayElements(dims, nullptr); + for (int j = 1; j < num_dims; ++j) { + if (target->dims->data[j] != ptr[j]) { + std::unique_ptr expected_dims(new char[kBufferSize]); + std::unique_ptr obtained_dims(new char[kBufferSize]); + printDims(expected_dims.get(), kBufferSize, target->dims->data, + num_dims); + printDims(obtained_dims.get(), kBufferSize, ptr, num_dims); + throwException(env, kIllegalArgumentException, + "%d-th input dimension should be [%s], but found [%s]", + i, expected_dims.get(), obtained_dims.get()); + env->ReleaseIntArrayElements(dims, ptr, JNI_ABORT); + return kTfLiteError; + } + } + env->ReleaseIntArrayElements(dims, ptr, JNI_ABORT); + env->DeleteLocalRef(dims); + if (env->ExceptionCheck()) return kTfLiteError; + } + return kTfLiteOk; +} + +TfLiteStatus resizeInputs(JNIEnv* env, tflite::Interpreter* interpreter, + int input_size, jobjectArray sizes) { + for (int i = 0; i < input_size; ++i) { + int input_idx = interpreter->inputs()[i]; + jintArray dims = + static_cast(env->GetObjectArrayElement(sizes, i)); + TfLiteStatus status = interpreter->ResizeInputTensor( + input_idx, convertJIntArrayToVector(env, dims)); + if (status != kTfLiteOk) { + return status; + } + env->DeleteLocalRef(dims); + if (env->ExceptionCheck()) return kTfLiteError; + } + return kTfLiteOk; +} + +TfLiteStatus setInputs(JNIEnv* env, tflite::Interpreter* interpreter, + int input_size, jintArray data_types, + jintArray nums_of_bytes, jobjectArray values) { + jint* data_type = env->GetIntArrayElements(data_types, nullptr); + jint* num_bytes = env->GetIntArrayElements(nums_of_bytes, nullptr); + for (int i = 0; i < input_size; ++i) { + int input_idx = interpreter->inputs()[i]; + TfLiteTensor* target = interpreter->tensor(input_idx); + jobject value = env->GetObjectArrayElement(values, i); + bool is_byte_buffer = isByteBuffer(data_type[i]); + if (is_byte_buffer) { + writeByteBuffer(env, value, &(target->data.raw), + static_cast(num_bytes[i])); + } else { + TfLiteType type = resolveDataType(data_type[i]); + if (type != target->type) { + throwException(env, kIllegalArgumentException, + "DataType (%d) of input data does not match with the " + "DataType (%d) of model inputs.", + type, target->type); + return kTfLiteError; + } + writeMultiDimensionalArray(env, value, target->type, target->dims->size, + &(target->data.raw), + static_cast(num_bytes[i])); + } + env->DeleteLocalRef(value); + if (env->ExceptionCheck()) return kTfLiteError; + } + env->ReleaseIntArrayElements(data_types, data_type, JNI_ABORT); + env->ReleaseIntArrayElements(nums_of_bytes, num_bytes, JNI_ABORT); + return kTfLiteOk; +} + +} // namespace + +JNIEXPORT jobjectArray JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputNames(JNIEnv* env, + jclass clazz, + jlong handle) { + tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); + if (interpreter == nullptr) return nullptr; + jclass string_class = env->FindClass("java/lang/String"); + if (string_class == nullptr) { + throwException(env, kUnsupportedOperationException, + "Can not find java/lang/String class to get input names."); + return nullptr; + } + size_t size = interpreter->inputs().size(); + jobjectArray names = static_cast( + env->NewObjectArray(size, string_class, env->NewStringUTF(""))); + for (int i = 0; i < size; ++i) { + env->SetObjectArrayElement(names, i, + env->NewStringUTF(interpreter->GetInputName(i))); + } + return names; +} + +JNIEXPORT jobjectArray JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputNames(JNIEnv* env, + jclass clazz, + jlong handle) { + tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); + if (interpreter == nullptr) return nullptr; + jclass string_class = env->FindClass("java/lang/String"); + if (string_class == nullptr) { + throwException(env, kUnsupportedOperationException, + "Can not find java/lang/String class to get output names."); + return nullptr; + } + size_t size = interpreter->outputs().size(); + jobjectArray names = static_cast( + env->NewObjectArray(size, string_class, env->NewStringUTF(""))); + for (int i = 0; i < size; ++i) { + env->SetObjectArrayElement( + names, i, env->NewStringUTF(interpreter->GetOutputName(i))); + } + return names; +} + +JNIEXPORT void JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_useNNAPI(JNIEnv* env, + jclass clazz, + jlong handle, + jboolean state) { + tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); + if (interpreter == nullptr) return; + interpreter->UseNNAPI(static_cast(state)); +} + +JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_createErrorReporter( + JNIEnv* env, jclass clazz, jint size) { + BufferErrorReporter* error_reporter = + new BufferErrorReporter(env, static_cast(size)); + return reinterpret_cast(error_reporter); +} + +JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_createModel( + JNIEnv* env, jclass clazz, jstring model_file, jlong error_handle) { + BufferErrorReporter* error_reporter = + convertLongToErrorReporter(env, error_handle); + if (error_reporter == nullptr) return 0; + const char* path = env->GetStringUTFChars(model_file, nullptr); + auto model = tflite::FlatBufferModel::BuildFromFile(path, error_reporter); + if (!model) { + throwException(env, kIllegalArgumentException, + "Contents of %s does not encode a valid TensorFlowLite " + "model: %s", + path, error_reporter->CachedErrorMessage()); + env->ReleaseStringUTFChars(model_file, path); + return 0; + } + env->ReleaseStringUTFChars(model_file, path); + return reinterpret_cast(model.release()); +} + +JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_createModelWithBuffer( + JNIEnv* env, jclass /*clazz*/, jobject model_buffer, jlong error_handle) { + BufferErrorReporter* error_reporter = + convertLongToErrorReporter(env, error_handle); + if (error_reporter == nullptr) return 0; + const char* buf = + static_cast(env->GetDirectBufferAddress(model_buffer)); + jlong capacity = env->GetDirectBufferCapacity(model_buffer); + auto model = tflite::FlatBufferModel::BuildFromBuffer( + buf, static_cast(capacity), error_reporter); + if (!model) { + throwException(env, kIllegalArgumentException, + "MappedByteBuffer does not encode a valid TensorFlowLite " + "model: %s", + error_reporter->CachedErrorMessage()); + return 0; + } + return reinterpret_cast(model.release()); +} + +JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_createInterpreter( + JNIEnv* env, jclass clazz, jlong model_handle) { + tflite::FlatBufferModel* model = convertLongToModel(env, model_handle); + if (model == nullptr) return 0; + auto resolver = ::tflite::CreateOpResolver(); + std::unique_ptr interpreter; + tflite::InterpreterBuilder(*model, *(resolver.get()))(&interpreter); + return reinterpret_cast(interpreter.release()); +} + +// Sets inputs, runs inference, and returns outputs as long handles. +JNIEXPORT jlongArray JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_run( + JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle, + jobjectArray sizes, jintArray data_types, jintArray nums_of_bytes, + jobjectArray values) { + tflite::Interpreter* interpreter = + convertLongToInterpreter(env, interpreter_handle); + if (interpreter == nullptr) return nullptr; + BufferErrorReporter* error_reporter = + convertLongToErrorReporter(env, error_handle); + if (error_reporter == nullptr) return nullptr; + const int input_size = env->GetArrayLength(sizes); + // validates inputs + TfLiteStatus status = checkInputs(env, interpreter, input_size, data_types, + nums_of_bytes, values, sizes); + if (status != kTfLiteOk) return nullptr; + // resizes inputs + status = resizeInputs(env, interpreter, input_size, sizes); + if (status != kTfLiteOk) { + throwException(env, kNullPointerException, "Can not resize the input: %s", + error_reporter->CachedErrorMessage()); + return nullptr; + } + // allocates memory + status = interpreter->AllocateTensors(); + if (status != kTfLiteOk) { + throwException(env, kNullPointerException, + "Can not allocate memory for the given inputs: %s", + error_reporter->CachedErrorMessage()); + return nullptr; + } + // sets inputs + status = setInputs(env, interpreter, input_size, data_types, nums_of_bytes, + values); + if (status != kTfLiteOk) return nullptr; + // runs inference + if (interpreter->Invoke() != kTfLiteOk) { + throwException(env, kIllegalArgumentException, + "Failed to run on the given Interpreter: %s", + error_reporter->CachedErrorMessage()); + return nullptr; + } + // returns outputs + const std::vector& results = interpreter->outputs(); + if (results.empty()) { + throwException(env, kIllegalArgumentException, + "The Interpreter does not have any outputs."); + return nullptr; + } + jlongArray outputs = env->NewLongArray(results.size()); + size_t size = results.size(); + for (int i = 0; i < size; ++i) { + TfLiteTensor* source = interpreter->tensor(results[i]); + jlong output = reinterpret_cast(source); + env->SetLongArrayRegion(outputs, i, 1, &output); + } + return outputs; +} + +JNIEXPORT jintArray JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputDims( + JNIEnv* env, jclass clazz, jlong handle, jint input_idx, jint num_bytes) { + tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle); + if (interpreter == nullptr) return nullptr; + const int idx = static_cast(input_idx); + if (input_idx >= interpreter->inputs().size()) { + throwException(env, kIllegalArgumentException, + "Out of range: Failed to get %d-th input out of %d inputs", + input_idx, interpreter->inputs().size()); + return nullptr; + } + TfLiteTensor* target = interpreter->tensor(interpreter->inputs()[idx]); + int size = target->dims->size; + int expected_num_bytes = elementByteSize(target->type); + for (int i = 0; i < size; ++i) { + expected_num_bytes *= target->dims->data[i]; + } + if (num_bytes != expected_num_bytes) { + throwException(env, kIllegalArgumentException, + "Failed to get input dimensions. %d-th input should have" + " %d bytes, but found %d bytes.", + idx, expected_num_bytes, num_bytes); + return nullptr; + } + jintArray outputs = env->NewIntArray(size); + env->SetIntArrayRegion(outputs, 0, size, &(target->dims->data[0])); + return outputs; +} + +JNIEXPORT void JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_resizeInput( + JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle, + jint input_idx, jintArray dims) { + BufferErrorReporter* error_reporter = + convertLongToErrorReporter(env, error_handle); + if (error_reporter == nullptr) return; + tflite::Interpreter* interpreter = + convertLongToInterpreter(env, interpreter_handle); + if (interpreter == nullptr) return; + const int idx = static_cast(input_idx); + if (idx < 0 || idx >= interpreter->inputs().size()) { + throwException(env, kIllegalArgumentException, + "Can not resize %d-th input for a model having %d inputs.", + idx, interpreter->inputs().size()); + } + TfLiteStatus status = interpreter->ResizeInputTensor( + interpreter->inputs()[idx], convertJIntArrayToVector(env, dims)); + if (status != kTfLiteOk) { + throwException(env, kIllegalArgumentException, + "Failed to resize %d-th input: %s", idx, + error_reporter->CachedErrorMessage()); + } +} + +JNIEXPORT void JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_delete( + JNIEnv* env, jclass clazz, jlong error_handle, jlong model_handle, + jlong interpreter_handle) { + if (interpreter_handle != 0) { + delete convertLongToInterpreter(env, interpreter_handle); + } + if (model_handle != 0) { + delete convertLongToModel(env, model_handle); + } + if (error_handle != 0) { + delete convertLongToErrorReporter(env, error_handle); + } +} diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h new file mode 100644 index 0000000000000000000000000000000000000000..430886b7cc04a356d1826843acc1bbebf4189bf7 --- /dev/null +++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h @@ -0,0 +1,151 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_JAVA_NATIVEINTERPRETERWRAPPER_JNI_H_ +#define TENSORFLOW_CONTRIB_LITE_JAVA_NATIVEINTERPRETERWRAPPER_JNI_H_ + +#include +#include +#include +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/java/src/main/native/exception_jni.h" +#include "tensorflow/contrib/lite/java/src/main/native/tensor_jni.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +// This is to be provided at link-time by a library. +extern std::unique_ptr CreateOpResolver(); +} // namespace tflite + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/* + * Class: org_tensorflow_lite_NativeInterpreterWrapper + * Method: + * Signature: (J)[Ljava/lang/Object; + */ +JNIEXPORT jobjectArray JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputNames(JNIEnv* env, + jclass clazz, + jlong handle); + +/* + * Class: org_tensorflow_lite_NativeInterpreterWrapper + * Method: + * Signature: (J)[Ljava/lang/Object; + */ +JNIEXPORT jobjectArray JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputNames(JNIEnv* env, + jclass clazz, + jlong handle); + +/* + * Class: org_tensorflow_lite_NativeInterpreterWrapper + * Method: + * Signature: (JZ) + */ +JNIEXPORT void JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_useNNAPI(JNIEnv* env, + jclass clazz, + jlong handle, + jboolean state); + +/* + * Class: org_tensorflow_lite_NativeInterpreterWrapper + * Method: + * Signature: (I)J + */ +JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_createErrorReporter( + JNIEnv* env, jclass clazz, jint size); + +/* + * Class: org_tensorflow_lite_NativeInterpreterWrapper + * Method: + * Signature: (Ljava/lang/String;J)J + */ +JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_createModel( + JNIEnv* env, jclass clazz, jstring model_file, jlong error_handle); + +/* + * Class: org_tensorflow_lite_NativeInterpreterWrapper + * Method: + * Signature: (Ljava/lang/Object;J)J + */ +JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_createModelWithBuffer( + JNIEnv* env, jclass clazz, jobject model_buffer, jlong error_handle); + +/* + * Class: org_tensorflow_lite_NativeInterpreterWrapper + * Method: + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_createInterpreter( + JNIEnv* env, jclass clazz, jlong model_handle); + +/* + * Class: org_tensorflow_lite_NativeInterpreterWrapper + * Method: + * Signature: (JJ[Ljava/lang/Object;[I[I[Ljava/lang/Object;)[J + */ +JNIEXPORT jlongArray JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_run( + JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle, + jobjectArray sizes, jintArray data_types, jintArray nums_of_bytes, + jobjectArray values); + +/* + * Class: org_tensorflow_lite_NativeInterpreterWrapper + * Method: + * Signature: (JII)[I + * + * It gets input dimensions if num_bytes matches number of bytes required by + * the input, else returns null and throws IllegalArgumentException. + */ +JNIEXPORT jintArray JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputDims( + JNIEnv* env, jclass clazz, jlong handle, jint input_idx, jint num_bytes); + +/* + * Class: org_tensorflow_lite_NativeInterpreterWrapper + * Method: + * Signature: (JJI[I) + * + * It resizes dimensions of a input. + */ +JNIEXPORT void JNICALL +Java_org_tensorflow_lite_NativeInterpreterWrapper_resizeInput( + JNIEnv* env, jclass clazz, jlong interpreter_handle, jlong error_handle, + jint input_idx, jintArray dims); + +/* + * Class: org_tensorflow_lite_NativeInterpreterWrapper + * Method: + * Signature: (JJJ) + */ +JNIEXPORT void JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_delete( + JNIEnv* env, jclass clazz, jlong error_handle, jlong model_handle, + jlong interpreter_handle); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus +#endif // TENSORFLOW_CONTRIB_LITE_JAVA_NATIVEINTERPRETERWRAPPER_JNI_H_ diff --git a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc new file mode 100644 index 0000000000000000000000000000000000000000..65126e78a3003f8a69c69326124d613e878c0f9d --- /dev/null +++ b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc @@ -0,0 +1,242 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/java/src/main/native/tensor_jni.h" +#include +#include +#include "tensorflow/contrib/lite/java/src/main/native/exception_jni.h" + +namespace { + +TfLiteTensor* convertLongToTensor(JNIEnv* env, jlong handle) { + if (handle == 0) { + throwException(env, kIllegalArgumentException, + "Invalid handle to TfLiteTensor."); + return nullptr; + } + return reinterpret_cast(handle); +} + +size_t writeOneDimensionalArray(JNIEnv* env, jobject object, TfLiteType type, + void* dst, size_t dst_size) { + jarray array = static_cast(object); + const int num_elements = env->GetArrayLength(array); + size_t to_copy = num_elements * elementByteSize(type); + if (to_copy > dst_size) { + throwException(env, kIllegalStateException, + "cannot write Java array of %d bytes to Tensor of %d bytes", + to_copy, dst_size); + return 0; + } + switch (type) { + case kTfLiteFloat32: { + jfloatArray a = static_cast(array); + jfloat* values = env->GetFloatArrayElements(a, nullptr); + memcpy(dst, values, to_copy); + env->ReleaseFloatArrayElements(a, values, JNI_ABORT); + return to_copy; + } + case kTfLiteInt32: { + jintArray a = static_cast(array); + jint* values = env->GetIntArrayElements(a, nullptr); + memcpy(dst, values, to_copy); + env->ReleaseIntArrayElements(a, values, JNI_ABORT); + return to_copy; + } + case kTfLiteInt64: { + jlongArray a = static_cast(array); + jlong* values = env->GetLongArrayElements(a, nullptr); + memcpy(dst, values, to_copy); + env->ReleaseLongArrayElements(a, values, JNI_ABORT); + return to_copy; + } + case kTfLiteUInt8: { + jbyteArray a = static_cast(array); + jbyte* values = env->GetByteArrayElements(a, nullptr); + memcpy(dst, values, to_copy); + env->ReleaseByteArrayElements(a, values, JNI_ABORT); + return to_copy; + } + default: { + throwException(env, kUnsupportedOperationException, + "TensorFlowLite currently supports float (32 bits), " + "int (32 bits), byte (8 bits), and long (64 bits), " + "support for other types (DataType %d in this case) will " + "be added in the future", + kTfLiteFloat32, type); + return 0; + } + } +} + +size_t readOneDimensionalArray(JNIEnv* env, TfLiteType data_type, + const void* src, size_t src_size, jarray dst) { + const int len = env->GetArrayLength(dst); + const size_t size = len * elementByteSize(data_type); + if (size > src_size) { + throwException( + env, kIllegalStateException, + "cannot fill a Java array of %d bytes with a Tensor of %d bytes", size, + src_size); + return 0; + } + switch (data_type) { + case kTfLiteFloat32: { + jfloatArray float_array = static_cast(dst); + env->SetFloatArrayRegion(float_array, 0, len, + static_cast(src)); + return size; + } + case kTfLiteInt32: { + jintArray int_array = static_cast(dst); + env->SetIntArrayRegion(int_array, 0, len, static_cast(src)); + return size; + } + case kTfLiteInt64: { + jlongArray long_array = static_cast(dst); + env->SetLongArrayRegion(long_array, 0, len, + static_cast(src)); + return size; + } + case kTfLiteUInt8: { + jbyteArray byte_array = static_cast(dst); + env->SetByteArrayRegion(byte_array, 0, len, + static_cast(src)); + return size; + } + default: { + throwException(env, kIllegalStateException, "invalid DataType(%d)", + data_type); + } + } + return 0; +} + +size_t readMultiDimensionalArray(JNIEnv* env, TfLiteType data_type, char* src, + size_t src_size, int dims_left, jarray dst) { + if (dims_left == 1) { + return readOneDimensionalArray(env, data_type, src, src_size, dst); + } else { + jobjectArray ndarray = static_cast(dst); + int len = env->GetArrayLength(ndarray); + size_t size = 0; + for (int i = 0; i < len; ++i) { + jarray row = static_cast(env->GetObjectArrayElement(ndarray, i)); + size += readMultiDimensionalArray(env, data_type, src + size, + src_size - size, dims_left - 1, row); + env->DeleteLocalRef(row); + if (env->ExceptionCheck()) return size; + } + return size; + } +} + +} // namespace + +size_t elementByteSize(TfLiteType data_type) { + // The code in this file makes the assumption that the + // TensorFlow TF_DataTypes and the Java primitive types + // have the same byte sizes. Validate that: + switch (data_type) { + case kTfLiteFloat32: + static_assert(sizeof(jfloat) == 4, + "Java float not compatible with kTfLiteFloat"); + return 4; + case kTfLiteInt32: + static_assert(sizeof(jint) == 4, + "Java int not compatible with kTfLiteInt"); + return 4; + case kTfLiteUInt8: + static_assert(sizeof(jbyte) == 1, + "Java byte not compatible with kTfLiteUInt8"); + return 1; + case kTfLiteInt64: + static_assert(sizeof(jlong) == 8, + "Java long not compatible with kTfLiteInt64"); + return 8; + default: + return 0; + } +} + +size_t writeByteBuffer(JNIEnv* env, jobject object, char** dst, int dst_size) { + char* buf = static_cast(env->GetDirectBufferAddress(object)); + if (!buf) { + throwException(env, kIllegalArgumentException, + "Input ByteBuffer is not a direct buffer"); + return 0; + } + *dst = buf; + return dst_size; +} + +size_t writeMultiDimensionalArray(JNIEnv* env, jobject src, TfLiteType type, + int dims_left, char** dst, int dst_size) { + if (dims_left <= 1) { + return writeOneDimensionalArray(env, src, type, *dst, dst_size); + } else { + jobjectArray ndarray = static_cast(src); + int len = env->GetArrayLength(ndarray); + size_t sz = 0; + for (int i = 0; i < len; ++i) { + jobject row = env->GetObjectArrayElement(ndarray, i); + char* next_dst = *dst + sz; + sz += writeMultiDimensionalArray(env, row, type, dims_left - 1, &next_dst, + dst_size - sz); + env->DeleteLocalRef(row); + if (env->ExceptionCheck()) return sz; + } + return sz; + } +} + +JNIEXPORT void JNICALL +Java_org_tensorflow_lite_Tensor_readMultiDimensionalArray(JNIEnv* env, + jclass clazz, + jlong handle, + jobject value) { + TfLiteTensor* tensor = convertLongToTensor(env, handle); + if (tensor == nullptr) return; + int num_dims = tensor->dims->size; + if (num_dims == 0) { + throwException(env, kIllegalArgumentException, + "copyTo() is not meant for scalar Tensors."); + return; + } + readMultiDimensionalArray(env, tensor->type, tensor->data.raw, tensor->bytes, + num_dims, static_cast(value)); +} + +JNIEXPORT jint JNICALL Java_org_tensorflow_lite_Tensor_dtype(JNIEnv* env, + jclass clazz, + jlong handle) { + TfLiteTensor* tensor = convertLongToTensor(env, handle); + if (tensor == nullptr) return 0; + return static_cast(tensor->type); +} + +JNIEXPORT jintArray JNICALL +Java_org_tensorflow_lite_Tensor_shape(JNIEnv* env, jclass clazz, jlong handle) { + TfLiteTensor* tensor = convertLongToTensor(env, handle); + if (tensor == nullptr) return nullptr; + int num_dims = tensor->dims->size; + jintArray result = env->NewIntArray(num_dims); + jint* dims = env->GetIntArrayElements(result, nullptr); + for (int i = 0; i < num_dims; ++i) { + dims[i] = static_cast(tensor->dims->data[i]); + } + env->ReleaseIntArrayElements(result, dims, 0); + return result; +} diff --git a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h new file mode 100644 index 0000000000000000000000000000000000000000..3a4910dcc3a719fbb9f365dae693423de768349c --- /dev/null +++ b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h @@ -0,0 +1,74 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_JAVA_TENSOR_JNI_H_ +#define TENSORFLOW_CONTRIB_LITE_JAVA_TENSOR_JNI_H_ + +#include +#include "tensorflow/contrib/lite/context.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/* + * Class: org_tensorflow_lite_TfLiteTensor + * Method: + * Signature: (J)I + */ +JNIEXPORT jint JNICALL Java_org_tensorflow_lite_Tensor_dtype(JNIEnv* env, + jclass clazz, + jlong handle); + +/* + * Class: org_tensorflow_lite_TfLiteTensor + * Method: + * Signature: (J)[I + */ +JNIEXPORT jintArray JNICALL Java_org_tensorflow_lite_Tensor_shape(JNIEnv* env, + jclass clazz, + jlong handle); + +/* + * Class: org_tensorflow_lite_TfLiteTensor + * Method: + * Signature: (JLjava/lang/Object;) + */ +JNIEXPORT void JNICALL +Java_org_tensorflow_lite_Tensor_readMultiDimensionalArray(JNIEnv* env, + jclass clazz, + jlong handle, + jobject value); + +/* + * Finds the size of each data type. + */ +size_t elementByteSize(TfLiteType data_type); + +/* + * Writes data of a ByteBuffer into dest. + */ +size_t writeByteBuffer(JNIEnv* env, jobject object, char** dst, int dst_size); + +/* + * Writes a multi-dimensional array into dest. + */ +size_t writeMultiDimensionalArray(JNIEnv* env, jobject src, TfLiteType type, + int dims_left, char** dst, int dst_size); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus +#endif // TENSORFLOW_CONTRIB_LITE_JAVA_TENSOR_JNI_H_ diff --git a/tensorflow/contrib/lite/java/src/main/native/tensorflow_lite_jni.cc b/tensorflow/contrib/lite/java/src/main/native/tensorflow_lite_jni.cc new file mode 100644 index 0000000000000000000000000000000000000000..2e7f2f56921b871a6ace2b6cb984fcd185a4d2ab --- /dev/null +++ b/tensorflow/contrib/lite/java/src/main/native/tensorflow_lite_jni.cc @@ -0,0 +1,26 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/contrib/lite/java/src/main/native/tensorflow_lite_jni.h" +#include "tensorflow/contrib/lite/version.h" + +JNIEXPORT jstring JNICALL +Java_org_tensorflow_lite_TensorFlowLite_version(JNIEnv* env, jclass /*clazz*/) { + char buf[64]; + snprintf(buf, sizeof(buf), "%d", TFLITE_SCHEMA_VERSION); + return env->NewStringUTF(buf); +} diff --git a/tensorflow/contrib/lite/java/src/main/native/tensorflow_lite_jni.h b/tensorflow/contrib/lite/java/src/main/native/tensorflow_lite_jni.h new file mode 100644 index 0000000000000000000000000000000000000000..65f8341149287f151f7e51fe04d9525bf119164e --- /dev/null +++ b/tensorflow/contrib/lite/java/src/main/native/tensorflow_lite_jni.h @@ -0,0 +1,36 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_JAVA_TENSORFLOW_LITE_JNI_H_ +#define TENSORFLOW_CONTRIB_LITE_JAVA_TENSORFLOW_LITE_JNI_H_ + +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +/* + * Class: org_tensorflow_lite_TensorFlowLite + * Method: version + * Signature: ()Ljava/lang/String; + */ +JNIEXPORT jstring JNICALL +Java_org_tensorflow_lite_TensorFlowLite_version(JNIEnv*, jclass); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus +#endif // TENSORFLOW_CONTRIB_LITE_JAVA_TENSORFLOW_LITE_JNI_H_ diff --git a/tensorflow/contrib/lite/java/src/main/native/version_script.lds b/tensorflow/contrib/lite/java/src/main/native/version_script.lds new file mode 100644 index 0000000000000000000000000000000000000000..38c93dda730550070f28b59297c5191a9615ed7b --- /dev/null +++ b/tensorflow/contrib/lite/java/src/main/native/version_script.lds @@ -0,0 +1,11 @@ +VERS_1.0 { + # Export JNI symbols. + global: + Java_*; + JNI_OnLoad; + JNI_OnUnload; + + # Hide everything else. + local: + *; +}; diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/DataTypeTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/DataTypeTest.java new file mode 100644 index 0000000000000000000000000000000000000000..cebc9442008e10e7674cf7b1dc58e633fef4ba39 --- /dev/null +++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/DataTypeTest.java @@ -0,0 +1,34 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +package org.tensorflow.lite; + +import static com.google.common.truth.Truth.assertThat; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link org.tensorflow.lite.DataType}. */ +@RunWith(JUnit4.class) +public final class DataTypeTest { + + @Test + public void testElemByteSize() { + assertThat(DataType.FLOAT32.elemByteSize()).isEqualTo(4); + assertThat(DataType.INT32.elemByteSize()).isEqualTo(4); + assertThat(DataType.UINT8.elemByteSize()).isEqualTo(1); + assertThat(DataType.INT64.elemByteSize()).isEqualTo(8); + } +} diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java new file mode 100644 index 0000000000000000000000000000000000000000..424b3de6c97672e310c54230a7ac1204f46d9ac8 --- /dev/null +++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java @@ -0,0 +1,221 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.fail; + +import java.io.File; +import java.nio.MappedByteBuffer; +import java.nio.channels.FileChannel; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.StandardOpenOption; +import java.util.EnumSet; +import java.util.HashMap; +import java.util.Map; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link org.tensorflow.lite.Interpreter}. */ +@RunWith(JUnit4.class) +public final class InterpreterTest { + + private static final File MODEL_FILE = + new File("tensorflow/contrib/lite/java/src/testdata/add.bin"); + + private static final File MOBILENET_MODEL_FILE = + new File("tensorflow/contrib/lite/java/src/testdata/mobilenet.tflite.bin"); + + @Test + public void testInterpreter() throws Exception { + Interpreter interpreter = new Interpreter(MODEL_FILE); + assertThat(interpreter).isNotNull(); + interpreter.close(); + } + + @Test + public void testRunWithMappedByteBufferModel() throws Exception { + Path path = MODEL_FILE.toPath(); + FileChannel fileChannel = + (FileChannel) Files.newByteChannel(path, EnumSet.of(StandardOpenOption.READ)); + MappedByteBuffer mappedByteBuffer = + fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, fileChannel.size()); + Interpreter interpreter = new Interpreter(mappedByteBuffer); + float[] oneD = {1.23f, 6.54f, 7.81f}; + float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + float[][][][] fourD = {threeD, threeD}; + float[][][][] parsedOutputs = new float[2][8][8][3]; + interpreter.run(fourD, parsedOutputs); + float[] outputOneD = parsedOutputs[0][0][0]; + float[] expected = {3.69f, 19.62f, 23.43f}; + assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); + interpreter.close(); + fileChannel.close(); + } + + @Test + public void testRun() { + Interpreter interpreter = new Interpreter(MODEL_FILE); + Float[] oneD = {1.23f, 6.54f, 7.81f}; + Float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + Float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + Float[][][][] fourD = {threeD, threeD}; + Float[][][][] parsedOutputs = new Float[2][8][8][3]; + try { + interpreter.run(fourD, parsedOutputs); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e).hasMessageThat().contains("cannot resolve DataType of [[[[Ljava.lang.Float;"); + } + interpreter.close(); + } + + @Test + public void testRunWithBoxedInputs() { + Interpreter interpreter = new Interpreter(MODEL_FILE); + float[] oneD = {1.23f, 6.54f, 7.81f}; + float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + float[][][][] fourD = {threeD, threeD}; + float[][][][] parsedOutputs = new float[2][8][8][3]; + interpreter.run(fourD, parsedOutputs); + float[] outputOneD = parsedOutputs[0][0][0]; + float[] expected = {3.69f, 19.62f, 23.43f}; + assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); + interpreter.close(); + } + + @Test + public void testRunForMultipleInputsOutputs() { + Interpreter interpreter = new Interpreter(MODEL_FILE); + float[] oneD = {1.23f, 6.54f, 7.81f}; + float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + float[][][][] fourD = {threeD, threeD}; + Object[] inputs = {fourD}; + float[][][][] parsedOutputs = new float[2][8][8][3]; + Map outputs = new HashMap<>(); + outputs.put(0, parsedOutputs); + interpreter.runForMultipleInputsOutputs(inputs, outputs); + float[] outputOneD = parsedOutputs[0][0][0]; + float[] expected = {3.69f, 19.62f, 23.43f}; + assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); + interpreter.close(); + } + + @Test + public void testMobilenetRun() { + // Create a gray image. + float[][][][] img = new float[1][224][224][3]; + for (int i = 0; i < 224; ++i) { + for (int j = 0; j < 224; ++j) { + img[0][i][j][0] = 0.5f; + img[0][i][j][1] = 0.5f; + img[0][i][j][2] = 0.5f; + } + } + + // Allocate memory to receive the output values. + float[][] labels = new float[1][1001]; + + Interpreter interpreter = new Interpreter(MOBILENET_MODEL_FILE); + interpreter.run(img, labels); + interpreter.close(); + + assertThat(labels[0]) + .usingExactEquality() + .containsNoneOf(new float[] {Float.NaN, Float.NEGATIVE_INFINITY, Float.POSITIVE_INFINITY}); + } + + @Test + public void testRunWithWrongInputType() { + Interpreter interpreter = new Interpreter(MODEL_FILE); + int[] oneD = {4, 3, 9}; + int[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + int[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + int[][][][] fourD = {threeD, threeD}; + float[][][][] parsedOutputs = new float[2][8][8][3]; + try { + interpreter.run(fourD, parsedOutputs); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e) + .hasMessageThat() + .contains( + "DataType (2) of input data does not match with the DataType (1) of model inputs."); + } + interpreter.close(); + } + + @Test + public void testRunWithWrongOutputType() { + Interpreter interpreter = new Interpreter(MODEL_FILE); + float[] oneD = {1.23f, 6.54f, 7.81f}; + float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + float[][][][] fourD = {threeD, threeD}; + int[][][][] parsedOutputs = new int[2][8][8][3]; + try { + interpreter.run(fourD, parsedOutputs); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e) + .hasMessageThat() + .contains( + "Cannot convert an TensorFlowLite tensor with type " + + "FLOAT32 to a Java object of type [[[[I (which is compatible with the" + + " TensorFlowLite type INT32)"); + } + interpreter.close(); + } + + @Test + public void testGetInputIndex() { + Interpreter interpreter = new Interpreter(MOBILENET_MODEL_FILE); + try { + interpreter.getInputIndex("WrongInputName"); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e) + .hasMessageThat() + .contains( + "WrongInputName is not a valid name for any input. The indexes of the inputs" + + " are {input=0}"); + } + int index = interpreter.getInputIndex("input"); + assertThat(index).isEqualTo(0); + } + + @Test + public void testGetOutputIndex() { + Interpreter interpreter = new Interpreter(MOBILENET_MODEL_FILE); + try { + interpreter.getOutputIndex("WrongOutputName"); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e) + .hasMessageThat() + .contains( + "WrongOutputName is not a valid name for any output. The indexes of the outputs" + + " are {MobilenetV1/Predictions/Softmax=0}"); + } + int index = interpreter.getOutputIndex("MobilenetV1/Predictions/Softmax"); + assertThat(index).isEqualTo(0); + } +} diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java new file mode 100644 index 0000000000000000000000000000000000000000..9a6894f49c0b7278511717d2671648c6d1763e00 --- /dev/null +++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java @@ -0,0 +1,406 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.fail; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link org.tensorflow.lite.NativeInterpreterWrapper}. */ +@RunWith(JUnit4.class) +public final class NativeInterpreterWrapperTest { + + private static final String FLOAT_MODEL_PATH = + "tensorflow/contrib/lite/java/src/testdata/add.bin"; + + private static final String INT_MODEL_PATH = + "tensorflow/contrib/lite/java/src/testdata/int32.bin"; + + private static final String LONG_MODEL_PATH = + "tensorflow/contrib/lite/java/src/testdata/int64.bin"; + + private static final String BYTE_MODEL_PATH = + "tensorflow/contrib/lite/java/src/testdata/uint8.bin"; + + private static final String INVALID_MODEL_PATH = + "tensorflow/contrib/lite/java/src/testdata/invalid_model.bin"; + + @Test + public void testConstructor() { + NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH); + assertThat(wrapper).isNotNull(); + wrapper.close(); + } + + @Test + public void testConstructorWithInvalidModel() { + try { + NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(INVALID_MODEL_PATH); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e) + .hasMessageThat() + .contains("Model provided has model identifier ' is ', should be 'TFL3'"); + } + } + + @Test + public void testRunWithFloat() { + NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH); + float[] oneD = {1.23f, -6.54f, 7.81f}; + float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + float[][][][] fourD = {threeD, threeD}; + Object[] inputs = {fourD}; + Tensor[] outputs = wrapper.run(inputs); + assertThat(outputs.length).isEqualTo(1); + float[][][][] parsedOutputs = new float[2][8][8][3]; + outputs[0].copyTo(parsedOutputs); + float[] outputOneD = parsedOutputs[0][0][0]; + float[] expected = {3.69f, -19.62f, 23.43f}; + assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); + wrapper.close(); + } + + @Test + public void testRunWithInt() { + NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(INT_MODEL_PATH); + int[] oneD = {3, 7, -4}; + int[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + int[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + int[][][][] fourD = {threeD, threeD}; + Object[] inputs = {fourD}; + Tensor[] outputs = wrapper.run(inputs); + assertThat(outputs.length).isEqualTo(1); + int[][][][] parsedOutputs = new int[2][4][4][12]; + outputs[0].copyTo(parsedOutputs); + int[] outputOneD = parsedOutputs[0][0][0]; + int[] expected = {3, 7, -4, 3, 7, -4, 3, 7, -4, 3, 7, -4}; + assertThat(outputOneD).isEqualTo(expected); + wrapper.close(); + } + + @Test + public void testRunWithLong() { + NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(LONG_MODEL_PATH); + long[] oneD = {-892834092L, 923423L, 2123918239018L}; + long[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + long[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + long[][][][] fourD = {threeD, threeD}; + Object[] inputs = {fourD}; + Tensor[] outputs = wrapper.run(inputs); + assertThat(outputs.length).isEqualTo(1); + long[][][][] parsedOutputs = new long[2][4][4][12]; + outputs[0].copyTo(parsedOutputs); + long[] outputOneD = parsedOutputs[0][0][0]; + long[] expected = {-892834092L, 923423L, 2123918239018L, -892834092L, 923423L, 2123918239018L, + -892834092L, 923423L, 2123918239018L, -892834092L, 923423L, 2123918239018L}; + assertThat(outputOneD).isEqualTo(expected); + wrapper.close(); + } + + @Test + public void testRunWithByte() { + NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(BYTE_MODEL_PATH); + byte[] oneD = {(byte) 0xe0, 0x4f, (byte) 0xd0}; + byte[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + byte[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + byte[][][][] fourD = {threeD, threeD}; + Object[] inputs = {fourD}; + int[] inputDims = {2, 8, 8, 3}; + wrapper.resizeInput(0, inputDims); + Tensor[] outputs = wrapper.run(inputs); + assertThat(outputs.length).isEqualTo(1); + byte[][][][] parsedOutputs = new byte[2][4][4][12]; + outputs[0].copyTo(parsedOutputs); + byte[] outputOneD = parsedOutputs[0][0][0]; + byte[] expected = {(byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0, + (byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0}; + assertThat(outputOneD).isEqualTo(expected); + wrapper.close(); + } + + @Test + public void testRunWithByteBufferHavingBytes() { + NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(BYTE_MODEL_PATH); + ByteBuffer bbuf = ByteBuffer.allocateDirect(2 * 8 * 8 * 3); + bbuf.order(ByteOrder.nativeOrder()); + bbuf.rewind(); + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 8; ++j) { + for (int k = 0; k < 8; ++k) { + bbuf.put((byte) 0xe0); + bbuf.put((byte) 0x4f); + bbuf.put((byte) 0xd0); + } + } + } + Object[] inputs = {bbuf}; + int[] inputDims = {2, 8, 8, 3}; + wrapper.resizeInput(0, inputDims); + Tensor[] outputs = wrapper.run(inputs); + assertThat(outputs.length).isEqualTo(1); + byte[][][][] parsedOutputs = new byte[2][4][4][12]; + outputs[0].copyTo(parsedOutputs); + byte[] outputOneD = parsedOutputs[0][0][0]; + byte[] expected = { + (byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0, + (byte) 0xe0, 0x4f, (byte) 0xd0, (byte) 0xe0, 0x4f, (byte) 0xd0 + }; + assertThat(outputOneD).isEqualTo(expected); + wrapper.close(); + } + + @Test + public void testRunWithByteBufferHavingFloats() { + NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH); + ByteBuffer bbuf = ByteBuffer.allocateDirect(4 * 8 * 8 * 3 * 4); + bbuf.order(ByteOrder.nativeOrder()); + bbuf.rewind(); + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 8; ++j) { + for (int k = 0; k < 8; ++k) { + bbuf.putFloat(1.23f); + bbuf.putFloat(-6.54f); + bbuf.putFloat(7.81f); + } + } + } + Object[] inputs = {bbuf}; + try { + wrapper.run(inputs); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e) + .hasMessageThat() + .contains( + "Failed to get input dimensions. 0-th input should have 768 bytes, but found 3072 bytes"); + } + int[] inputDims = {4, 8, 8, 3}; + wrapper.resizeInput(0, inputDims); + Tensor[] outputs = wrapper.run(inputs); + assertThat(outputs.length).isEqualTo(1); + float[][][][] parsedOutputs = new float[4][8][8][3]; + outputs[0].copyTo(parsedOutputs); + float[] outputOneD = parsedOutputs[0][0][0]; + float[] expected = {3.69f, -19.62f, 23.43f}; + assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); + wrapper.close(); + } + + @Test + public void testRunWithByteBufferHavingWrongSize() { + NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(BYTE_MODEL_PATH); + ByteBuffer bbuf = ByteBuffer.allocateDirect(2 * 7 * 8 * 3); + bbuf.order(ByteOrder.nativeOrder()); + Object[] inputs = {bbuf}; + try { + wrapper.run(inputs); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e) + .hasMessageThat() + .contains( + "Failed to get input dimensions. 0-th input should have 192 bytes, but found 336 bytes."); + } + wrapper.close(); + } + + @Test + public void testRunWithWrongInputType() { + NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH); + int[] oneD = {4, 3, 9}; + int[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + int[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + int[][][][] fourD = {threeD, threeD}; + Object[] inputs = {fourD}; + try { + wrapper.run(inputs); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e) + .hasMessageThat() + .contains( + "DataType (2) of input data does not match with the DataType (1) of model inputs."); + } + wrapper.close(); + } + + @Test + public void testRunAfterClose() { + NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH); + wrapper.close(); + float[] oneD = {1.23f, 6.54f, 7.81f}; + float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + float[][][][] fourD = {threeD, threeD}; + Object[] inputs = {fourD}; + try { + wrapper.run(inputs); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e).hasMessageThat().contains("Invalid handle to Interpreter."); + } + } + + @Test + public void testRunWithEmptyInputs() { + NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH); + try { + Object[] inputs = {}; + wrapper.run(inputs); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e) + .hasMessageThat() + .contains("Invalid inputs. Inputs should not be null or empty."); + } + wrapper.close(); + } + + @Test + public void testRunWithWrongInputSize() { + NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH); + float[] oneD = {1.23f, 6.54f, 7.81f}; + float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + float[][][][] fourD = {threeD, threeD}; + Object[] inputs = {fourD, fourD}; + try { + wrapper.run(inputs); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e).hasMessageThat().contains("Expected num of inputs is 1 but got 2"); + } + wrapper.close(); + } + + @Test + public void testRunWithWrongInputNumOfDims() { + NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH); + float[] oneD = {1.23f, 6.54f, 7.81f}; + float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + Object[] inputs = {threeD}; + try { + wrapper.run(inputs); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e) + .hasMessageThat() + .contains("0-th input should have 4 dimensions, but found 3 dimensions"); + } + wrapper.close(); + } + + @Test + public void testRunWithWrongInputDims() { + NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH); + float[] oneD = {1.23f, 6.54f, 7.81f}; + float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + float[][][][] fourD = {threeD, threeD}; + Object[] inputs = {fourD}; + try { + wrapper.run(inputs); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e) + .hasMessageThat() + .contains("0-th input dimension should be [?,8,8,3], but found [?,8,7,3]"); + } + wrapper.close(); + } + + @Test + public void testNumElements() { + int[] shape = {2, 3, 4}; + int num = NativeInterpreterWrapper.numElements(shape); + assertThat(num).isEqualTo(24); + shape = null; + num = NativeInterpreterWrapper.numElements(shape); + assertThat(num).isEqualTo(0); + } + + @Test + public void testIsNonEmtpyArray() { + assertThat(NativeInterpreterWrapper.isNonEmptyArray(null)).isFalse(); + assertThat(NativeInterpreterWrapper.isNonEmptyArray(3.2)).isFalse(); + int[] emptyArray = {}; + assertThat(NativeInterpreterWrapper.isNonEmptyArray(emptyArray)).isFalse(); + int[] validArray = {9, 5, 2, 1}; + assertThat(NativeInterpreterWrapper.isNonEmptyArray(validArray)).isTrue(); + } + + @Test + public void testDataTypeOf() { + float[] testEmtpyArray = {}; + DataType dataType = NativeInterpreterWrapper.dataTypeOf(testEmtpyArray); + assertThat(dataType).isEqualTo(DataType.FLOAT32); + float[] testFloatArray = {0.783f, 0.251f}; + dataType = NativeInterpreterWrapper.dataTypeOf(testFloatArray); + assertThat(dataType).isEqualTo(DataType.FLOAT32); + float[][] testMultiDimArray = {testFloatArray, testFloatArray, testFloatArray}; + dataType = NativeInterpreterWrapper.dataTypeOf(testFloatArray); + assertThat(dataType).isEqualTo(DataType.FLOAT32); + try { + double[] testDoubleArray = {0.783, 0.251}; + NativeInterpreterWrapper.dataTypeOf(testDoubleArray); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e).hasMessageThat().contains("cannot resolve DataType of"); + } + try { + Float[] testBoxedArray = {0.783f, 0.251f}; + NativeInterpreterWrapper.dataTypeOf(testBoxedArray); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e).hasMessageThat().contains("cannot resolve DataType of [Ljava.lang.Float;"); + } + } + + @Test + public void testNumDimensions() { + int scalar = 1; + assertThat(NativeInterpreterWrapper.numDimensions(scalar)).isEqualTo(0); + int[][] array = {{2, 4}, {1, 9}}; + assertThat(NativeInterpreterWrapper.numDimensions(array)).isEqualTo(2); + try { + int[] emptyArray = {}; + NativeInterpreterWrapper.numDimensions(emptyArray); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e).hasMessageThat().contains("array lengths cannot be 0."); + } + } + + @Test + public void testFillShape() { + int[][][] array = {{{23}, {14}, {87}}, {{12}, {42}, {31}}}; + int num = NativeInterpreterWrapper.numDimensions(array); + int[] shape = new int[num]; + NativeInterpreterWrapper.fillShape(array, 0, shape); + assertThat(num).isEqualTo(3); + assertThat(shape[0]).isEqualTo(2); + assertThat(shape[1]).isEqualTo(3); + assertThat(shape[2]).isEqualTo(1); + } +} diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorFlowLiteTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorFlowLiteTest.java new file mode 100644 index 0000000000000000000000000000000000000000..665c937cb60ad957c0030c01eb57899754c80bf8 --- /dev/null +++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorFlowLiteTest.java @@ -0,0 +1,32 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite; + +import static com.google.common.truth.Truth.assertThat; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link org.tensorflow.lite.TensorFlowLite}. */ +@RunWith(JUnit4.class) +public final class TensorFlowLiteTest { + + @Test + public void testVersion() { + assertThat(TensorFlowLite.version()).isEqualTo("3"); + } +} diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java new file mode 100644 index 0000000000000000000000000000000000000000..94b6632bb8dd7117bf4074da1939bd23ce732efd --- /dev/null +++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java @@ -0,0 +1,105 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.fail; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link org.tensorflow.lite.Tensor}. */ +@RunWith(JUnit4.class) +public final class TensorTest { + + private static final String MODEL_PATH = + "tensorflow/contrib/lite/java/src/testdata/add.bin"; + + private NativeInterpreterWrapper wrapper; + private long nativeHandle; + + @Before + public void setUp() { + wrapper = new NativeInterpreterWrapper(MODEL_PATH); + float[] oneD = {1.23f, 6.54f, 7.81f}; + float[][] twoD = {oneD, oneD, oneD, oneD, oneD, oneD, oneD, oneD}; + float[][][] threeD = {twoD, twoD, twoD, twoD, twoD, twoD, twoD, twoD}; + float[][][][] fourD = {threeD, threeD}; + Object[] inputs = {fourD}; + Tensor[] outputs = wrapper.run(inputs); + nativeHandle = outputs[0].nativeHandle; + } + + @After + public void tearDown() { + wrapper.close(); + } + + @Test + public void testFromHandle() throws Exception { + Tensor tensor = Tensor.fromHandle(nativeHandle); + assertThat(tensor).isNotNull(); + int[] expectedShape = {2, 8, 8, 3}; + assertThat(tensor.shapeCopy).isEqualTo(expectedShape); + assertThat(tensor.dtype).isEqualTo(DataType.FLOAT32); + } + + @Test + public void testCopyTo() { + Tensor tensor = Tensor.fromHandle(nativeHandle); + float[][][][] parsedOutputs = new float[2][8][8][3]; + tensor.copyTo(parsedOutputs); + float[] outputOneD = parsedOutputs[0][0][0]; + float[] expected = {3.69f, 19.62f, 23.43f}; + assertThat(outputOneD).usingTolerance(0.1f).containsExactly(expected).inOrder(); + } + + @Test + public void testCopyToWrongType() { + Tensor tensor = Tensor.fromHandle(nativeHandle); + int[][][][] parsedOutputs = new int[2][8][8][3]; + try { + tensor.copyTo(parsedOutputs); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e) + .hasMessageThat() + .contains( + "Cannot convert an TensorFlowLite tensor with type " + + "FLOAT32 to a Java object of type [[[[I (which is compatible with the TensorFlowLite " + + "type INT32)"); + } + } + + @Test + public void testCopyToWrongShape() { + Tensor tensor = Tensor.fromHandle(nativeHandle); + float[][][][] parsedOutputs = new float[1][8][8][3]; + try { + tensor.copyTo(parsedOutputs); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e) + .hasMessageThat() + .contains( + "Shape of output target [1, 8, 8, 3] does not match " + + "with the shape of the Tensor [2, 8, 8, 3]."); + } + } +} diff --git a/tensorflow/contrib/lite/java/src/testdata/add.bin b/tensorflow/contrib/lite/java/src/testdata/add.bin new file mode 100644 index 0000000000000000000000000000000000000000..aef0fe3d82c9d92dc444076d3b46e05af1923f46 Binary files /dev/null and b/tensorflow/contrib/lite/java/src/testdata/add.bin differ diff --git a/tensorflow/contrib/lite/java/src/testdata/float32.bin b/tensorflow/contrib/lite/java/src/testdata/float32.bin new file mode 100644 index 0000000000000000000000000000000000000000..30b1264ca152740e1607651ce6cbc2a548319bc3 Binary files /dev/null and b/tensorflow/contrib/lite/java/src/testdata/float32.bin differ diff --git a/tensorflow/contrib/lite/java/src/testdata/int32.bin b/tensorflow/contrib/lite/java/src/testdata/int32.bin new file mode 100644 index 0000000000000000000000000000000000000000..f6f3cf607a249e096921b12d848c4055a37d1168 Binary files /dev/null and b/tensorflow/contrib/lite/java/src/testdata/int32.bin differ diff --git a/tensorflow/contrib/lite/java/src/testdata/int64.bin b/tensorflow/contrib/lite/java/src/testdata/int64.bin new file mode 100644 index 0000000000000000000000000000000000000000..c12aa41ca7be49b30db291a25156bd20cbab21a9 Binary files /dev/null and b/tensorflow/contrib/lite/java/src/testdata/int64.bin differ diff --git a/tensorflow/contrib/lite/java/src/testdata/invalid_model.bin b/tensorflow/contrib/lite/java/src/testdata/invalid_model.bin new file mode 100644 index 0000000000000000000000000000000000000000..8156ac741cbc0aa32e6d867ad09b5e6be8451868 --- /dev/null +++ b/tensorflow/contrib/lite/java/src/testdata/invalid_model.bin @@ -0,0 +1 @@ +This is an invalid model. \ No newline at end of file diff --git a/tensorflow/contrib/lite/java/src/testdata/uint8.bin b/tensorflow/contrib/lite/java/src/testdata/uint8.bin new file mode 100644 index 0000000000000000000000000000000000000000..f06c5cf58462ce56b012d163fb208329874f83ad Binary files /dev/null and b/tensorflow/contrib/lite/java/src/testdata/uint8.bin differ diff --git a/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/BUILD b/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..2b4f37bc6cfe1dbc0c178a56b892f545e8ad4f3b --- /dev/null +++ b/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/BUILD @@ -0,0 +1,30 @@ +# Description: +# Internal helper function to test TF Lite API. + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +android_library( + name = "testhelper", + srcs = glob( + [ + "*.java", + ], + ), + deps = [ + "//tensorflow/contrib/lite/java:tensorflowlite_java", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java b/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java new file mode 100644 index 0000000000000000000000000000000000000000..8660cabf709e6531a5667a16e5cf43a93c7135bd --- /dev/null +++ b/tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite/TestHelper.java @@ -0,0 +1,35 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.lite; + +/** A helper class for internal tests. */ +public class TestHelper { + + /** + * Turns on/off NNAPI of an {@code Interpreter}. + * + * @param interpreter an instance of {@code Interpreter}. If it is not initialized, an {@code + * IllegalArgumentException} will be thrown. + * @param useNNAPI a boolean value indicating to turn on or off NNAPI. + */ + public static void setUseNNAPI(Interpreter interpreter, boolean useNNAPI) { + if (interpreter != null && interpreter.wrapper != null) { + interpreter.wrapper.setUseNNAPI(useNNAPI); + } else { + throw new IllegalArgumentException("Interpreter has not initialized; Failed to setUseNNAPI."); + } + } +} diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..ad76e906064b30801b4c2484cfe180589241afe1 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/BUILD @@ -0,0 +1,409 @@ +package(default_visibility = [ + "//visibility:public", +]) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts") +load( + "//tensorflow:tensorflow.bzl", + "tf_cc_test", +) + +tf_cc_test( + name = "optional_tensor_test", + size = "small", + srcs = ["optional_tensor_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "test_util", + testonly = 1, + srcs = ["test_util.cc"], + hdrs = ["test_util.h"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:schema_fbs_version", + "//tensorflow/contrib/lite:string_util", + "//tensorflow/contrib/lite/testing:util", + "//tensorflow/core:lib", + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "gemm_support", + srcs = [ + "gemm_support.cc", + ], + hdrs = [ + "gemm_support.h", + ], + copts = tflite_copts(), + deps = [ + ":op_macros", + "//tensorflow/contrib/lite:context", + "@gemmlowp//:gemmlowp", + ], +) + +cc_library( + name = "activation_functor", + hdrs = [ + "activation_functor.h", + ], + deps = [ + "//tensorflow/contrib/lite:builtin_op_data", + ], +) + +cc_library( + name = "op_macros", + hdrs = [ + "op_macros.h", + ], +) + +cc_library( + name = "builtin_ops", + srcs = [ + "activations.cc", + "add.cc", + "basic_rnn.cc", + "concatenation.cc", + "conv.cc", + "depthwise_conv.cc", + "embedding_lookup.cc", + "embedding_lookup_sparse.cc", + "fully_connected.cc", + "hashtable_lookup.cc", + "kernel_util.cc", + "l2norm.cc", + "local_response_norm.cc", + "lsh_projection.cc", + "lstm.cc", + "mul.cc", + "pooling.cc", + "register.cc", + "reshape.cc", + "resize_bilinear.cc", + "skip_gram.cc", + "space_to_depth.cc", + "svdf.cc", + ], + hdrs = [ + "kernel_util.h", + "padding.h", + "register.h", + ], + # Suppress warnings that are introduced by Eigen Tensor. + copts = tflite_copts() + [ + "-Wno-error=reorder", + ] + select({ + "//tensorflow:ios": ["-Wno-error=invalid-partial-specialization"], + "//conditions:default": [ + ], + }), + deps = [ + ":activation_functor", + ":op_macros", + "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:string_util", + "//tensorflow/contrib/lite/kernels:gemm_support", + "//tensorflow/contrib/lite/kernels/internal:optimized", + "//tensorflow/contrib/lite/kernels/internal:optimized_base", + "//tensorflow/contrib/lite/kernels/internal:quantization_util", + "//tensorflow/contrib/lite/kernels/internal:reference", + "//tensorflow/contrib/lite/kernels/internal:reference_base", + "//tensorflow/contrib/lite/kernels/internal:round", + "//tensorflow/contrib/lite/kernels/internal:tensor_utils", + "@farmhash_archive//:farmhash", + ], +) + +tf_cc_test( + name = "activations_test", + size = "small", + srcs = ["activations_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "add_test", + size = "small", + srcs = ["add_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "concatenation_test", + size = "small", + srcs = ["concatenation_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "conv_test", + size = "small", + srcs = ["conv_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "depthwise_conv_test", + size = "small", + srcs = ["depthwise_conv_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "basic_rnn_test", + size = "small", + srcs = ["basic_rnn_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "l2norm_test", + size = "small", + srcs = ["l2norm_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "mul_test", + size = "small", + srcs = ["mul_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "reshape_test", + size = "small", + srcs = ["reshape_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "resize_bilinear_test", + size = "small", + srcs = ["resize_bilinear_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "svdf_test", + size = "small", + srcs = ["svdf_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "embedding_lookup_test", + size = "small", + srcs = ["embedding_lookup_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "embedding_lookup_sparse_test", + size = "small", + srcs = ["embedding_lookup_sparse_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "fully_connected_test", + size = "small", + srcs = ["fully_connected_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "local_response_norm_test", + size = "small", + srcs = ["local_response_norm_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "pooling_test", + size = "small", + srcs = ["pooling_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "softmax_test", + size = "small", + srcs = ["softmax_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/contrib/lite/kernels/internal:reference_base", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "lsh_projection_test", + size = "small", + srcs = ["lsh_projection_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "hashtable_lookup_test", + size = "small", + srcs = ["hashtable_lookup_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:string_util", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "lstm_test", + size = "small", + srcs = ["lstm_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "skip_gram_test", + size = "small", + srcs = ["skip_gram_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:string_util", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "space_to_depth_test", + size = "small", + srcs = ["space_to_depth_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/lite/kernels/activation_functor.h b/tensorflow/contrib/lite/kernels/activation_functor.h new file mode 100644 index 0000000000000000000000000000000000000000..cfb3369e991a474315424423fe655ba214edabbc --- /dev/null +++ b/tensorflow/contrib/lite/kernels/activation_functor.h @@ -0,0 +1,58 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_ACTIVATION_FUNCTOR_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_ACTIVATION_FUNCTOR_H_ + +#include +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" + +namespace tflite { + +// Dynamic (non-fused) activation functor. perhaps it is worth having +// template instantiation? +// TODO(aselle): Make this more efficient by pulling the switch to conv_eval +// using template inlining. +class ActivationFunctor { + public: + explicit ActivationFunctor(TfLiteFusedActivation act) : act_(act) {} + + float operator()(float a) const { + switch (act_) { + case kTfLiteActNone: + return a; + case kTfLiteActRelu: + return a < 0.f ? 0.f : a; + case kTfLiteActRelu6: + return std::max(0.f, std::min(a, 6.f)); + case kTfLiteActTanh: + return std::tanh(a); + case kTfLiteActSigmoid: + return 1.0f / (1.0f + std::exp(-a)); + default: + // TODO(aselle): More informative fatal error! + exit(1); + } + } + + private: + TfLiteFusedActivation act_; +}; + +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_ACTIVATION_FUNCTOR_H_ diff --git a/tensorflow/contrib/lite/kernels/activations.cc b/tensorflow/contrib/lite/kernels/activations.cc new file mode 100644 index 0000000000000000000000000000000000000000..7ab60a33e5e2ff61bae5f4c6db85ab9c47a391bc --- /dev/null +++ b/tensorflow/contrib/lite/kernels/activations.cc @@ -0,0 +1,389 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace activations { + +struct OpData { + int32_t input_multiplier = 0; + int input_left_shift = 0; + int32_t input_range_radius = 0; + int diff_min = 0; +}; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + // This is a builtin op, so we don't use the contents in 'buffer', if any. + // Instead, we allocate a new object to carry information from Prepare() to + // Eval(). + return new OpData; +} + +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast(buffer); +} + +TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + TfLiteTensor* input = GetInput(context, node, 0); + TfLiteTensor* output = GetOutput(context, node, 0); + TF_LITE_ENSURE_EQ(context, input->type, output->type); + + return context->ResizeTensor(context, output, + TfLiteIntArrayCopy(input->dims)); +} + +TfLiteStatus SigmoidPrepare(TfLiteContext* context, TfLiteNode* node) { + OpData* data = reinterpret_cast(node->user_data); + + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + TfLiteTensor* input = GetInput(context, node, 0); + TfLiteTensor* output = GetOutput(context, node, 0); + TF_LITE_ENSURE_EQ(context, input->type, output->type); + + if (input->type == kTfLiteUInt8) { + TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0); + TF_LITE_ENSURE(context, output->params.scale == 1. / 256); + + static constexpr int kInputIntegerBits = 4; + + const double input_real_multiplier = + input->params.scale * + static_cast(1 << (31 - kInputIntegerBits)); + + QuantizeMultiplierGreaterThanOne(input_real_multiplier, + &data->input_multiplier, + &data->input_left_shift); + data->input_range_radius = + CalculateInputRadius(kInputIntegerBits, data->input_left_shift); + } + + return context->ResizeTensor(context, output, + TfLiteIntArrayCopy(input->dims)); +} + +TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + OpData* data = reinterpret_cast(node->user_data); + + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + TfLiteTensor* input = GetInput(context, node, 0); + TfLiteTensor* output = GetOutput(context, node, 0); + TF_LITE_ENSURE_EQ(context, input->type, output->type); + + TF_LITE_ENSURE(context, + NumDimensions(input) == 2 || NumDimensions(input) == 4); + + if (input->type == kTfLiteUInt8) { + TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0); + TF_LITE_ENSURE(context, output->params.scale == 1. / 256); + + static const int kScaledDiffIntegerBits = 5; + + tflite::PreprocessSoftmaxScaling( + params->beta, input->params.scale, kScaledDiffIntegerBits, + &data->input_multiplier, &data->input_left_shift); + data->diff_min = -1.0 * tflite::CalculateInputRadius( + kScaledDiffIntegerBits, data->input_left_shift); + } + + return context->ResizeTensor(context, output, + TfLiteIntArrayCopy(input->dims)); +} + +TfLiteStatus ReluEval(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* input = GetInput(context, node, 0); + TfLiteTensor* output = GetOutput(context, node, 0); + switch (input->type) { + case kTfLiteFloat32: { + size_t elements = input->bytes / sizeof(float); + float* in = input->data.f; + float* in_end = in + elements; + float* out = output->data.f; + for (; in < in_end; in++, out++) *out = std::max(0.f, *in); + return kTfLiteOk; + } + break; + default: + context->ReportError(context, "Only float32 supported currently."); + return kTfLiteError; + } +} + +TfLiteStatus Relu1Eval(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* input = GetInput(context, node, 0); + TfLiteTensor* output = GetOutput(context, node, 0); + switch (input->type) { + case kTfLiteFloat32: { + size_t elements = input->bytes / sizeof(float); + float* in = input->data.f; + float* in_end = in + elements; + float* out = output->data.f; + for (; in < in_end; in++, out++) { + *out = std::min(std::max(-1.f, *in), 1.f); + } + return kTfLiteOk; + } break; + default: + context->ReportError(context, "Only float32 supported currently."); + return kTfLiteError; + } +} + +TfLiteStatus Relu6Eval(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* input = GetInput(context, node, 0); + TfLiteTensor* output = GetOutput(context, node, 0); + switch (input->type) { + case kTfLiteFloat32: { + size_t elements = input->bytes / sizeof(float); + float* in = input->data.f; + float* in_end = in + elements; + float* out = output->data.f; + for (; in < in_end; in++, out++) *out = std::min(std::max(0.f, *in), 6.f); + return kTfLiteOk; + } + break; + default: + context->ReportError(context, "Only float32 supported currently."); + return kTfLiteError; + } +} + +TfLiteStatus TanhEval(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* input = GetInput(context, node, 0); + TfLiteTensor* output = GetOutput(context, node, 0); + switch (input->type) { + case kTfLiteFloat32: { + size_t elements = input->bytes / sizeof(float); + float* in = input->data.f; + float* in_end = in + elements; + float* out = output->data.f; + for (; in < in_end; in++, out++) *out = std::tanh(*in); + return kTfLiteOk; + } + break; + default: + context->ReportError(context, "Only float32 supported currently."); + return kTfLiteError; + } +} + +// Sigmoid is also know as "Logistic". +TfLiteStatus SigmoidEval(TfLiteContext* context, TfLiteNode* node) { + OpData* data = reinterpret_cast(node->user_data); + + TfLiteTensor* input = GetInput(context, node, 0); + TfLiteTensor* output = GetOutput(context, node, 0); + switch (input->type) { + case kTfLiteFloat32: { + size_t elements = input->bytes / sizeof(float); + float* in = input->data.f; + float* in_end = in + elements; + float* out = output->data.f; + for (; in < in_end; in++, out++) *out = 1.f / (1.f + std::exp(-*in)); + break; + } + case kTfLiteUInt8: { + optimized_ops::Logistic( + GetTensorData(input), GetTensorDims(input), + input->params.zero_point, data->input_range_radius, + data->input_multiplier, data->input_left_shift, + GetTensorData(output), GetTensorDims(output)); + break; + } + default: + context->ReportError(context, "Only float32 supported currently."); + return kTfLiteError; + } + return kTfLiteOk; +} + +// Takes a 2D tensor and perform softmax along the second dimension. +void Softmax2DFloat(TfLiteTensor* input, TfLiteTensor* output, + TfLiteSoftmaxParams* params) { + const int batch_size = input->dims->data[0]; + const int input_size = input->dims->data[1]; + float* in = input->data.f; + float* out = output->data.f; + TF_LITE_ASSERT(input_size > 0); + + // For each batch + for (int b = 0; b < batch_size; b++) { + // Find the max coeff. + float max_coeff = in[0]; + for (int i = 1; i < input_size; i++) { + if (in[i] > max_coeff) max_coeff = in[i]; + } + + // Compute the normalized sum of exps. + float exp_sum = 0.0; + for (int i = 0; i < input_size; i++) { + out[i] = std::exp((in[i] - max_coeff) * params->beta); + exp_sum += out[i]; + } + + // Divide by the sum of exps. + float reciprocal_sum_exp = 1.f / exp_sum; + for (int i = 0; i < input_size; i++) { + out[i] *= reciprocal_sum_exp; + } + + // Advance in and out pointers for the next batch. + in += input_size; + out += input_size; + } +} + +void Softmax2DQuantized(TfLiteTensor* input, TfLiteTensor* output, + TfLiteSoftmaxParams* params, OpData* data) { + // TODO(ahentz): this is arguably a dirty trick. Since the implementation + // always traverses the last dimension of a 4D tensor, we will pretend our 2D + // tensor is 4D in a special way. We will convert a (X, Y) shape into a (X, + // 1, 1, Y) shape. + const int batch_size = input->dims->data[0]; + const int input_size = input->dims->data[1]; + optimized_ops::Softmax(GetTensorData(input), + GetTensorDims({batch_size, 1, 1, input_size}), + data->input_multiplier, data->input_left_shift, + data->diff_min, GetTensorData(output), + GetTensorDims({batch_size, 1, 1, input_size})); +} + +// Takes a 4D tensor and perform softmax along the forth dimension. +void Softmax4DFloat(TfLiteTensor* input, TfLiteTensor* output, + TfLiteSoftmaxParams* params) { + optimized_ops::Softmax(GetTensorData(input), GetTensorDims(input), + params->beta, GetTensorData(output), + GetTensorDims(output)); +} + +void Softmax4DQuantized(TfLiteTensor* input, TfLiteTensor* output, + TfLiteSoftmaxParams* params, OpData* data) { + optimized_ops::Softmax(GetTensorData(input), GetTensorDims(input), + data->input_multiplier, data->input_left_shift, + data->diff_min, GetTensorData(output), + GetTensorDims(output)); +} + +TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + OpData* data = reinterpret_cast(node->user_data); + + TfLiteTensor* input = GetInput(context, node, 0); + TfLiteTensor* output = GetOutput(context, node, 0); + + // TODO(ahentz): consider an implementation that works for many (all?) + // dimensions. + switch (input->type) { + case kTfLiteFloat32: { + if (NumDimensions(input) == 2) { + Softmax2DFloat(input, output, params); + return kTfLiteOk; + } + if (NumDimensions(input) == 4) { + Softmax4DFloat(input, output, params); + return kTfLiteOk; + } + context->ReportError(context, + "Only 2D and 4D tensors supported currently."); + return kTfLiteError; + } + case kTfLiteUInt8: { + if (NumDimensions(input) == 2) { + Softmax2DQuantized(input, output, params, data); + return kTfLiteOk; + } + if (NumDimensions(input) == 4) { + Softmax4DQuantized(input, output, params, data); + return kTfLiteOk; + } + context->ReportError(context, + "Only 2D and 4D tensors supported currently."); + return kTfLiteError; + } + default: + context->ReportError(context, + "Only float32 and uint8_t supported currently."); + return kTfLiteError; + } +} + +} // namespace activations + +TfLiteRegistration* Register_RELU() { + static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, + activations::GenericPrepare, + activations::ReluEval}; + return &r; +} + +TfLiteRegistration* Register_RELU1() { + static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, + activations::GenericPrepare, + activations::Relu1Eval}; + return &r; +} + +TfLiteRegistration* Register_RELU6() { + static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, + activations::GenericPrepare, + activations::Relu6Eval}; + return &r; +} + +TfLiteRegistration* Register_TANH() { + static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, + activations::GenericPrepare, + activations::TanhEval}; + return &r; +} + +TfLiteRegistration* Register_LOGISTIC() { + static TfLiteRegistration r = {activations::Init, activations::Free, + activations::SigmoidPrepare, + activations::SigmoidEval}; + return &r; +} + +TfLiteRegistration* Register_SOFTMAX() { + static TfLiteRegistration r = {activations::Init, activations::Free, + activations::SoftmaxPrepare, + activations::SoftmaxEval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/activations_test.cc b/tensorflow/contrib/lite/kernels/activations_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..33ca56e745c043efd12b851af14f273fb273d577 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/activations_test.cc @@ -0,0 +1,323 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class BaseActivationsOpModel : public SingleOpModel { + public: + // Most activations don't take any options, so this constructor works for + // them. + BaseActivationsOpModel(BuiltinOperator type, TensorData input) { + input_ = AddInput(input); + if (input.type == TensorType_UINT8) { + output_ = AddOutput({input.type, {}, 0, 0, 1. / 256}); + } else { + output_ = AddOutput({input.type, {}}); + } + SetBuiltinOp(type, BuiltinOptions_NONE, 0); + BuildInterpreter({GetShape(input_)}); + } + + // A dedicated constructor for SOFTMAX, which does some options. + BaseActivationsOpModel(float softmax_beta, TensorData input) { + input_ = AddInput(input); + if (input.type == TensorType_UINT8) { + output_ = AddOutput({input.type, {}, 0, 0, 1. / 256}); + } else { + output_ = AddOutput({input.type, {}}); + } + SetBuiltinOp(BuiltinOperator_SOFTMAX, BuiltinOptions_SoftmaxOptions, + CreateSoftmaxOptions(builder_, softmax_beta).Union()); + BuildInterpreter({GetShape(input_)}); + } + + protected: + int input_; + int output_; +}; + +class FloatActivationsOpModel : public BaseActivationsOpModel { + public: + using BaseActivationsOpModel::BaseActivationsOpModel; + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + std::vector GetOutput() { return ExtractVector(output_); } +}; + +// TODO(ahentz): I don't quite understand the tradeoffs in the quantized +// implementation of sigmoid and software, but a tolerance of twice the output +// scale seems reasonable. We might want to change this if we have a better +// theoretical bound. +const float kQuantizedTolerance = 2 * (1. / 256); + +class QuantizedActivationsOpModel : public BaseActivationsOpModel { + public: + using BaseActivationsOpModel::BaseActivationsOpModel; + + void SetInput(std::initializer_list data) { + QuantizeAndPopulate(input_, data); + } + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetDequantizedOutput() { + return Dequantize(ExtractVector(output_), + GetScale(output_), GetZeroPoint(output_)); + } +}; + +TEST(FloatActivationsOpTest, Relu) { + FloatActivationsOpModel m(BuiltinOperator_RELU, + /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}}); + m.SetInput({ + 0, -6, 2, 4, // + 3, -2, 10, 1, // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 0, 0, 2, 4, // + 3, 0, 10, 1, // + })); +} + +TEST(FloatActivationsOpTest, Relu1) { + FloatActivationsOpModel m(BuiltinOperator_RELU1, + /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}}); + m.SetInput({ + 0.0, -0.6, 0.2, -0.4, // + 0.3, -2.0, 1.1, -0.1, // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 0.0, -0.6, 0.2, -0.4, // + 0.3, -1.0, 1.0, -0.1, // + })); +} + +TEST(FloatActivationsOpTest, Relu6) { + FloatActivationsOpModel m(BuiltinOperator_RELU6, + /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}}); + m.SetInput({ + 0, -6, 2, 4, // + 3, -2, 10, 1, // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 0, 0, 2, 4, // + 3, 0, 6, 1, // + })); +} + +TEST(FloatActivationsOpTest, Tanh) { + FloatActivationsOpModel m(BuiltinOperator_TANH, + /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}}); + m.SetInput({ + 0, -6, 2, 4, // + 3, -2, 10, 1, // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 0, -0.9999877, 0.9640275, 0.999329, // + 0.99505475, -0.9640275, 1, 0.7615941, // + }))); +} + +TEST(FloatActivationsOpTest, Sigmoid) { + FloatActivationsOpModel m(BuiltinOperator_LOGISTIC, + /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}}); + m.SetInput({ + 0, -6, 2, 4, // + 3, -2, 10, 1, // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 0.5, 0.002473, 0.880797, 0.982014, // + 0.952574, 0.119203, 0.999955, 0.731059, // + }))); +} + +TEST(QuantizedActivationsOpTest, Sigmoid) { + QuantizedActivationsOpModel m( + BuiltinOperator_LOGISTIC, + /*input=*/{TensorType_UINT8, {1, 2, 4, 1}, -10, 10}); + m.SetInput({ + 0, -6, 2, 4, // + 3, -2, 10, 1, // + }); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + { + 0.5, 0.002473, 0.880797, 0.982014, // + 0.952574, 0.119203, 0.999955, 0.731059, // + }, + kQuantizedTolerance))); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({128, 1, 227, 251, 244, 32, 255, 188})); +} + +TEST(FloatActivationsOpTest, Softmax4D) { + FloatActivationsOpModel m(0.1, + /*input=*/{TensorType_FLOAT32, {1, 2, 1, 4}}); + m.SetInput({ + 0, -6, 2, 4, // depth = 0 + 3, -2, 10, 1, // depth = 1 + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + .23463, .12877, .28658, .35003, // + .22528, .13664, .45365, .18443, // + }))); + + // Same input, but a different shape. + FloatActivationsOpModel m2(0.1, + /*input=*/{TensorType_FLOAT32, {4, 1, 1, 2}}); + m2.SetInput({ + 0, -6, // + 2, 4, // + 3, -2, // + 10, 1, // + }); + m2.Invoke(); + EXPECT_THAT(m2.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 0.645656, 0.354344, // + 0.450166, 0.549834, // + 0.622459, 0.377541, // + 0.710949, 0.28905, // + }))); +} + +TEST(QuantizedActivationsOpTest, Softmax4D) { + QuantizedActivationsOpModel m( + 0.1, + /*input=*/{TensorType_UINT8, {1, 2, 1, 4}, -10, 10}); + m.SetInput({ + 0, -6, 2, 4, // depth = 0 + 3, -2, 10, 1, // depth = 1 + }); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + { + .23463, .12877, .28658, .35003, // + .22528, .13664, .45365, .18443, // + }, + kQuantizedTolerance))); + + // Same input, but a different shape. + QuantizedActivationsOpModel m2( + 0.1, + /*input=*/{TensorType_UINT8, {4, 1, 1, 2}, -10, 10}); + m2.SetInput({ + 0, -6, // + 2, 4, // + 3, -2, // + 10, 1, // + }); + m2.Invoke(); + EXPECT_THAT(m2.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear( + { + 0.645656, 0.354344, // + 0.450166, 0.549834, // + 0.622459, 0.377541, // + 0.710949, 0.28905, // + }, + kQuantizedTolerance))); +} + +TEST(FloatActivationsOpTest, Softmax2D) { + FloatActivationsOpModel m(0.1, + /*input=*/{TensorType_FLOAT32, {2, 4}}); + m.SetInput({ + 0, -6, 2, 4, // + 3, -2, 10, 1, // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + .23463, .12877, .28658, .35003, // + .22528, .13664, .45365, .18443, // + }))); + + // Same input, but a different shape. + FloatActivationsOpModel m2(0.1, + /*input=*/{TensorType_FLOAT32, {4, 2}}); + m2.SetInput({ + 0, -6, // + 2, 4, // + 3, -2, // + 10, 1, // + }); + m2.Invoke(); + EXPECT_THAT(m2.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 0.645656, 0.354344, // + 0.450166, 0.549834, // + 0.622459, 0.377541, // + 0.710949, 0.28905, // + }))); +} + +TEST(QuantizedActivationsOpTest, Softmax2D) { + QuantizedActivationsOpModel m(0.1, + /*input=*/{TensorType_UINT8, {2, 4}, -10, 10}); + m.SetInput({ + 0, -6, 2, 4, // + 3, -2, 10, 1, // + }); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + { + .23463, .12877, .28658, .35003, // + .22528, .13664, .45365, .18443, // + }, + kQuantizedTolerance))); + + // Same input, but a different shape. + QuantizedActivationsOpModel m2(0.1, + /*input=*/{TensorType_UINT8, {4, 2}, -10, 10}); + m2.SetInput({ + 0, -6, // + 2, 4, // + 3, -2, // + 10, 1, // + }); + m2.Invoke(); + EXPECT_THAT(m2.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear( + { + 0.645656, 0.354344, // + 0.450166, 0.549834, // + 0.622459, 0.377541, // + 0.710949, 0.28905, // + }, + kQuantizedTolerance))); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/add.cc b/tensorflow/contrib/lite/kernels/add.cc new file mode 100644 index 0000000000000000000000000000000000000000..0e10a249abac3ba19cf107e055aa71d1eee00122 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/add.cc @@ -0,0 +1,184 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace add { + +// This file has three implementation of Add. +enum KernelType { + kReference, + kGenericOptimized, // Neon-free + kNeonOptimized, +}; + +constexpr int kInputTensor1 = 0; +constexpr int kInputTensor2 = 1; +constexpr int kOutputTensor = 0; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + TF_LITE_ENSURE_EQ(context, NumDimensions(input1), NumDimensions(input2)); + for (int i = 0; i < NumDimensions(input1); ++i) { + TF_LITE_ENSURE_EQ(context, SizeOfDimension(input1, i), + SizeOfDimension(input2, i)); + } + + TF_LITE_ENSURE_EQ(context, input1->type, output->type); + TF_LITE_ENSURE_EQ(context, input2->type, output->type); + + TfLiteIntArray* output_size = TfLiteIntArrayCopy(input1->dims); + return context->ResizeTensor(context, output, output_size); +} + +template +void EvalAddFloat(TfLiteContext* context, TfLiteNode* node, + TfLiteAddParams* params, TfLiteTensor* input1, + TfLiteTensor* input2, TfLiteTensor* output) { + float output_activation_min, output_activation_max; + CalculateActivationRangeFloat(params->activation, &output_activation_min, + &output_activation_max); +#define TF_LITE_ADD(type) \ + type::Add(GetTensorData(input1), GetTensorDims(input1), \ + GetTensorData(input2), GetTensorDims(input2), \ + output_activation_min, output_activation_max, \ + GetTensorData(output), GetTensorDims(output)) + if (kernel_type == kReference) { + TF_LITE_ADD(reference_ops); + } else { + TF_LITE_ADD(optimized_ops); + } +#undef TF_LITE_ADD +} + +template +void EvalAddQuantized(TfLiteContext* context, TfLiteNode* node, + TfLiteAddParams* params, TfLiteTensor* input1, + TfLiteTensor* input2, TfLiteTensor* output) { + auto input1_offset = -input1->params.zero_point; + auto input2_offset = -input2->params.zero_point; + auto output_offset = output->params.zero_point; + const int left_shift = 20; + const double twice_max_input_scale = + 2 * std::max(input1->params.scale, input2->params.scale); + const double real_input1_multiplier = + input1->params.scale / twice_max_input_scale; + const double real_input2_multiplier = + input2->params.scale / twice_max_input_scale; + const double real_output_multiplier = + twice_max_input_scale / ((1 << left_shift) * output->params.scale); + + int32 input1_multiplier; + int input1_shift; + QuantizeMultiplierSmallerThanOne(real_input1_multiplier, &input1_multiplier, + &input1_shift); + int32 input2_multiplier; + int input2_shift; + QuantizeMultiplierSmallerThanOne(real_input2_multiplier, &input2_multiplier, + &input2_shift); + int32 output_multiplier; + int output_shift; + QuantizeMultiplierSmallerThanOne(real_output_multiplier, &output_multiplier, + &output_shift); + + int32 output_activation_min, output_activation_max; + CalculateActivationRangeUint8(params->activation, output, + &output_activation_min, &output_activation_max); + +#define TF_LITE_ADD(type) \ + type::BroadcastAdd( \ + left_shift, GetTensorData(input1), GetTensorDims(input1), \ + input1_offset, input1_multiplier, input1_shift, \ + GetTensorData(input2), GetTensorDims(input2), input2_offset, \ + input2_multiplier, input2_shift, output_offset, output_multiplier, \ + output_shift, output_activation_min, output_activation_max, \ + GetTensorData(output), GetTensorDims(output)); + + if (kernel_type == kReference) { + TF_LITE_ADD(reference_ops); + } else { + TF_LITE_ADD(optimized_ops); + } +#undef TF_LITE_ADD +} + +template +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + + TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + if (output->type == kTfLiteFloat32) { + EvalAddFloat(context, node, params, input1, input2, output); + } else if (output->type == kTfLiteUInt8) { + EvalAddQuantized(context, node, params, input1, input2, + output); + } else { + context->ReportError(context, + "Inputs and outputs not all float|unit8 types."); + return kTfLiteError; + } + + return kTfLiteOk; +} + +} // namespace add + +TfLiteRegistration* Register_ADD_REF() { + static TfLiteRegistration r = {nullptr, nullptr, add::Prepare, + add::Eval}; + return &r; +} + +TfLiteRegistration* Register_ADD_GENERIC_OPT() { + static TfLiteRegistration r = {nullptr, nullptr, add::Prepare, + add::Eval}; + return &r; +} + +TfLiteRegistration* Register_ADD_NEON_OPT() { + static TfLiteRegistration r = {nullptr, nullptr, add::Prepare, + add::Eval}; + return &r; +} + +TfLiteRegistration* Register_ADD() { +#ifdef USE_NEON + return Register_ADD_NEON_OPT(); +#else + return Register_ADD_GENERIC_OPT(); +#endif +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/add_test.cc b/tensorflow/contrib/lite/kernels/add_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..ddf45bb576755d57d50c9e6e01bf50f15612c56d --- /dev/null +++ b/tensorflow/contrib/lite/kernels/add_test.cc @@ -0,0 +1,170 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class BaseAddOpModel : public SingleOpModel { + public: + BaseAddOpModel(const TensorData& input, const TensorData& output, + ActivationFunctionType activation_type) { + input1_ = AddInput(input); + input2_ = AddInput(input); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_ADD, BuiltinOptions_AddOptions, + CreateAddOptions(builder_, activation_type).Union()); + BuildInterpreter({GetShape(input1_), GetShape(input2_)}); + } + + int input1() { return input1_; } + int input2() { return input2_; } + + protected: + int input1_; + int input2_; + int output_; +}; + +class FloatAddOpModel : public BaseAddOpModel { + public: + using BaseAddOpModel::BaseAddOpModel; + + std::vector GetOutput() { return ExtractVector(output_); } +}; + +class QuantizedAddOpModel : public BaseAddOpModel { + public: + using BaseAddOpModel::BaseAddOpModel; + + std::vector GetDequantizedOutput() { + return Dequantize(ExtractVector(output_), + GetScale(output_), GetZeroPoint(output_)); + } +}; + +// for quantized Add, the error shouldn't exceed 2*step +float GetTolerance(int min, int max) { + float kQuantizedStep = (max - min) / 255.0; + float kQuantizedTolerance = 2.0 * kQuantizedStep; + return kQuantizedTolerance; +} + +TEST(FloatAddOpModel, NoActivation) { + FloatAddOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE); + m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.8}); + m.PopulateTensor(m.input2(), {0.1, 0.2, 0.3, 0.5}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1.9, 0.4, 1.0, 1.3})); +} + +TEST(FloatAddOpModel, ActivationRELU1) { + FloatAddOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {}}, ActivationFunctionType_RELU1); + m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.8}); + m.PopulateTensor(m.input2(), {0.1, 0.2, 0.3, 0.5}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1.0, 0.4, 1.0, 1.0})); +} + +TEST(FloatAddOpModel, VariousInputShapes) { + std::vector> test_shapes = { + {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; + for (int i = 0; i < test_shapes.size(); ++i) { + FloatAddOpModel m({TensorType_FLOAT32, test_shapes[i]}, + {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE); + m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.8, 1.1, 2.0}); + m.PopulateTensor(m.input2(), {0.1, 0.2, 0.3, 0.5, 1.1, 0.1}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({-1.9, 0.4, 1.0, 1.3, 2.2, 2.1})) + << "With shape number " << i; + } +} + +TEST(QuantizedAddOpModel, QuantizedTestsNoActivation) { + float kQuantizedTolerance = GetTolerance(-1.0, 1.0); + std::vector> inputs1 = { + {0.1, 0.2, 0.3, 0.4}, {-0.8, 0.2, 0.4, 0.7}, {-0.8, 0.2, 0.7, 0.3}}; + std::vector> inputs2 = { + {0.6, 0.4, 0.3, 0.1}, {0.6, 0.4, 0.5, -0.8}, {0.6, 0.4, -0.8, 0.5}}; + std::vector> results = { + {0.7, 0.6, 0.6, 0.5}, {-0.2, 0.6, 0.9, -0.1}, {-0.2, 0.6, -0.1, 0.8}}; + for (int i = 0; i < inputs1.size(); ++i) { + QuantizedAddOpModel m({TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0}, + {TensorType_UINT8, {}, -1.0, 1.0}, + ActivationFunctionType_NONE); + m.QuantizeAndPopulate(m.input1(), inputs1[i]); + m.QuantizeAndPopulate(m.input2(), inputs2[i]); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear( + results[i], kQuantizedTolerance))) + << "With test number " << i; + } +} + +TEST(QuantizedAddOpModel, QuantizedTestsActivationRELU1) { + float kQuantizedTolerance = GetTolerance(-1.0, 1.0); + std::vector> inputs1 = {{-0.8, 0.2, 0.9, 0.7}, + {-0.8, 0.2, 0.7, 0.3}}; + std::vector> inputs2 = {{0.6, 0.4, 0.9, -0.8}, + {0.6, 0.4, -0.8, 0.5}}; + std::vector> results = {{-0.2, 0.6, 1.0, -0.1}, + {-0.2, 0.6, -0.1, 0.8}}; + for (int i = 0; i < inputs1.size(); ++i) { + QuantizedAddOpModel m({TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0}, + {TensorType_UINT8, {}, -1.0, 1.0}, + ActivationFunctionType_RELU1); + m.QuantizeAndPopulate(m.input1(), inputs1[i]); + m.QuantizeAndPopulate(m.input2(), inputs2[i]); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear( + results[i], kQuantizedTolerance))) + << "With test number " << i; + } +} + +TEST(QuantizedAddOpModel, QuantizedVariousInputShapes) { + float kQuantizedTolerance = GetTolerance(-3.0, 3.0); + std::vector> test_shapes = { + {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; + for (int i = 0; i < test_shapes.size(); ++i) { + QuantizedAddOpModel m({TensorType_UINT8, test_shapes[i], -3.0, 3.0}, + {TensorType_UINT8, {}, -3.0, 3.0}, + ActivationFunctionType_NONE); + m.QuantizeAndPopulate(m.input1(), {-2.0, 0.2, 0.7, 0.8, 1.1, 2.0}); + m.QuantizeAndPopulate(m.input2(), {0.1, 0.3, 0.3, 0.5, 1.1, 0.1}); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({-1.9, 0.5, 1.0, 1.3, 2.2, 2.1}, + kQuantizedTolerance))) + << "With shape number " << i; + } +} + +} // namespace +} // namespace tflite +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/basic_rnn.cc b/tensorflow/contrib/lite/kernels/basic_rnn.cc new file mode 100644 index 0000000000000000000000000000000000000000..3cee43c68b2a0af5a3fd84b33a980b74bb8f0cb4 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/basic_rnn.cc @@ -0,0 +1,161 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/activation_functor.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace rnn { + +constexpr int kInputTensor = 0; +constexpr int kWeightsTensor = 1; +constexpr int kRecurrentWeightsTensor = 2; +constexpr int kBiasTensor = 3; +constexpr int KHiddenStateTensor = 0; +constexpr int kOutputTensor = 1; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + // Check we have all the inputs and outputs we need. + TF_LITE_ENSURE_EQ(context, node->inputs->size, 4); + TF_LITE_ENSURE_EQ(context, node->outputs->size, 2); + + TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]]; + TfLiteTensor* input_weights = + &context->tensors[node->inputs->data[kWeightsTensor]]; + TfLiteTensor* recurrent_weights = + &context->tensors[node->inputs->data[kRecurrentWeightsTensor]]; + TfLiteTensor* bias = &context->tensors[node->inputs->data[kBiasTensor]]; + + // Check all the parameters of tensor match within themselves and match the + // input configuration. + const int batch_size = input->dims->data[0]; + const int num_units = input_weights->dims->data[0]; + TF_LITE_ASSERT_EQ(input->dims->data[1], input_weights->dims->data[1]); + TF_LITE_ASSERT_EQ(input_weights->dims->data[0], bias->dims->data[0]); + TF_LITE_ASSERT_EQ(recurrent_weights->dims->data[0], bias->dims->data[0]); + TF_LITE_ASSERT_EQ(recurrent_weights->dims->data[1], bias->dims->data[0]); + + TfLiteTensor* hidden_state = + &context->tensors[node->outputs->data[KHiddenStateTensor]]; + TfLiteTensor* output = &context->tensors[node->outputs->data[kOutputTensor]]; + + // Resize state. + TfLiteIntArray* hidden_state_size_array = TfLiteIntArrayCreate(2); + hidden_state_size_array->data[0] = batch_size; + hidden_state_size_array->data[1] = num_units; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, hidden_state, + hidden_state_size_array)); + + // Mark hidden state as a persistent tensor. + hidden_state->allocation_type = kTfLiteArenaRwPersistent; + + // Resize output. + TfLiteIntArray* output_size_array = TfLiteIntArrayCreate(2); + output_size_array->data[0] = batch_size; + output_size_array->data[1] = num_units; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output, + output_size_array)); + + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + + TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]]; + TfLiteTensor* input_weights = + &context->tensors[node->inputs->data[kWeightsTensor]]; + TfLiteTensor* recurrent_weights = + &context->tensors[node->inputs->data[kRecurrentWeightsTensor]]; + TfLiteTensor* bias = &context->tensors[node->inputs->data[kBiasTensor]]; + TfLiteTensor* hidden_state = + &context->tensors[node->outputs->data[KHiddenStateTensor]]; + TfLiteTensor* output = &context->tensors[node->outputs->data[kOutputTensor]]; + + // Initialize the pointer bias. + const float* bias_ptr = bias->data.f; + + const int batch_size = input->dims->data[0]; + const int num_units = input_weights->dims->data[0]; + const int input_size = input->dims->data[1]; + const int input_weights_stride = input_weights->dims->data[1]; + const int recurrent_weights_stride = recurrent_weights->dims->data[1]; + + // For each batch + for (int b = 0; b < batch_size; b++) { + // Initialize the pointer to input, output and bias. + const float* input_ptr_batch = input->data.f + b * input_size; + float* output_ptr_batch = output->data.f + b * num_units; + float* hidden_state_ptr_batch = hidden_state->data.f + b * num_units; + + // Initialize input_weights and recurrent_weights. + const float* input_weights_ptr = input_weights->data.f; + const float* recurrent_weights_ptr = recurrent_weights->data.f; + + // Output = bias + for (int o = 0; o < num_units; o++) { + output_ptr_batch[o] = bias_ptr[o]; + } + + // Output += input * input_weights + for (int o = 0; o < num_units; o++) { + for (int i = 0; i < input_size; i++) { + output_ptr_batch[o] += input_ptr_batch[i] * input_weights_ptr[i]; + } + input_weights_ptr += input_weights_stride; + } + + // Output += recurrent_weights * hidden_state + for (int o = 0; o < num_units; o++) { + for (int h = 0; h < num_units; h++) { + output_ptr_batch[o] += + hidden_state_ptr_batch[h] * recurrent_weights_ptr[h]; + } + recurrent_weights_ptr += recurrent_weights_stride; + } + + // Output = activation(Output) and update hidden_state + for (int o = 0; o < num_units; o++) { + output_ptr_batch[o] = + (ActivationFunctor(params->activation))(output_ptr_batch[o]); + hidden_state_ptr_batch[o] = output_ptr_batch[o]; + } + } + + return kTfLiteOk; +} + +} // namespace rnn + +TfLiteRegistration* Register_RNN() { + static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, + rnn::Prepare, rnn::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/basic_rnn_test.cc b/tensorflow/contrib/lite/kernels/basic_rnn_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..5ecccb985e91238f1183c8f94a2b5f468758ce55 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/basic_rnn_test.cc @@ -0,0 +1,267 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Unit test for TFLite RNN op. + +#include +#include + +#include +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +static float rnn_input[] = { + 0.23689353, 0.285385, 0.037029743, -0.19858193, -0.27569133, + 0.43773448, 0.60379338, 0.35562468, -0.69424844, -0.93421471, + -0.87287879, 0.37144363, -0.62476718, 0.23791671, 0.40060222, + 0.1356622, -0.99774903, -0.98858172, -0.38952237, -0.47685933, + 0.31073618, 0.71511042, -0.63767755, -0.31729108, 0.33468103, + 0.75801885, 0.30660987, -0.37354088, 0.77002847, -0.62747043, + -0.68572164, 0.0069220066, 0.65791464, 0.35130811, 0.80834007, + -0.61777675, -0.21095741, 0.41213346, 0.73784804, 0.094794154, + 0.47791874, 0.86496925, -0.53376222, 0.85315156, 0.10288584, + 0.86684, -0.011186242, 0.10513687, 0.87825835, 0.59929144, + 0.62827742, 0.18899453, 0.31440187, 0.99059987, 0.87170351, + -0.35091716, 0.74861872, 0.17831337, 0.2755419, 0.51864719, + 0.55084288, 0.58982027, -0.47443086, 0.20875752, -0.058871567, + -0.66609079, 0.59098077, 0.73017097, 0.74604273, 0.32882881, + -0.17503482, 0.22396147, 0.19379807, 0.29120302, 0.077113032, + -0.70331609, 0.15804303, -0.93407321, 0.40182066, 0.036301374, + 0.66521823, 0.0300982, -0.7747041, -0.02038002, 0.020698071, + -0.90300065, 0.62870288, -0.23068321, 0.27531278, -0.095755219, + -0.712036, -0.17384434, -0.50593495, -0.18646687, -0.96508682, + 0.43519354, 0.14744234, 0.62589407, 0.1653645, -0.10651493, + -0.045277178, 0.99032974, -0.88255352, -0.85147917, 0.28153265, + 0.19455957, -0.55479527, -0.56042433, 0.26048636, 0.84702539, + 0.47587705, -0.074295521, -0.12287641, 0.70117295, 0.90532446, + 0.89782166, 0.79817224, 0.53402734, -0.33286154, 0.073485017, + -0.56172788, -0.044897556, 0.89964068, -0.067662835, 0.76863563, + 0.93455386, -0.6324693, -0.083922029}; + +static float rnn_golden_output[] = { + 0.496726, 0, 0.965996, 0, 0.0584254, 0, + 0, 0.12315, 0, 0, 0.612266, 0.456601, + 0, 0.52286, 1.16099, 0.0291232, + + 0, 0, 0.524901, 0, 0, 0, + 0, 1.02116, 0, 1.35762, 0, 0.356909, + 0.436415, 0.0355727, 0, 0, + + 0, 0, 0, 0.262335, 0, 0, + 0, 1.33992, 0, 2.9739, 0, 0, + 1.31914, 2.66147, 0, 0, + + 0.942568, 0, 0, 0, 0.025507, 0, + 0, 0, 0.321429, 0.569141, 1.25274, 1.57719, + 0.8158, 1.21805, 0.586239, 0.25427, + + 1.04436, 0, 0.630725, 0, 0.133801, 0.210693, + 0.363026, 0, 0.533426, 0, 1.25926, 0.722707, + 0, 1.22031, 1.30117, 0.495867, + + 0.222187, 0, 0.72725, 0, 0.767003, 0, + 0, 0.147835, 0, 0, 0, 0.608758, + 0.469394, 0.00720298, 0.927537, 0, + + 0.856974, 0.424257, 0, 0, 0.937329, 0, + 0, 0, 0.476425, 0, 0.566017, 0.418462, + 0.141911, 0.996214, 1.13063, 0, + + 0.967899, 0, 0, 0, 0.0831304, 0, + 0, 1.00378, 0, 0, 0, 1.44818, + 1.01768, 0.943891, 0.502745, 0, + + 0.940135, 0, 0, 0, 0, 0, + 0, 2.13243, 0, 0.71208, 0.123918, 1.53907, + 1.30225, 1.59644, 0.70222, 0, + + 0.804329, 0, 0.430576, 0, 0.505872, 0.509603, + 0.343448, 0, 0.107756, 0.614544, 1.44549, 1.52311, + 0.0454298, 0.300267, 0.562784, 0.395095, + + 0.228154, 0, 0.675323, 0, 1.70536, 0.766217, + 0, 0, 0, 0.735363, 0.0759267, 1.91017, + 0.941888, 0, 0, 0, + + 0, 0, 1.5909, 0, 0, 0, + 0, 0.5755, 0, 0.184687, 0, 1.56296, + 0.625285, 0, 0, 0, + + 0, 0, 0.0857888, 0, 0, 0, + 0, 0.488383, 0.252786, 0, 0, 0, + 1.02817, 1.85665, 0, 0, + + 0.00981836, 0, 1.06371, 0, 0, 0, + 0, 0, 0, 0.290445, 0.316406, 0, + 0.304161, 1.25079, 0.0707152, 0, + + 0.986264, 0.309201, 0, 0, 0, 0, + 0, 1.64896, 0.346248, 0, 0.918175, 0.78884, + 0.524981, 1.92076, 2.07013, 0.333244, + + 0.415153, 0.210318, 0, 0, 0, 0, + 0, 2.02616, 0, 0.728256, 0.84183, 0.0907453, + 0.628881, 3.58099, 1.49974, 0 +}; + +class RNNOpModel : public SingleOpModel { + public: + RNNOpModel(int batches, int units, int size) + : batches_(batches), units_(units), input_size_(size) { + input_ = AddInput(TensorType_FLOAT32); + weights_ = AddInput(TensorType_FLOAT32); + recurrent_weights_ = AddInput(TensorType_FLOAT32); + bias_ = AddInput(TensorType_FLOAT32); + hidden_state_ = AddOutput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp( + BuiltinOperator_RNN, BuiltinOptions_RNNOptions, + CreateRNNOptions(builder_, ActivationFunctionType_RELU).Union()); + BuildInterpreter({{batches_, input_size_}, + {units_, input_size_}, + {units_, units_}, + {units_}}); + } + + void SetBias(std::initializer_list f) { PopulateTensor(bias_, f); } + + void SetWeights(std::initializer_list f) { + PopulateTensor(weights_, f); + } + + void SetRecurrentWeights(std::initializer_list f) { + PopulateTensor(recurrent_weights_, f); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + void SetInput(int offset, float* begin, float* end) { + PopulateTensor(input_, offset, begin, end); + } + + void ResetHiddenState() { + const int zero_buffer_size = units_ * batches_; + std::unique_ptr zero_buffer(new float[zero_buffer_size]); + memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float)); + PopulateTensor(hidden_state_, 0, zero_buffer.get(), + zero_buffer.get() + zero_buffer_size); + } + + std::vector GetOutput() { return ExtractVector(output_); } + + int input_size() { return input_size_; } + int num_units() { return units_; } + int num_batches() { return batches_; } + + private: + int input_; + int weights_; + int recurrent_weights_; + int bias_; + int hidden_state_; + int output_; + + int batches_; + int units_; + int input_size_; +}; + +TEST(FullyConnectedOpTest, BlackBoxTest) { + RNNOpModel rnn(2, 16, 8); + rnn.SetWeights( + {0.461459, 0.153381, 0.529743, -0.00371218, 0.676267, -0.211346, + 0.317493, 0.969689, -0.343251, 0.186423, 0.398151, 0.152399, + 0.448504, 0.317662, 0.523556, -0.323514, 0.480877, 0.333113, + -0.757714, -0.674487, -0.643585, 0.217766, -0.0251462, 0.79512, + -0.595574, -0.422444, 0.371572, -0.452178, -0.556069, -0.482188, + -0.685456, -0.727851, 0.841829, 0.551535, -0.232336, 0.729158, + -0.00294906, -0.69754, 0.766073, -0.178424, 0.369513, -0.423241, + 0.548547, -0.0152023, -0.757482, -0.85491, 0.251331, -0.989183, + 0.306261, -0.340716, 0.886103, -0.0726757, -0.723523, -0.784303, + 0.0354295, 0.566564, -0.485469, -0.620498, 0.832546, 0.697884, + -0.279115, 0.294415, -0.584313, 0.548772, 0.0648819, 0.968726, + 0.723834, -0.0080452, -0.350386, -0.272803, 0.115121, -0.412644, + -0.824713, -0.992843, -0.592904, -0.417893, 0.863791, -0.423461, + -0.147601, -0.770664, -0.479006, 0.654782, 0.587314, -0.639158, + 0.816969, -0.337228, 0.659878, 0.73107, 0.754768, -0.337042, + 0.0960841, 0.368357, 0.244191, -0.817703, -0.211223, 0.442012, + 0.37225, -0.623598, -0.405423, 0.455101, 0.673656, -0.145345, + -0.511346, -0.901675, -0.81252, -0.127006, 0.809865, -0.721884, + 0.636255, 0.868989, -0.347973, -0.10179, -0.777449, 0.917274, + 0.819286, 0.206218, -0.00785118, 0.167141, 0.45872, 0.972934, + -0.276798, 0.837861, 0.747958, -0.0151566, -0.330057, -0.469077, + 0.277308, 0.415818}); + + rnn.SetBias({0.065691948, -0.69055247, 0.1107955, -0.97084129, -0.23957068, + -0.23566568, -0.389184, 0.47481549, -0.4791103, 0.29931796, + 0.10463274, 0.83918178, 0.37197268, 0.61957061, 0.3956964, + -0.37609905}); + + rnn.SetRecurrentWeights({0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1}); + + rnn.ResetHiddenState(); + const int input_sequence_size = sizeof(rnn_input) / sizeof(float) / + (rnn.input_size() * rnn.num_batches()); + + for (int i = 0; i < input_sequence_size; i++) { + float* batch_start = rnn_input + i * rnn.input_size(); + float* batch_end = batch_start + rnn.input_size(); + rnn.SetInput(0, batch_start, batch_end); + rnn.SetInput(rnn.input_size(), batch_start, batch_end); + + rnn.Invoke(); + + float* golden_start = rnn_golden_output + i * rnn.num_units(); + float* golden_end = golden_start + rnn.num_units(); + std::vector expected; + expected.insert(expected.end(), golden_start, golden_end); + expected.insert(expected.end(), golden_start, golden_end); + + EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); + } +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/concatenation.cc b/tensorflow/contrib/lite/kernels/concatenation.cc new file mode 100644 index 0000000000000000000000000000000000000000..9e7a1233dac0f3cd02dc386f9d194597f38ca3b8 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/concatenation.cc @@ -0,0 +1,200 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace concatenation { + +// This file has two implementation of Concatenation. +enum KernelType { + kReference, + kGenericOptimized, +}; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + auto* params = + reinterpret_cast(node->builtin_data); + int axis = params->axis; + int num_inputs = node->inputs->size; + + // The number of dimensions of the input tensors must match, and all + // dimensions except 'axis' must be equal. + TfLiteTensor* t0 = &context->tensors[node->inputs->data[0]]; + TfLiteType input_type = t0->type; + TF_LITE_ENSURE(context, axis >= 0); + TF_LITE_ENSURE(context, axis < t0->dims->size); + + // TODO(ahentz): These are limitations of our implementation that could be + // removed with a bit of effort. + TF_LITE_ENSURE(context, t0->dims->size <= 4); + TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActNone); + TF_LITE_ENSURE(context, + input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8); + + // Output dimensions will match input dimensions, except 'axis', which + // will be the sum of inputs + int sum_axis = t0->dims->data[axis]; + for (int i = 1; i < num_inputs; ++i) { + TfLiteTensor* t = &context->tensors[node->inputs->data[i]]; + TF_LITE_ENSURE_EQ(context, t->dims->size, t0->dims->size); + TF_LITE_ENSURE_EQ(context, t->type, input_type); + if (input_type == kTfLiteUInt8) { + TF_LITE_ENSURE_EQ(context, t->params.zero_point, t0->params.zero_point); + TF_LITE_ENSURE_EQ(context, t->params.scale, t0->params.scale); + } + for (int d = 0; d < t0->dims->size; ++d) { + if (d == axis) { + sum_axis += t->dims->data[axis]; + } else { + TF_LITE_ENSURE_EQ(context, t->dims->data[d], t0->dims->data[d]); + } + } + } + + TfLiteIntArray* output_size = TfLiteIntArrayCreate(t0->dims->size); + for (int d = 0; d < t0->dims->size; ++d) { + output_size->data[d] = (d == axis) ? sum_axis : t0->dims->data[d]; + } + + TfLiteTensor* output = &context->tensors[node->outputs->data[0]]; + TF_LITE_ENSURE_EQ(context, output->type, input_type); + if (input_type == kTfLiteUInt8) { + TF_LITE_ENSURE_EQ(context, output->params.zero_point, + t0->params.zero_point); + TF_LITE_ENSURE_EQ(context, output->params.scale, t0->params.scale); + } + + return context->ResizeTensor(context, output, output_size); +} + +template +class VectorOfInputs { + public: + VectorOfInputs(const TfLiteContext& context, const TfLiteIntArray& inputs) { + int num_inputs = inputs.size; + + all_data_.reserve(num_inputs); + all_dims_.reserve(num_inputs); + all_dims_ptr_.reserve(num_inputs); + + for (int i = 0; i < num_inputs; ++i) { + TfLiteTensor* input = &context.tensors[inputs.data[i]]; + all_data_.push_back(GetTensorData(input)); + all_dims_.push_back(GetTensorDims(input)); + } + + // Taking the pointer from inside a std::vector is only OK if the vector is + // never modified, so we populate all_dims in the previous loop and then we + // are free to grab iterators here. + for (int i = 0; i < num_inputs; ++i) { + all_dims_ptr_.push_back(&all_dims_[i]); + } + } + const T* const* data() const { return all_data_.data(); } + const Dims<4>* const* dims() const { return all_dims_ptr_.data(); } + + private: + std::vector all_data_; + std::vector> all_dims_; + std::vector*> all_dims_ptr_; +}; + +template +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = + reinterpret_cast(node->builtin_data); + + TfLiteTensor* output = &context->tensors[node->outputs->data[0]]; + +// TODO(ahentz): Creating 'all_inputs' below is not very efficient. We should +// allocate and populate these during Prepare(). +// TODO(ycling): Activation function parameter is ignored. For now we dont have +// a model with a Concatenation with fused activation function. +#define TF_LITE_CONCATENATION(type, scalar) \ + VectorOfInputs all_inputs(*context, *node->inputs); \ + type::Concatenation( \ + RemapDim(NumDimensions(output), params->axis), all_inputs.data(), \ + all_inputs.dims(), node->inputs->size, GetTensorData(output), \ + GetTensorDims(output)) + + switch (output->type) { // Already know in/outtypes are same. + case kTfLiteFloat32: + if (kernel_type == kReference) { + TF_LITE_CONCATENATION(reference_ops, float); + } else { + TF_LITE_CONCATENATION(optimized_ops, float); + } + break; + case kTfLiteUInt8: + if (kernel_type == kReference) { + TF_LITE_CONCATENATION(reference_ops, uint8_t); + } else { + TF_LITE_CONCATENATION(optimized_ops, uint8_t); + } + break; + default: + context->ReportError(context, + "Only float32 and uint8 are currently supported."); + return kTfLiteError; + } + +#undef TF_LITE_CONCATENATION + + return kTfLiteOk; +} + +#undef TF_LITE_MACRO_DISPATCH + +} // namespace concatenation + +TfLiteRegistration* Register_CONCATENATION_REF() { + static TfLiteRegistration r = { + nullptr, nullptr, concatenation::Prepare, + concatenation::Eval}; + return &r; +} + +TfLiteRegistration* Register_CONCATENATION_GENERIC_OPT() { + static TfLiteRegistration r = { + nullptr, nullptr, concatenation::Prepare, + concatenation::Eval}; + return &r; +} + +TfLiteRegistration* Register_CONCATENATION() { + // TODO(ahentz): It turns out the two versions of Concatenation are almost + // identical, so we should consider removing one. + return Register_CONCATENATION_GENERIC_OPT(); +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/concatenation_test.cc b/tensorflow/contrib/lite/kernels/concatenation_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..499856a93cbbfbf9aa1a326912e52ce32bbbdf83 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/concatenation_test.cc @@ -0,0 +1,162 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class BaseConcatenationOpModel : public SingleOpModel { + public: + // TODO(ahentz): Also test different activation types, axis, input + // dimensions. + BaseConcatenationOpModel(const TensorData& input_template, int axis, + int num_inputs) { + std::vector> all_input_shapes; + for (int i = 0; i < num_inputs; ++i) { + all_input_shapes.push_back(input_template.shape); + AddInput(input_template); + } + output_ = AddOutput({input_template.type, /*shape=*/{}, input_template.min, + input_template.max}); + SetBuiltinOp( + BuiltinOperator_CONCATENATION, BuiltinOptions_ConcatenationOptions, + CreateConcatenationOptions(builder_, axis, ActivationFunctionType_NONE) + .Union()); + BuildInterpreter(all_input_shapes); + } + + protected: + int output_; +}; + +class ConcatenationOpModel : public BaseConcatenationOpModel { + public: + using BaseConcatenationOpModel::BaseConcatenationOpModel; + void SetInput(int index, std::initializer_list data) { + PopulateTensor(index, data); + } + std::vector GetOutput() { return ExtractVector(output_); } +}; + +class QuantizedConcatenationOpModel : public BaseConcatenationOpModel { + public: + using BaseConcatenationOpModel::BaseConcatenationOpModel; + void SetInput(int index, std::initializer_list data) { + QuantizeAndPopulate(index, data); + } + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetDequantizedOutput() { + return Dequantize(ExtractVector(output_), + GetScale(output_), GetZeroPoint(output_)); + } +}; + +TEST(ConcatenationOpTest, ThreeDimensionalOneInput) { + ConcatenationOpModel m0({TensorType_FLOAT32, {2, 1, 2}}, /*axis=*/1, + /*num_inputs=*/1); + m0.SetInput(0, {1.0f, 3.0f, 4.0f, 7.0f}); + m0.Invoke(); + EXPECT_THAT(m0.GetOutput(), ElementsAreArray({1, 3, 4, 7})); +} + +TEST(ConcatenationOpTest, OneTrivialInput) { + ConcatenationOpModel m0({TensorType_FLOAT32, {1}}, /*axis=*/0, + /*num_inputs=*/1); + m0.SetInput(0, {5.0f}); + m0.Invoke(); + EXPECT_THAT(m0.GetOutput(), ::testing::ElementsAre(5)); +} + +TEST(ConcatenationOpTest, TwoDimensionalOneInput) { + ConcatenationOpModel m0({TensorType_FLOAT32, {2, 3}}, /*axis=*/0, + /*num_inputs=*/1); + m0.SetInput(0, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}); + m0.Invoke(); + EXPECT_THAT(m0.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6})); +} + +TEST(ConcatenationOpTest, TwoInputsTwoAxis) { + // We will concatenate two tensors along different dimensions. + auto tensor0 = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + auto tensor1 = {7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f}; + + ConcatenationOpModel m0({TensorType_FLOAT32, {2, 3}}, /*axis=*/0, + /*num_inputs=*/2); + m0.SetInput(0, tensor0); + m0.SetInput(1, tensor1); + m0.Invoke(); + EXPECT_THAT(m0.GetOutput(), + ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12})); + + ConcatenationOpModel m1({TensorType_FLOAT32, {2, 3}}, /*axis=*/1, + /*num_inputs=*/2); + m1.SetInput(0, tensor0); + m1.SetInput(1, tensor1); + m1.Invoke(); + EXPECT_THAT(m1.GetOutput(), + ElementsAreArray({1, 2, 3, 7, 8, 9, 4, 5, 6, 10, 11, 12})); +} + +TEST(ConcatenationOpTest, FourInputs) { + ConcatenationOpModel m0({TensorType_FLOAT32, {2, 1, 2}}, /*axis=*/2, + /*num_inputs=*/4); + m0.SetInput(0, {1.0f, 3.0f, 4.0f, 7.0f}); + m0.SetInput(1, {1.1f, 3.1f, 4.1f, 7.1f}); + m0.SetInput(2, {1.2f, 3.2f, 4.2f, 7.2f}); + m0.SetInput(3, {1.3f, 3.3f, 4.3f, 7.3f}); + m0.Invoke(); + EXPECT_THAT(m0.GetOutput(), + ElementsAreArray({ + 1.0f, 3.0f, 1.1f, 3.1f, 1.2f, 3.2f, 1.3f, 3.3f, // + 4.0f, 7.0f, 4.1f, 7.1f, 4.2f, 7.2f, 4.3f, 7.3f, // + })); +} + +TEST(ConcatenationOpTest, FourInputsQuantized) { + QuantizedConcatenationOpModel m0({TensorType_UINT8, {2, 1, 2}, -12.7, 12.8}, + /*axis=*/2, + /*num_inputs=*/4); + + m0.SetInput(0, {1.0f, 3.0f, 4.0f, 7.0f}); + m0.SetInput(1, {1.1f, 3.1f, 4.1f, 7.1f}); + m0.SetInput(2, {1.2f, 3.2f, 4.2f, 7.2f}); + m0.SetInput(3, {1.3f, 3.3f, 4.3f, 7.3f}); + m0.Invoke(); + EXPECT_THAT(m0.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({ + 1.0f, 3.0f, 1.1f, 3.1f, 1.2f, 3.2f, 1.3f, 3.3f, // + 4.0f, 7.0f, 4.1f, 7.1f, 4.2f, 7.2f, 4.3f, 7.3f, // + }))); + EXPECT_THAT(m0.GetOutput(), ElementsAreArray({ + 137, 157, 138, 158, 139, 159, 140, 160, // + 167, 197, 168, 198, 169, 199, 170, 200, // + })); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/conv.cc b/tensorflow/contrib/lite/kernels/conv.cc new file mode 100644 index 0000000000000000000000000000000000000000..c75c04baeac2ce53c6261d677dca8d72fafa0da5 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/conv.cc @@ -0,0 +1,425 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/gemm_support.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/contrib/lite/kernels/padding.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace conv { + +// This file has three implementation of Conv. +enum KernelType { + kReference, + kGenericOptimized, // Neon-free + kNeonOptimized, +}; + +struct OpData { + // IDs are the arbitrary identifiers used by TF Lite to identify and access + // memory buffers. + int im2col_id; + int hwcn_weights_id; + + TfLitePaddingValues padding; + // The scaling factor from input to output (aka the 'real multiplier') can + // be represented as a fixed point multipler plus a left shift. + int32_t output_multiplier; + int output_shift; + // The range of the fused activation layer. For example for kNone and + // uint8_t these would be 0 and 255. + int32_t output_activation_min; + int32_t output_activation_max; + // Indexes are the offset to the memory buffer in the array used to keep track + // of the allocated temporaries. + int32_t im2col_index; + int32_t hwcn_weights_index; + bool need_hwcn_weights; + bool have_weights_been_transposed; + bool need_im2col; +}; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + // This is a builtin op, so we don't use the contents in 'buffer', if any. + // Instead, we allocate a new object to use as scratch space for im2col, and + // to carry information from Prepare() to Eval(). + auto* data = new OpData; + context->AddTensors(context, 1, &data->im2col_id); + context->AddTensors(context, 1, &data->hwcn_weights_id); + gemm_support::IncrementUsageCounter(context); + return data; +} + +void Free(TfLiteContext* context, void* buffer) { + gemm_support::DecrementUsageCounter(context); + delete reinterpret_cast(buffer); +} + +// Naive implementation of transpose for floats. Could be optimized to be more +// cache friendly, but for now it's a one-time cost on first run, and we would +// prefer to remove the need to do this at all eventually. +void TransposeFloatTensor(TfLiteTensor* input, TfLiteTensor* output) { + const int rows = output->dims->data[1]; + const int cols = output->dims->data[0]; + const float* input_data = GetTensorData(input); + float* output_data = GetTensorData(output); + for (int i = 0; i < rows; ++i) { + for (int j = 0; j < cols; ++j) { + const float in_value = input_data[i * cols + j]; + output_data[j * rows + i] = in_value; + } + } +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + OpData* data = reinterpret_cast(node->user_data); + + bool hasBias = node->inputs->size == 3; + // Check number of inputs/outputs + TF_LITE_ENSURE(context, hasBias || node->inputs->size == 2); + TF_LITE_ENSURE_EQ(context, node->outputs->size, 1); + TfLiteTensor* output = &context->tensors[node->outputs->data[0]]; + TfLiteTensor* input = &context->tensors[node->inputs->data[0]]; + TfLiteTensor* filter = &context->tensors[node->inputs->data[1]]; + // Check dimensionality of input, filter + TF_LITE_ENSURE_EQ(context, input->dims->size, 4); + TF_LITE_ENSURE_EQ(context, filter->dims->size, 4); + // Check input channels matching filter + TF_LITE_ENSURE_EQ(context, input->dims->data[3], filter->dims->data[3]); + + // Check types. (We assume that UINT8 refers to quantized tensors) + TfLiteType data_type = input->type; + TF_LITE_ENSURE(context, + data_type == kTfLiteFloat32 || data_type == kTfLiteUInt8); + TF_LITE_ENSURE_EQ(context, output->type, data_type); + TF_LITE_ENSURE_EQ(context, filter->type, data_type); + + TfLiteTensor* bias = nullptr; + + // TODO(ahentz): At this point the optimized versions require 'bias'. We can + // either change that or document that convolution requires it. + TF_LITE_ENSURE(context, hasBias); + + if (hasBias) { + bias = &context->tensors[node->inputs->data[2]]; + if (data_type == kTfLiteUInt8) { + TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteInt32); + TF_LITE_ENSURE_EQ(context, bias->params.zero_point, 0); + } else { + TF_LITE_ENSURE_EQ(context, bias->type, data_type); + } + TF_LITE_ENSURE_EQ(context, bias->dims->size, 1); + TF_LITE_ENSURE_EQ(context, bias->dims->data[0], filter->dims->data[0]); + } + + int channels_out = filter->dims->data[0]; + int width = input->dims->data[2]; + int height = input->dims->data[1]; + int filter_width = filter->dims->data[2]; + int filter_height = filter->dims->data[1]; + int batches = input->dims->data[0]; + + // Matching GetWindowedOutputSize in TensorFlow. + auto padding = params->padding; + auto computeOutSize = [padding](int imageSize, int filterSize, + int stride) -> int { + return padding == kTfLitePaddingSame + ? (imageSize + stride - 1) / stride + : padding == kTfLitePaddingValid + ? (imageSize - filterSize + stride) / stride + : 0; + }; + + int outWidth = computeOutSize(width, filter_width, params->stride_width); + int outHeight = computeOutSize(height, filter_height, params->stride_height); + + data->padding.height = + ComputePadding(params->stride_height, height, filter_height, outHeight); + data->padding.width = + ComputePadding(params->stride_width, width, filter_width, outWidth); + + TF_LITE_ENSURE(context, hasBias); + + // Note that quantized inference requires that all tensors have their + // parameters set. This is usually done during quantized training. + if (data_type != kTfLiteFloat32) { + double real_multiplier = 0.0; + TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler( + context, input, filter, bias, output, &real_multiplier)); + QuantizeMultiplierSmallerThanOne(real_multiplier, &data->output_multiplier, + &data->output_shift); + CalculateActivationRangeUint8(params->activation, output, + &data->output_activation_min, + &data->output_activation_max); + } + + TfLiteIntArray* output_size = TfLiteIntArrayCreate(4); + output_size->data[0] = batches; + output_size->data[1] = outHeight; + output_size->data[2] = outWidth; + output_size->data[3] = channels_out; + auto output_status = context->ResizeTensor(context, output, output_size); + + if (output_status != kTfLiteOk) return output_status; + + // We don't always need to allocate im2col. It is only used in some versions + // of the optimized Conv. This test just mimics something that happens inside + // optimized_ops.h, in order to avoid a DCHECK(!im2col_data). + data->need_im2col = + (params->stride_width != 1 || params->stride_height != 1 || + filter_width != 1 || filter_height != 1); + // If we're using the optimized multithreaded EigenTensor implementation of + // convolution, it expects the filter weights to be transposed compared to + // the normal TF Lite buffer format. Typical TF Lite weights are + // [filter_count, filter_height, filter_width, input_depth], but for the float + // implementation we need them as [filter_height, filter_width, input_depth, + // filter_count]. We get to that format by transposing, and create a temporary + // buffer to store the results. + // This path is only used for float processing, so only create the buffer if + // we're running with that data type. + data->need_hwcn_weights = (data_type == kTfLiteFloat32); + + int temporaries_count = 0; + if (data->need_im2col) { + data->im2col_index = temporaries_count; + ++temporaries_count; + } + if (data->need_hwcn_weights) { + data->hwcn_weights_index = temporaries_count; + ++temporaries_count; + } + + TfLiteIntArrayFree(node->temporaries); + node->temporaries = TfLiteIntArrayCreate(temporaries_count); + + if (data->need_im2col) { + node->temporaries->data[data->im2col_index] = data->im2col_id; + + TfLiteIntArray* im2col_size = TfLiteIntArrayCreate(4); + + int input_depth = input->dims->data[3]; + im2col_size->data[0] = output_size->data[0]; + im2col_size->data[1] = output_size->data[1]; + im2col_size->data[2] = output_size->data[2]; + im2col_size->data[3] = input_depth * filter_height * filter_width; + + TfLiteTensor* im2col = + &context->tensors[node->temporaries->data[data->im2col_index]]; + im2col->type = data_type; + im2col->allocation_type = kTfLiteArenaRw; + auto im2col_status = context->ResizeTensor(context, im2col, im2col_size); + if (im2col_status != kTfLiteOk) return im2col_status; + } + + if (data->need_hwcn_weights) { + node->temporaries->data[data->hwcn_weights_index] = data->hwcn_weights_id; + TfLiteIntArray* hwcn_weights_size = TfLiteIntArrayCreate(2); + + // Because we're treating the filter weights as a matrix when we do the + // transpose, we allocate the buffer with a two-dimensional shape, where one + // dimension is the number of elements in each filter, and the second is the + // total number of filters. + int input_depth = input->dims->data[3]; + hwcn_weights_size->data[0] = (filter_height * filter_width * input_depth); + hwcn_weights_size->data[1] = channels_out; + + TfLiteTensor* hwcn_weights = + &context->tensors[node->temporaries->data[data->hwcn_weights_index]]; + hwcn_weights->type = data_type; + hwcn_weights->allocation_type = kTfLiteDynamic; + // Make sure we release any previous allocations before we reallocate. + // TODO(petewarden): Persistent arenas would be a better fit for this, but + // they aren't fully implemented yet. + if (hwcn_weights->data.raw) { + free(hwcn_weights->data.raw); + hwcn_weights->data.raw = nullptr; + } + auto hwcn_weights_status = + context->ResizeTensor(context, hwcn_weights, hwcn_weights_size); + if (hwcn_weights_status != kTfLiteOk) return hwcn_weights_status; + hwcn_weights->data.raw = static_cast(malloc(hwcn_weights->bytes)); + + // TODO(petewarden): If Resize() is called when the size hasn't actually + // changed, this will do extra redundant work. + data->have_weights_been_transposed = false; + } + + return kTfLiteOk; +} + +template +void EvalQuantized(TfLiteContext* context, TfLiteNode* node, + TfLiteConvParams* params, OpData* data, TfLiteTensor* input, + TfLiteTensor* filter, TfLiteTensor* bias, + TfLiteTensor* im2col, TfLiteTensor* hwcn_weights, + TfLiteTensor* output) { + gemmlowp::GemmContext* gemm_context = gemm_support::GetFromContext(context); + + auto input_offset = -input->params.zero_point; + auto filter_offset = -filter->params.zero_point; + auto output_offset = output->params.zero_point; + + if (kernel_type == kReference) { + reference_ops::Conv( + GetTensorData(input), GetTensorDims(input), input_offset, + GetTensorData(filter), GetTensorDims(filter), filter_offset, + GetTensorData(bias), GetTensorDims(bias), params->stride_width, + params->stride_height, data->padding.width, data->padding.height, + output_offset, data->output_multiplier, data->output_shift, + data->output_activation_min, data->output_activation_max, + GetTensorData(output), GetTensorDims(output), + GetTensorData(im2col), GetTensorDims(im2col), gemm_context); + } else { + optimized_ops::Conv( + GetTensorData(input), GetTensorDims(input), input_offset, + GetTensorData(filter), GetTensorDims(filter), filter_offset, + GetTensorData(bias), GetTensorDims(bias), params->stride_width, + params->stride_height, data->padding.width, data->padding.height, + output_offset, data->output_multiplier, data->output_shift, + data->output_activation_min, data->output_activation_max, + GetTensorData(output), GetTensorDims(output), + GetTensorData(im2col), GetTensorDims(im2col), gemm_context); + } +} + +template +void EvalFloat(TfLiteContext* context, TfLiteNode* node, + TfLiteConvParams* params, OpData* data, TfLiteTensor* input, + TfLiteTensor* filter, TfLiteTensor* bias, TfLiteTensor* im2col, + TfLiteTensor* hwcn_weights, TfLiteTensor* output) { + float output_activation_min, output_activation_max; + CalculateActivationRangeFloat(params->activation, &output_activation_min, + &output_activation_max); + + const float* filter_data; + if (data->need_hwcn_weights) { + filter_data = GetTensorData(hwcn_weights); + } else { + filter_data = GetTensorData(filter); + } + + if (kernel_type == kReference) { + reference_ops::Conv( + GetTensorData(input), GetTensorDims(input), filter_data, + GetTensorDims(filter), GetTensorData(bias), GetTensorDims(bias), + params->stride_width, params->stride_height, data->padding.width, + data->padding.height, output_activation_min, output_activation_max, + GetTensorData(output), GetTensorDims(output), + GetTensorData(im2col), GetTensorDims(im2col)); + } else { + multithreaded_ops::Conv( + GetTensorData(input), GetTensorDims(input), filter_data, + GetTensorDims(filter), GetTensorData(bias), GetTensorDims(bias), + params->stride_width, params->stride_height, data->padding.width, + data->padding.height, params->padding, output_activation_min, + output_activation_max, GetTensorData(output), + GetTensorDims(output), GetTensorData(im2col), + GetTensorDims(im2col)); + } +} + +template +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + OpData* data = reinterpret_cast(node->user_data); + + TfLiteTensor* output = &context->tensors[node->outputs->data[0]]; + TfLiteTensor* input = &context->tensors[node->inputs->data[0]]; + TfLiteTensor* filter = &context->tensors[node->inputs->data[1]]; + bool hasBias = node->inputs->size == 3; + TfLiteTensor* bias = + hasBias ? &context->tensors[node->inputs->data[2]] : nullptr; + TfLiteTensor* im2col = + data->need_im2col + ? &context->tensors[node->temporaries->data[data->im2col_index]] + : nullptr; + TfLiteTensor* hwcn_weights = + data->need_hwcn_weights + ? &context->tensors[node->temporaries->data[data->hwcn_weights_index]] + : nullptr; + + if (data->need_hwcn_weights && !data->have_weights_been_transposed) { + TransposeFloatTensor(filter, hwcn_weights); + data->have_weights_been_transposed = true; + } + + // TODO(aselle): Consider whether float conv and quantized conv should be + // separate ops to avoid dispatch overhead here. + switch (input->type) { // Already know in/outtypes are same. + case kTfLiteFloat32: + EvalFloat(context, node, params, data, input, filter, bias, + im2col, hwcn_weights, output); + break; + case kTfLiteUInt8: + EvalQuantized(context, node, params, data, input, filter, + bias, im2col, hwcn_weights, output); + break; + default: + context->ReportError(context, "Type not currently supported."); + return kTfLiteError; + } + return kTfLiteOk; +} + +} // namespace conv + +TfLiteRegistration* Register_CONVOLUTION_REF() { + static TfLiteRegistration r = {conv::Init, conv::Free, conv::Prepare, + conv::Eval}; + return &r; +} + +TfLiteRegistration* Register_CONVOLUTION_GENERIC_OPT() { + static TfLiteRegistration r = {conv::Init, conv::Free, conv::Prepare, + conv::Eval}; + return &r; +} + +TfLiteRegistration* Register_CONVOLUTION_NEON_OPT() { + static TfLiteRegistration r = {conv::Init, conv::Free, conv::Prepare, + conv::Eval}; + return &r; +} + +TfLiteRegistration* Register_CONV_2D() { +#ifdef USE_NEON + return Register_CONVOLUTION_NEON_OPT(); +#else + return Register_CONVOLUTION_GENERIC_OPT(); +#endif +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/conv_test.cc b/tensorflow/contrib/lite/kernels/conv_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..1d0a81c3135625c07a3566f5f9a8e5401f0d4db7 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/conv_test.cc @@ -0,0 +1,440 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class BaseConvolutionOpModel : public SingleOpModel { + public: + // TODO(ahentz): Also test different activation types, bias, padding types, + // stride values. + BaseConvolutionOpModel( + const TensorData& input, const TensorData& filter, + const TensorData& output, int stride_width = 2, int stride_height = 2, + enum Padding padding = Padding_VALID, + enum ActivationFunctionType activation = ActivationFunctionType_NONE) { + input_ = AddInput(input); + filter_ = AddInput(filter); + + int bias_size = GetShape(filter_)[0]; + if (input.type == TensorType_FLOAT32) { + bias_ = AddInput({TensorType_FLOAT32, {bias_size}}); + } else { + // This is a quantized version. The scale of 'bias' depends on the scales + // of input and filter. Supposedly this is correctly set during quantized + // training. + auto bias_scale = GetScale(input_) * GetScale(filter_); + TensorData bias{TensorType_INT32, {bias_size}, 0, 0, bias_scale}; + bias_ = AddInput(bias); + } + + output_ = AddOutput(output); + if (input.type != TensorType_FLOAT32) { + // The following is required by quantized inference. It is the unittest's + // responsibility to make sure the output scale falls into the correct + // range. + CHECK_LT(GetScale(input_) * GetScale(filter_), GetScale(output_)); + } + + SetBuiltinOp(BuiltinOperator_CONV_2D, BuiltinOptions_Conv2DOptions, + CreateConv2DOptions(builder_, padding, stride_width, + stride_height, activation) + .Union()); + + BuildInterpreter({GetShape(input_), GetShape(filter_), GetShape(bias_)}); + } + + protected: + int input_; + int filter_; + int bias_; + int output_; +}; + +class ConvolutionOpModel : public BaseConvolutionOpModel { + public: + using BaseConvolutionOpModel::BaseConvolutionOpModel; + + void SetFilter(std::initializer_list f) { PopulateTensor(filter_, f); } + + void SetBias(std::initializer_list f) { PopulateTensor(bias_, f); } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } +}; + +TEST(ConvolutionOpTest, SimpleTestFloat32) { + ConvolutionOpModel m({TensorType_FLOAT32, {2, 2, 4, 1}}, + {TensorType_FLOAT32, {3, 2, 2, 1}}, + {TensorType_FLOAT32, {}}); + + m.SetInput({ + // First batch + 1, 1, 1, 1, // row = 1 + 2, 2, 2, 2, // row = 2 + // Second batch + 1, 2, 3, 4, // row = 1 + 1, 2, 3, 4, // row = 2 + }); + m.SetFilter({ + 1, 2, 3, 4, // first 2x2 filter + -1, 1, -1, 1, // second 2x2 filter + -1, -1, 1, 1, // third 2x2 filter + }); + m.SetBias({1, 2, 3}); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 18, 2, 5, // first batch, left + 18, 2, 5, // first batch, right + 17, 4, 3, // second batch, left + 37, 4, 3, // second batch, right + })); +} + +TEST(ConvolutionOpTest, SimpleTestFloat32WithAnisotropicStrides) { + ConvolutionOpModel m({TensorType_FLOAT32, {1, 3, 6, 1}}, + {TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {}}, + /*stride_width=*/3, /*stride_height=*/1); + m.SetInput({ + 3, 2, 1, -1, -2, -3, // + 4, 3, 2, -2, -3, -4, // + 5, 4, 3, -3, -4, -5, // + }); + m.SetFilter({ + 1, 2, // + 3, 4, // + }); + m.SetBias({-1}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 30, -24, // + 40, -34, // + })); +} + +TEST(ConvolutionOpTest, HandCalculatedFloat32) { + const int depth = 1; + const int image_width = 4; + const int image_height = 3; + const int image_batch_count = 1; + const int filter_size = 3; + const int filter_count = 1; + const int stride_width = 1; + const int stride_height = 1; + const Padding padding = Padding_SAME; + ConvolutionOpModel m( + {TensorType_FLOAT32, + {image_batch_count, image_height, image_width, depth}}, + {TensorType_FLOAT32, {depth, filter_size, filter_size, filter_count}}, + {TensorType_FLOAT32, {}}, stride_width, stride_height, padding); + + // The image matrix is: + // | 1 | 2 | 3 | 4 | + // | 5 | 6 | 7 | 8 | + // | 9 | 10 | 11 | 12 | + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + // The filter matrix is: + // | 1 | 4 | 7 | + // | 2 | 5 | 8 | + // | 3 | 6 | 9 | + m.SetFilter({1, 4, 7, 2, 5, 8, 3, 6, 9}); + // No bias for this test. + m.SetBias({0}); + + m.Invoke(); + // We're sliding the 3x3 filter across the 3x4 image, with accesses outside + // the input set to zero because we're using the 'SAME' padding mode. + // The calculations behind the expected output are: + // (1*0)+(4*0)+(7*0)+(2*0)+(5*1)+(8*2)+(3*0)+(6*5)+(9*6)=105 + // (1*0)+(4*0)+(7*0)+(2*1)+(5*2)+(8*3)+(3*5)+(6*6)+(9*7)=150 + // (1*0)+(4*0)+(7*0)+(2*2)+(5*3)+(8*4)+(3*6)+(6*7)+(9*8)=183 + // (1*0)+(4*0)+(7*0)+(2*3)+(5*4)+(8*0)+(3*7)+(6*8)+(9*0)=95 + // (1*0)+(4*1)+(7*2)+(2*0)+(5*5)+(8*6)+(3*0)+(6*9)+(9*10)=235 + // (1*1)+(4*2)+(7*3)+(2*5)+(5*6)+(8*7)+(3*9)+(6*10)+(9*11)=312 + // (1*2)+(4*3)+(7*4)+(2*6)+(5*7)+(8*8)+(3*10)+(6*11)+(9*12)=357 + // (1*3)+(4*4)+(7*0)+(2*7)+(5*8)+(8*0)+(3*11)+(6*12)+(9*0)=178 + // (1*0)+(4*5)+(7*6)+(2*0)+(5*9)+(8*10)+(3*0)+(6*0)+(9*0)=187 + // (1*5)+(4*6)+(7*7)+(2*9)+(5*10)+(8*11)+(3*0)+(6*0)+(9*0)=234 + // (1*6)+(4*7)+(7*8)+(2*10)+(5*11)+(8*12)+(3*0)+(6*0)+(9*0)=261 + // (1*7)+(4*11)+(7*0)+(2*8)+(5*12)+(8*0)+(3*0)+(6*0)+(9*0)=121 + // This means we should end up with this matrix: + // | 105 | 150 | 183 | 95 | + // | 235 | 312 | 357 | 178 | + // | 187 | 234 | 261 | 121 | + EXPECT_THAT(m.GetOutput(), ElementsAreArray({105, 150, 183, 95, 235, 312, 357, + 178, 187, 234, 261, 121})); +} + +TEST(ConvolutionOpTest, HandCalculatedWithBiasFloat32) { + const int depth = 1; + const int image_width = 4; + const int image_height = 3; + const int image_batch_count = 1; + const int filter_size = 3; + const int filter_count = 1; + const int stride_width = 1; + const int stride_height = 1; + const Padding padding = Padding_SAME; + ConvolutionOpModel m( + {TensorType_FLOAT32, + {image_batch_count, image_height, image_width, depth}}, + {TensorType_FLOAT32, {depth, filter_size, filter_size, filter_count}}, + {TensorType_FLOAT32, {}}, stride_width, stride_height, padding); + + // The image matrix is: + // | 1 | 2 | 3 | 4 | + // | 5 | 6 | 7 | 8 | + // | 9 | 10 | 11 | 12 | + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + // The filter matrix is: + // | 1 | 4 | 7 | + // | 2 | 5 | 8 | + // | 3 | 6 | 9 | + m.SetFilter({1, 4, 7, 2, 5, 8, 3, 6, 9}); + // Bias is | 10 |. + m.SetBias({10}); + + m.Invoke(); + // We're sliding the 3x3 filter across the 3x4 image, with accesses outside + // the input set to zero because we're using the 'SAME' padding mode. + // The calculations behind the expected output are: + // (1*0)+(4*0)+(7*0)+(2*0)+(5*1)+(8*2)+(3*0)+(6*5)+(9*6)+10=115 + // (1*0)+(4*0)+(7*0)+(2*1)+(5*2)+(8*3)+(3*5)+(6*6)+(9*7)+10=160 + // (1*0)+(4*0)+(7*0)+(2*2)+(5*3)+(8*4)+(3*6)+(6*7)+(9*8)+10=193 + // (1*0)+(4*0)+(7*0)+(2*3)+(5*4)+(8*0)+(3*7)+(6*8)+(9*0)+10=105 + // (1*0)+(4*1)+(7*2)+(2*0)+(5*5)+(8*6)+(3*0)+(6*9)+(9*10)+10=245 + // (1*1)+(4*2)+(7*3)+(2*5)+(5*6)+(8*7)+(3*9)+(6*10)+(9*11)+10=322 + // (1*2)+(4*3)+(7*4)+(2*6)+(5*7)+(8*8)+(3*10)+(6*11)+(9*12)+10=367 + // (1*3)+(4*4)+(7*0)+(2*7)+(5*8)+(8*0)+(3*11)+(6*12)+(9*0)+10=188 + // (1*0)+(4*5)+(7*6)+(2*0)+(5*9)+(8*10)+(3*0)+(6*0)+(9*0)+10=197 + // (1*5)+(4*6)+(7*7)+(2*9)+(5*10)+(8*11)+(3*0)+(6*0)+(9*0)+10=244 + // (1*6)+(4*7)+(7*8)+(2*10)+(5*11)+(8*12)+(3*0)+(6*0)+(9*0)+10=271 + // (1*7)+(4*11)+(7*0)+(2*8)+(5*12)+(8*0)+(3*0)+(6*0)+(9*0)+10=131 + // This means we should end up with this matrix: + // | 115 | 160 | 193 | 105 | + // | 245 | 322 | 367 | 188 | + // | 197 | 244 | 271 | 131 | + EXPECT_THAT(m.GetOutput(), ElementsAreArray({115, 160, 193, 105, 245, 322, + 367, 188, 197, 244, 271, 131})); +} + +TEST(ConvolutionOpTest, HandCalculatedWithReluFloat32) { + const int depth = 1; + const int image_width = 4; + const int image_height = 3; + const int image_batch_count = 1; + const int filter_size = 3; + const int filter_count = 1; + const int stride_width = 1; + const int stride_height = 1; + const Padding padding = Padding_SAME; + ConvolutionOpModel m( + {TensorType_FLOAT32, + {image_batch_count, image_height, image_width, depth}}, + {TensorType_FLOAT32, {depth, filter_size, filter_size, filter_count}}, + {TensorType_FLOAT32, {}}, stride_width, stride_height, padding, + ActivationFunctionType_RELU); + + // The image matrix is: + // | 1 | 2 | 3 | 4 | + // | 5 | 6 | 7 | 8 | + // | 9 | 10 | 11 | 12 | + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + // The filter matrix is: + // | 1 | 4 | 7 | + // | 2 | 5 | 8 | + // | 3 | 6 | 9 | + m.SetFilter({1, 4, 7, 2, 5, 8, 3, 6, 9}); + // Bias is | -200 |. + m.SetBias({-200}); + + m.Invoke(); + // We're sliding the 3x3 filter across the 3x4 image, with accesses outside + // the input set to zero because we're using the 'SAME' padding mode. + // The calculations behind the expected output are: + // (1*0)+(4*0)+(7*0)+(2*0)+(5*1)+(8*2)+(3*0)+(6*5)+(9*6)-200=-95 + // (1*0)+(4*0)+(7*0)+(2*1)+(5*2)+(8*3)+(3*5)+(6*6)+(9*7)-200=-50 + // (1*0)+(4*0)+(7*0)+(2*2)+(5*3)+(8*4)+(3*6)+(6*7)+(9*8)-200=-17 + // (1*0)+(4*0)+(7*0)+(2*3)+(5*4)+(8*0)+(3*7)+(6*8)+(9*0)-200=-105 + // (1*0)+(4*1)+(7*2)+(2*0)+(5*5)+(8*6)+(3*0)+(6*9)+(9*10)-200=35 + // (1*1)+(4*2)+(7*3)+(2*5)+(5*6)+(8*7)+(3*9)+(6*10)+(9*11)-200=112 + // (1*2)+(4*3)+(7*4)+(2*6)+(5*7)+(8*8)+(3*10)+(6*11)+(9*12)-200=157 + // (1*3)+(4*4)+(7*0)+(2*7)+(5*8)+(8*0)+(3*11)+(6*12)+(9*0)-200=-22 + // (1*0)+(4*5)+(7*6)+(2*0)+(5*9)+(8*10)+(3*0)+(6*0)+(9*0)-200=-13 + // (1*5)+(4*6)+(7*7)+(2*9)+(5*10)+(8*11)+(3*0)+(6*0)+(9*0)-200=34 + // (1*6)+(4*7)+(7*8)+(2*10)+(5*11)+(8*12)+(3*0)+(6*0)+(9*0)-200=61 + // (1*7)+(4*11)+(7*0)+(2*8)+(5*12)+(8*0)+(3*0)+(6*0)+(9*0)-200=-79 + // All negative values are gated to zero by the Relu activation function. + // This means we should end up with this matrix: + // | 0 | 0 | 0 | 0 | + // | 35 | 112 | 157 | 0 | + // | 0 | 34 | 61 | 0 | + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({0, 0, 0, 0, 35, 112, 157, 0, 0, 34, 61, 0})); +} + +TEST(ConvolutionOpTest, HandCalculatedValidFloat32) { + const int depth = 1; + const int image_width = 4; + const int image_height = 3; + const int image_batch_count = 1; + const int filter_size = 3; + const int filter_count = 1; + const int stride_width = 1; + const int stride_height = 1; + const Padding padding = Padding_VALID; + ConvolutionOpModel m( + {TensorType_FLOAT32, + {image_batch_count, image_height, image_width, depth}}, + {TensorType_FLOAT32, {depth, filter_size, filter_size, filter_count}}, + {TensorType_FLOAT32, {}}, stride_width, stride_height, padding); + + // The image matrix is: + // | 1 | 2 | 3 | 4 | + // | 5 | 6 | 7 | 8 | + // | 9 | 10 | 11 | 12 | + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + // The filter matrix is: + // | 1 | 4 | 7 | + // | 2 | 5 | 8 | + // | 3 | 6 | 9 | + m.SetFilter({1, 4, 7, 2, 5, 8, 3, 6, 9}); + // No bias for this test. + m.SetBias({0}); + + m.Invoke(); + // We're sliding the 3x3 filter across the 3x4 image, with no accesses outside + // the input because we're using the 'VALID' padding mode, giving a 2x1 + // output. + // The calculations behind the expected output are: + // (1*1)+(4*2)+(7*3)+(2*5)+(5*6)+(8*7)+(3*9)+(6*10)+(9*11)=312 + // (1*2)+(4*3)+(7*4)+(2*6)+(5*7)+(8*8)+(3*10)+(6*11)+(9*12)=357 + // This means we should end up with this matrix: + // | 312 | 357 | + EXPECT_THAT(m.GetOutput(), ElementsAreArray({312, 357})); +} + +class QuantizedConvolutionOpModel : public BaseConvolutionOpModel { + public: + using BaseConvolutionOpModel::BaseConvolutionOpModel; + + void SetInput(std::initializer_list data) { + QuantizeAndPopulate(input_, data); + } + + void SetFilter(std::initializer_list data) { + QuantizeAndPopulate(filter_, data); + } + + void SetBias(std::initializer_list data) { + QuantizeAndPopulate(bias_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetDequantizedOutput() { + return Dequantize(ExtractVector(output_), + GetScale(output_), GetZeroPoint(output_)); + } +}; + +// In this tests we set the input and output scales so that the results +// match exactly the 'non-quantized' version. +TEST(ConvolutionOpTest, SimpleTestQuantized) { + QuantizedConvolutionOpModel m({TensorType_UINT8, {2, 2, 4, 1}, -63.5, 64}, + {TensorType_UINT8, {3, 2, 2, 1}, -63.5, 64}, + {TensorType_UINT8, {}, -127, 128}); + m.SetInput({ + // First batch + 1, 1, 1, 1, // row = 1 + 2, 2, 2, 2, // row = 2 + // Second batch + 1, 2, 3, 4, // row = 1 + 1, 2, 3, 4, // row = 2 + }); + m.SetFilter({ + 1, 2, 3, 4, // first 2x2 filter + -1, 1, -1, 1, // second 2x2 filter + -1, -1, 1, 1, // third 2x2 filter + }); + m.SetBias({1, 2, 3}); + + m.Invoke(); + + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + { + 18, 2, 5, // first batch, left + 18, 2, 5, // first batch, right + 17, 4, 3, // second batch, left + 37, 4, 3, // second batch, right + }, + 1e-5))); + // For good measure, let's also verify the quantized values: + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 145, 129, 132, // + 145, 129, 132, // + 144, 131, 130, // + 164, 131, 130, // + })); +} + +TEST(ConvolutionOpTest, SimpleTestQuantizedWithAnisotropicStrides) { + QuantizedConvolutionOpModel m({TensorType_UINT8, {1, 3, 6, 1}, -63.5, 64}, + {TensorType_UINT8, {1, 2, 2, 1}, -63.5, 64}, + {TensorType_UINT8, {}, -127, 128}, + /*stride_width=*/3, /*stride_height=*/1); + m.SetInput({ + 3, 2, 1, -1, -2, -3, // + 4, 3, 2, -2, -3, -4, // + 5, 4, 3, -3, -4, -5, // + }); + m.SetFilter({ + 1, 2, // + 3, 4, // + }); + m.SetBias({-1}); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear({ + 30, -24, // + 40, -34, // + }))); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 157, 103, // + 167, 93, // + })); +} +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/depthwise_conv.cc b/tensorflow/contrib/lite/kernels/depthwise_conv.cc new file mode 100644 index 0000000000000000000000000000000000000000..15dbfe08c82befcf001b9ed9a053528b5606053e --- /dev/null +++ b/tensorflow/contrib/lite/kernels/depthwise_conv.cc @@ -0,0 +1,289 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h" +#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/contrib/lite/kernels/padding.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace depthwise_conv { + +constexpr int kInputTensor = 0; +constexpr int kFilterTensor = 1; +constexpr int kBiasTensor = 2; +constexpr int kOutputTensor = 0; + +// This file has three implementation of DepthwiseConv. +enum KernelType { + kReference, + kGenericOptimized, // Neon-free + kNeonOptimized, +}; + +struct OpData { + TfLitePaddingValues padding; + // The scaling factor from input to output (aka the 'real multiplier') can + // be represented as a fixed point multipler plus a left shift. + int32_t output_multiplier; + int output_shift; + // The range of the fused activation layer. For example for kNone and + // uint8_t these would be 0 and 255. + int32_t output_activation_min; + int32_t output_activation_max; +}; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + // This is a builtin op, so we don't use the contents in 'buffer', if any. + // Instead, we allocate a new object to carry information from Prepare() to + // Eval(). + return new OpData; +} + +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast(buffer); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + auto* params = + reinterpret_cast(node->builtin_data); + OpData* data = reinterpret_cast(node->user_data); + + // TODO(ahentz): use could use GetOptionalInputTensor() here, but we need to + // decide whether we are OK with optional tensors being completely absent, as + // opposed to having -1 as their index. + bool hasBias = NumInputs(node) == 3; + + TF_LITE_ENSURE(context, hasBias || NumInputs(node) == 2); + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* filter = GetInput(context, node, kFilterTensor); + TfLiteTensor* bias = nullptr; + + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4); + TF_LITE_ENSURE_EQ(context, NumDimensions(filter), 4); + + // The parameter 'depth_multiplier' is redundant, so we check here to make + // sure it is consistent with the given dimensions. + TF_LITE_ENSURE_EQ(context, + params->depth_multiplier * SizeOfDimension(input, 3), + SizeOfDimension(filter, 3)); + + const TfLiteType data_type = input->type; + TF_LITE_ENSURE(context, + data_type == kTfLiteFloat32 || data_type == kTfLiteUInt8); + TF_LITE_ENSURE_EQ(context, output->type, data_type); + TF_LITE_ENSURE_EQ(context, filter->type, data_type); + + if (hasBias) { + bias = GetInput(context, node, kBiasTensor); + if (data_type == kTfLiteUInt8) { + TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteInt32); + TF_LITE_ENSURE_EQ(context, bias->params.zero_point, 0); + } else { + TF_LITE_ENSURE_EQ(context, bias->type, data_type); + } + TF_LITE_ENSURE_EQ(context, NumDimensions(bias), 1); + TF_LITE_ENSURE_EQ(context, SizeOfDimension(filter, 3), + SizeOfDimension(bias, 0)); + } + + int channels_out = SizeOfDimension(filter, 3); + int width = SizeOfDimension(input, 2); + int height = SizeOfDimension(input, 1); + int filter_width = SizeOfDimension(filter, 2); + int filter_height = SizeOfDimension(filter, 1); + int batches = SizeOfDimension(input, 0); + + // Matching GetWindowedOutputSize in TensorFlow. + auto padding = params->padding; + auto compute_out_size = [padding](int imageSize, int filterSize, + int stride) -> int { + return padding == kTfLitePaddingSame + ? (imageSize + stride - 1) / stride + : padding == kTfLitePaddingValid + ? (imageSize - filterSize + stride) / stride + : 0; + }; + + int out_width = compute_out_size(width, filter_width, params->stride_width); + int out_height = + compute_out_size(height, filter_height, params->stride_height); + + data->padding.height = + ComputePadding(params->stride_height, height, filter_height, out_height); + data->padding.width = + ComputePadding(params->stride_width, width, filter_width, out_width); + + // Note that quantized inference requires that all tensors have their + // parameters set. This is usually done during quantized training. + if (data_type != kTfLiteFloat32) { + double real_multiplier = 0.0; + TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler( + context, input, filter, bias, output, &real_multiplier)); + QuantizeMultiplierSmallerThanOne(real_multiplier, &data->output_multiplier, + &data->output_shift); + CalculateActivationRangeUint8(params->activation, output, + &data->output_activation_min, + &data->output_activation_max); + } + + TfLiteIntArray* outputSize = TfLiteIntArrayCreate(4); + outputSize->data[0] = batches; + outputSize->data[1] = out_height; + outputSize->data[2] = out_width; + outputSize->data[3] = channels_out; + return context->ResizeTensor(context, output, outputSize); +} + +template +void EvalFloat(TfLiteContext* context, TfLiteNode* node, + TfLiteDepthwiseConvParams* params, OpData* data, + TfLiteTensor* input, TfLiteTensor* filter, TfLiteTensor* bias, + TfLiteTensor* output) { + float output_activation_min, output_activation_max; + CalculateActivationRangeFloat(params->activation, &output_activation_min, + &output_activation_max); + + void (*depthwise_conv)(const float*, const Dims<4>&, const float*, + const Dims<4>&, const float*, const Dims<4>&, int, int, + int, int, int, float, float, float*, const Dims<4>&); + if (kernel_type == kReference) { + depthwise_conv = &reference_ops::DepthwiseConv; + } else { + depthwise_conv = &optimized_ops::DepthwiseConv; + } + + depthwise_conv( + GetTensorData(input), GetTensorDims(input), + GetTensorData(filter), GetTensorDims(filter), + GetTensorData(bias), GetTensorDims(bias), params->stride_width, + params->stride_height, data->padding.width, data->padding.height, + params->depth_multiplier, output_activation_min, output_activation_max, + GetTensorData(output), GetTensorDims(output)); +} + +template +void EvalQuantized(TfLiteContext* context, TfLiteNode* node, + TfLiteDepthwiseConvParams* params, OpData* data, + TfLiteTensor* input, TfLiteTensor* filter, + TfLiteTensor* bias, TfLiteTensor* output) { + auto input_offset = -input->params.zero_point; + auto filter_offset = -filter->params.zero_point; + auto output_offset = output->params.zero_point; + + void (*depthwise_conv)(const uint8*, const Dims<4>&, int32, const uint8*, + const Dims<4>&, int32, const int32*, const Dims<4>&, + int, int, int, int, int, int32, int32, int, int32, + int32, uint8*, const Dims<4>&); + if (kernel_type == kReference) { + depthwise_conv = &reference_ops::DepthwiseConv; + } else { + depthwise_conv = &optimized_ops::DepthwiseConv; + } + + depthwise_conv( + GetTensorData(input), GetTensorDims(input), input_offset, + GetTensorData(filter), GetTensorDims(filter), filter_offset, + GetTensorData(bias), GetTensorDims(bias), params->stride_width, + params->stride_height, data->padding.width, data->padding.height, + params->depth_multiplier, output_offset, data->output_multiplier, + data->output_shift, data->output_activation_min, + data->output_activation_max, GetTensorData(output), + GetTensorDims(output)); +} + +template +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = + reinterpret_cast(node->builtin_data); + OpData* data = reinterpret_cast(node->user_data); + + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* filter = GetInput(context, node, kFilterTensor); + TfLiteTensor* bias = + (NumInputs(node) == 3) ? GetInput(context, node, kBiasTensor) : nullptr; + + // TODO(aselle): Consider whether float conv and quantized conv should be + // separate ops to avoid dispatch overhead here. + switch (input->type) { // Already know in/out types are same. + case kTfLiteFloat32: + EvalFloat(context, node, params, data, input, filter, bias, + output); + break; + case kTfLiteUInt8: + EvalQuantized(context, node, params, data, input, filter, + bias, output); + break; + default: + context->ReportError(context, "Type not currently supported."); + return kTfLiteError; + } + return kTfLiteOk; +} + +} // namespace depthwise_conv + +TfLiteRegistration* Register_DEPTHWISE_CONVOLUTION_REF() { + static TfLiteRegistration r = { + depthwise_conv::Init, depthwise_conv::Free, depthwise_conv::Prepare, + depthwise_conv::Eval}; + return &r; +} + +TfLiteRegistration* Register_DEPTHWISE_CONVOLUTION_GENERIC_OPT() { + static TfLiteRegistration r = { + depthwise_conv::Init, depthwise_conv::Free, depthwise_conv::Prepare, + depthwise_conv::Eval}; + return &r; +} + +TfLiteRegistration* Register_DEPTHWISE_CONVOLUTION_NEON_OPT() { + static TfLiteRegistration r = { + depthwise_conv::Init, depthwise_conv::Free, depthwise_conv::Prepare, + depthwise_conv::Eval}; + return &r; +} + +TfLiteRegistration* Register_DEPTHWISE_CONV_2D() { +#ifdef USE_NEON + return Register_DEPTHWISE_CONVOLUTION_NEON_OPT(); +#else + return Register_DEPTHWISE_CONVOLUTION_GENERIC_OPT(); +#endif +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc b/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..1439c8bce14ad127ed68dc54991aed8b8bb39383 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc @@ -0,0 +1,186 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class BaseDepthwiseConvolutionOpModel : public SingleOpModel { + public: + // TODO(ahentz): Also test different activation types, bias, padding types, + // stride values. + BaseDepthwiseConvolutionOpModel(const TensorData& input, + const TensorData& filter, + const TensorData& output) { + input_ = AddInput(input); + filter_ = AddInput(filter); + + int bias_size = GetShape(filter_)[3]; + if (input.type == TensorType_FLOAT32) { + bias_ = AddInput({TensorType_FLOAT32, {bias_size}}); + } else { + // This is a quantized version. The scale of 'bias' depends on the scales + // of input and filter. Supposedly this is correctly set during quantized + // training. + auto bias_scale = GetScale(input_) * GetScale(filter_); + TensorData bias{TensorType_INT32, {bias_size}, 0, 0, bias_scale}; + bias_ = AddInput(bias); + } + + output_ = AddOutput(output); + if (input.type != TensorType_FLOAT32) { + // The following is required by quantized inference. It is the unittest's + // responsibility to make sure the output scale falls into the correct + // range. + CHECK_LT(GetScale(input_) * GetScale(filter_), GetScale(output_)); + } + + int input_depth = GetShape(input_)[3]; + int output_depth = GetShape(filter_)[3]; + int depth_mul = output_depth / input_depth; + + SetBuiltinOp( + BuiltinOperator_DEPTHWISE_CONV_2D, + BuiltinOptions_DepthwiseConv2DOptions, + CreateDepthwiseConv2DOptions(builder_, Padding_VALID, 1, 1, depth_mul, + ActivationFunctionType_NONE) + .Union()); + + BuildInterpreter({GetShape(input_), GetShape(filter_), GetShape(bias_)}); + } + + protected: + int input_; + int filter_; + int bias_; + int output_; +}; + +class DepthwiseConvolutionOpModel : public BaseDepthwiseConvolutionOpModel { + public: + using BaseDepthwiseConvolutionOpModel::BaseDepthwiseConvolutionOpModel; + + void SetFilter(std::initializer_list f) { PopulateTensor(filter_, f); } + + void SetBias(std::initializer_list f) { PopulateTensor(bias_, f); } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } +}; + +TEST(DepthwiseConvolutionOpTest, SimpleTest) { + DepthwiseConvolutionOpModel m({TensorType_FLOAT32, {1, 3, 2, 2}}, + {TensorType_FLOAT32, {1, 2, 2, 4}}, + {TensorType_FLOAT32, {}}); + + m.SetInput({ + 1, 2, 7, 8, // column 1 + 3, 4, 9, 10, // column 2 + 5, 6, 11, 12, // column 3 + }); + m.SetFilter({ + 1, 2, 3, 4, // + -9, 10, -11, 12, // + 5, 6, 7, 8, // + 13, -14, 15, -16, // + }); + m.SetBias({1, 2, 3, 4}); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 71, -34, 99, -20, // + 91, -26, 127, -4, // + })); +} + +class QuantizedDepthwiseConvolutionOpModel + : public BaseDepthwiseConvolutionOpModel { + public: + using BaseDepthwiseConvolutionOpModel::BaseDepthwiseConvolutionOpModel; + + void SetInput(std::initializer_list data) { + QuantizeAndPopulate(input_, data); + } + + void SetFilter(std::initializer_list data) { + QuantizeAndPopulate(filter_, data); + } + + void SetBias(std::initializer_list data) { + QuantizeAndPopulate(bias_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetDequantizedOutput() { + return Dequantize(ExtractVector(output_), + GetScale(output_), GetZeroPoint(output_)); + } +}; + +// In this test we set the input and output scales so that the results match +// exactly the 'non-quantized' version. +TEST(QuantizedDepthwiseConvolutionOpTest, SimpleTestQuantized) { + QuantizedDepthwiseConvolutionOpModel m( + {TensorType_UINT8, {1, 3, 2, 2}, -63.5, 64}, + {TensorType_UINT8, {1, 2, 2, 4}, -63.5, 64}, + {TensorType_UINT8, {}, -127, 128}); + + m.SetInput({ + 1, 2, 7, 8, // column 1 + 3, 4, 9, 10, // column 2 + 5, 6, 11, 12, // column 3 + }); + m.SetFilter({ + 1, 2, 3, 4, // + -9, 10, -11, 12, // + 5, 6, 7, 8, // + 13, -14, 15, -16, // + }); + m.SetBias({1, 2, 3, 4}); + + m.Invoke(); + + EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear( + { + 71, -34, 99, -20, // + 91, -26, 127, -4, // + }, + 1e-5))); + // For good measure, let's also verify the quantized values: + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 198, 93, 226, 107, // + 218, 101, 254, 123, // + })); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup.cc b/tensorflow/contrib/lite/kernels/embedding_lookup.cc new file mode 100644 index 0000000000000000000000000000000000000000..4e8cb396d43a58f94b08eb8dd8b05d16fd74fd2f --- /dev/null +++ b/tensorflow/contrib/lite/kernels/embedding_lookup.cc @@ -0,0 +1,104 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Ops that looks up items from matrix. +// +// Input: +// Tensor[0]: Row number to lookup, dim.size == 1, int32 +// Tensor[1]: 2-dimensional matrix of multi-dimensional items +// dim.size >= 2, any data type. +// first dimension is row, second dimension is column. +// +// Output: +// Output.dim[0] == Tensor[0].dim[0], num of lookups +// Output.dim[1] == Tensor[1].dim[1], num of items per row +// Each item in output is a raw bytes copy of corresponding item in input. +// When indices are out of bound, the ops will not succeed. +// + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace embedding_lookup { + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + TfLiteTensor* lookup = GetInput(context, node, 0); + TF_LITE_ENSURE_EQ(context, NumDimensions(lookup), 1); + TF_LITE_ENSURE_EQ(context, lookup->type, kTfLiteInt32); + + TfLiteTensor* value = GetInput(context, node, 1); + TF_LITE_ENSURE(context, NumDimensions(value) >= 2); + + TfLiteTensor* output = GetOutput(context, node, 0); + TfLiteIntArray* outputSize = TfLiteIntArrayCreate(NumDimensions(value)); + + outputSize->data[0] = SizeOfDimension(lookup, 0); + outputSize->data[1] = SizeOfDimension(value, 1); + for (int i = 2; i < NumDimensions(value); i++) { + outputSize->data[i] = SizeOfDimension(value, i); + } + return context->ResizeTensor(context, output, outputSize); +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* output = GetOutput(context, node, 0); + TfLiteTensor* lookup = GetInput(context, node, 0); + TfLiteTensor* value = GetInput(context, node, 1); + + const int row_size = SizeOfDimension(value, 0); + const int row_bytes = value->bytes / row_size; + + for (int i = 0; i < SizeOfDimension(lookup, 0); i++) { + int idx = lookup->data.i32[i]; + if (idx >= row_size || idx < 0) { + context->ReportError(context, "Embedding Lookup: index out of bounds."); + return kTfLiteError; + } else { + memcpy(output->data.raw + i * row_bytes, + value->data.raw + idx * row_bytes, row_bytes); + } + } + + return kTfLiteOk; +} + +} // namespace embedding_lookup + +TfLiteRegistration* Register_EMBEDDING_LOOKUP() { + static TfLiteRegistration r = {nullptr, nullptr, embedding_lookup::Prepare, + embedding_lookup::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup_sparse.cc b/tensorflow/contrib/lite/kernels/embedding_lookup_sparse.cc new file mode 100644 index 0000000000000000000000000000000000000000..6c770e7f71efe83eace3640c47e03e0c7ab19e20 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/embedding_lookup_sparse.cc @@ -0,0 +1,248 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Op that looks up items from a sparse tensor in an embedding matrix. +// The sparse lookup tensor is represented by three individual tensors: lookup, +// indices, and dense_shape. The representation assume that the corresponding +// dense tensor would satisfy: +// * dense.shape = dense_shape +// * dense[tuple(indices[i])] = lookup[i] +// +// By convention, indices should be sorted. +// +// Options: +// combiner: The reduction op (SUM, MEAN, SQRTN). +// * SUM computes the weighted sum of the embedding results. +// * MEAN is the weighted sum divided by the total weight. +// * SQRTN is the weighted sum divided by the square root of the sum of the +// squares of the weights. +// +// Input: +// Tensor[0]: Ids to lookup, dim.size == 1, int32. +// Tensor[1]: Indices, int32. +// Tensor[2]: Dense shape, int32. +// Tensor[3]: Weights to use for aggregation, float. +// Tensor[4]: Params, a matrix of multi-dimensional items, +// dim.size >= 2, float. +// +// Output: +// A (dense) tensor representing the combined embeddings for the sparse ids. +// For each row in the sparse tensor represented by (lookup, indices, shape) +// the op looks up the embeddings for all ids in that row, multiplies them by +// the corresponding weight, and combines these embeddings as specified in the +// last dimension. +// +// Output.dim = [l0, ... , ln-1, e1, ..., em] +// Where dense_shape == [l0, ..., ln] and Tensor[4].dim == [e0, e1, ..., em] +// +// For instance, if params is a 10x20 matrix and ids, weights are: +// +// [0, 0]: id 1, weight 2.0 +// [0, 1]: id 3, weight 0.5 +// [1, 0]: id 0, weight 1.0 +// [2, 3]: id 1, weight 3.0 +// +// with combiner=MEAN, then the output will be a (3, 20) tensor where: +// +// output[0, :] = (params[1, :] * 2.0 + params[3, :] * 0.5) / (2.0 + 0.5) +// output[1, :] = (params[0, :] * 1.0) / 1.0 +// output[2, :] = (params[1, :] * 3.0) / 3.0 +// +// When indices are out of bound, the op will not succeed. + +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { + +namespace { + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 5); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + TfLiteTensor* ids = GetInput(context, node, 0); + TF_LITE_ENSURE_EQ(context, NumDimensions(ids), 1); + TF_LITE_ENSURE_EQ(context, ids->type, kTfLiteInt32); + + TfLiteTensor* indices = GetInput(context, node, 1); + TF_LITE_ENSURE_EQ(context, NumDimensions(indices), 2); + TF_LITE_ENSURE_EQ(context, indices->type, kTfLiteInt32); + + TfLiteTensor* shape = GetInput(context, node, 2); + TF_LITE_ENSURE_EQ(context, NumDimensions(shape), 1); + TF_LITE_ENSURE_EQ(context, shape->type, kTfLiteInt32); + + TfLiteTensor* weights = GetInput(context, node, 3); + TF_LITE_ENSURE_EQ(context, NumDimensions(weights), 1); + TF_LITE_ENSURE_EQ(context, weights->type, kTfLiteFloat32); + + TF_LITE_ENSURE_EQ(context, SizeOfDimension(indices, 0), + SizeOfDimension(ids, 0)); + TF_LITE_ENSURE_EQ(context, SizeOfDimension(indices, 0), + SizeOfDimension(weights, 0)); + + TfLiteTensor* value = GetInput(context, node, 4); + TF_LITE_ENSURE(context, NumDimensions(value) >= 2); + + // Mark the output as a dynamic tensor. + TfLiteTensor* output = GetOutput(context, node, 0); + TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32); + output->allocation_type = kTfLiteDynamic; + + return kTfLiteOk; +} + +void FinalizeAggregation(TfLiteCombinerType combiner, int num_elements, + float current_total_weight, + float current_squares_weight, int embedding_size, + float* output) { + if (combiner != kTfLiteCombinerTypeSum && num_elements > 0) { + float multiplier = 1.0; + switch (combiner) { + case kTfLiteCombinerTypeMean: + multiplier = current_total_weight; + break; + case kTfLiteCombinerTypeSqrtn: + multiplier = std::sqrt(current_squares_weight); + break; + default: + break; + } + for (int k = 0; k < embedding_size; k++) { + output[k] /= multiplier; + } + } +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = + reinterpret_cast(node->builtin_data); + TfLiteTensor* output = GetOutput(context, node, 0); + TfLiteTensor* ids = GetInput(context, node, 0); + TfLiteTensor* indices = GetInput(context, node, 1); + TfLiteTensor* dense_shape = GetInput(context, node, 2); + TfLiteTensor* weights = GetInput(context, node, 3); + TfLiteTensor* value = GetInput(context, node, 4); + + const int lookup_rank = SizeOfDimension(indices, 1); + const int embedding_rank = NumDimensions(value); + const int num_lookups = SizeOfDimension(ids, 0); + const int num_rows = SizeOfDimension(value, 0); + + // The last dimension gets replaced by the embedding. + const int output_rank = (lookup_rank - 1) + (embedding_rank - 1); + + // Make sure that the actual dense shape of the sparse tensor represented by + // (loopkup, indices, dense_shape) is consistent. + TF_LITE_ENSURE_EQ(context, SizeOfDimension(dense_shape, 0), lookup_rank); + + // Resize output tensor. + TfLiteIntArray* output_shape = TfLiteIntArrayCreate(output_rank); + int k = 0; + int embedding_size = 1; + int lookup_size = 1; + for (int i = 0; i < lookup_rank - 1; i++, k++) { + const int dim = dense_shape->data.i32[i]; + lookup_size *= dim; + output_shape->data[k] = dim; + } + for (int i = 1; i < embedding_rank; i++, k++) { + const int dim = SizeOfDimension(value, i); + embedding_size *= dim; + output_shape->data[k] = dim; + } + TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, output, output_shape)); + const int output_size = lookup_size * embedding_size; + TfLiteTensorRealloc(output_size * sizeof(float), output); + + tensor_utils::ZeroVector(output->data.f, output_size); + + // Keep track of the current bucket for aggregation/combination. + int current_output_offset = 0; + float current_total_weight = 0.0; + float current_squares_weight = 0.0; + int num_elements = 0; + + for (int i = 0; i < num_lookups; i++) { + int idx = ids->data.i32[i]; + if (idx >= num_rows || idx < 0) { + context->ReportError(context, + "Embedding Lookup Sparse: index out of bounds."); + return kTfLiteError; + } + + // Check where we need to aggregate. + const int example_indices_offset = i * lookup_rank; + int output_bucket = 0; + int stride = 1; + for (int k = (lookup_rank - 1) - 1; k >= 0; k--) { + output_bucket += indices->data.i32[example_indices_offset + k] * stride; + stride *= dense_shape->data.i32[k]; + } + const int output_offset = output_bucket * embedding_size; + + // If we are in a new aggregation bucket and the combiner is not the sum, + // go back and finalize the result of the previous bucket. + if (output_offset != current_output_offset) { + FinalizeAggregation(params->combiner, num_elements, current_total_weight, + current_squares_weight, embedding_size, + &output->data.f[current_output_offset]); + + // Track next bucket. + num_elements = 0; + current_total_weight = 0.0; + current_squares_weight = 0.0; + current_output_offset = output_offset; + } + + // Add element to aggregation. + ++num_elements; + const int example_embedding_offset = idx * embedding_size; + const float w = weights->data.f[i]; + current_squares_weight += w * w; + current_total_weight += w; + for (int k = 0; k < embedding_size; k++) { + output->data.f[current_output_offset + k] += + (value->data.f[example_embedding_offset + k] * w); + } + } + + // Finalize last bucket. + FinalizeAggregation(params->combiner, num_elements, current_total_weight, + current_squares_weight, embedding_size, + &output->data.f[current_output_offset]); + + return kTfLiteOk; +} + +} // namespace + +TfLiteRegistration* Register_EMBEDDING_LOOKUP_SPARSE() { + static TfLiteRegistration r = {nullptr, nullptr, Prepare, Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup_sparse_test.cc b/tensorflow/contrib/lite/kernels/embedding_lookup_sparse_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..dcdc5fffad9ceac1a9d23a4e91637a9ff92a8dda --- /dev/null +++ b/tensorflow/contrib/lite/kernels/embedding_lookup_sparse_test.cc @@ -0,0 +1,164 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Unit test for TFLite sparse lookup op. + +#include +#include + +#include +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class EmbeddingLookupSparseOpModel : public SingleOpModel { + public: + EmbeddingLookupSparseOpModel(CombinerType type, + std::initializer_list lookup_shape, + std::initializer_list indices_shape, + std::initializer_list dense_shape_shape, + std::initializer_list value_shape) { + lookup_ = AddInput(TensorType_INT32); + indices_ = AddInput(TensorType_INT32); + dense_shape_ = AddInput(TensorType_INT32); + weights_ = AddInput(TensorType_FLOAT32); + value_ = AddInput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp(BuiltinOperator_EMBEDDING_LOOKUP_SPARSE, + BuiltinOptions_EmbeddingLookupSparseOptions, + CreateEmbeddingLookupSparseOptions(builder_, type).Union()); + BuildInterpreter({lookup_shape, indices_shape, dense_shape_shape, + lookup_shape, value_shape}); + } + + void SetInput(std::initializer_list lookup_data, + std::initializer_list indices_data, + std::initializer_list dense_shape_data, + std::initializer_list weights_data) { + PopulateTensor(lookup_, lookup_data); + PopulateTensor(indices_, indices_data); + PopulateTensor(dense_shape_, dense_shape_data); + PopulateTensor(weights_, weights_data); + } + + void Set3DWeightMatrix(const std::function& function) { + TfLiteTensor* tensor = interpreter_->tensor(value_); + int rows = tensor->dims->data[0]; + int columns = tensor->dims->data[1]; + int features = tensor->dims->data[2]; + for (int i = 0; i < rows; i++) { + for (int j = 0; j < columns; j++) { + for (int k = 0; k < features; k++) { + tensor->data.f[(i * columns + j) * features + k] = function(i, j, k); + } + } + } + } + + std::vector GetOutput() { return ExtractVector(output_); } + + private: + int lookup_; + int weights_; + int indices_; + int dense_shape_; + int value_; + int output_; +}; + +TEST(EmbeddingLookupOpTest, SimpleTest) { + EmbeddingLookupSparseOpModel m(CombinerType_SUM, {3}, {3, 2}, {2}, {4, 3, 2}); + m.SetInput({1, 3, 0}, {0, 0, 2, 0, 2, 1}, {3, 2}, {1.0, 2.0, 4.0}); + m.Set3DWeightMatrix( + [](int i, int j, int k) { return i + j / 10.0f + k / 100.0f; }); + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({ + 1.00, 1.01, 1.10, 1.11, 1.20, 1.21, // Row 1 + 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, // - + 6.00, 6.06, 6.60, 6.66, 7.20, 7.26, // 2 * Row 3 + 4 * Row 0 + }))); +} + +TEST(EmbeddingLookupOpTest, SimpleTestMean) { + EmbeddingLookupSparseOpModel m(CombinerType_MEAN, {3}, {3, 2}, {2}, + {4, 3, 2}); + m.SetInput({1, 3, 0}, {0, 0, 2, 0, 2, 1}, {3, 2}, {1.0, 2.0, 4.0}); + m.Set3DWeightMatrix( + [](int i, int j, int k) { return i + j / 10.0f + k / 100.0f; }); + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({ + 1.00, 1.01, 1.10, 1.11, 1.20, 1.21, // Row 1 + 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, // - + 1.00, 1.01, 1.10, 1.11, 1.20, 1.21, // 2 * Row 3 + 4 * Row 0 + }))); +} + +TEST(EmbeddingLookupOpTest, SimpleTestSqrtn) { + EmbeddingLookupSparseOpModel m(CombinerType_SQRTN, {3}, {3, 2}, {2}, + {4, 3, 2}); + m.SetInput({1, 3, 0}, {0, 0, 2, 0, 2, 1}, {3, 2}, {1.0, 2.0, 4.0}); + m.Set3DWeightMatrix( + [](int i, int j, int k) { return i + j / 10.0f + k / 100.0f; }); + m.Invoke(); + + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray(ArrayFloatNear({ + 1.00, 1.01, 1.10, 1.11, 1.20, 1.21, // Row 1 + 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, // - + 6.00f / std::sqrt(20.0f), 6.06f / std::sqrt(20.0f), + 6.60f / std::sqrt(20.0f), 6.66f / std::sqrt(20.0f), + 7.20f / std::sqrt(20.0f), + 7.26f / + std::sqrt( + 20.0f), // 2 * Row 3 + 4 * Row 0, // 2 * Row 3 + 4 * Row 0 + }))); +} + +TEST(EmbeddingLookupOpTest, Indices3DTest) { + EmbeddingLookupSparseOpModel m(CombinerType_SUM, {3}, {3, 3}, {3}, {4, 3, 2}); + m.SetInput({1, 3, 0}, {0, 0, 0, 2, 0, 0, 2, 0, 1}, {3, 2, 2}, + {1.0, 2.0, 4.0}); + m.Set3DWeightMatrix( + [](int i, int j, int k) { return i + j / 10.0f + k / 100.0f; }); + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({ + 1.00, 1.01, 1.10, 1.11, 1.20, 1.21, 0.00, 0.00, 0.00, + 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, + 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 6.00, 6.06, 6.60, + 6.66, 7.20, 7.26, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, + }))); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc b/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..9b501878f196216a61568bfa36e6615f4dd07478 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc @@ -0,0 +1,94 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Unit test for TFLite Lookup op. + +#include +#include + +#include +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class EmbeddingLookupOpModel : public SingleOpModel { + public: + EmbeddingLookupOpModel(std::initializer_list index_shape, + std::initializer_list weight_shape) { + input_ = AddInput(TensorType_INT32); + weight_ = AddInput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp(BuiltinOperator_EMBEDDING_LOOKUP, BuiltinOptions_NONE, 0); + BuildInterpreter({index_shape, weight_shape}); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + void Set3DWeightMatrix(const std::function& function) { + TfLiteTensor* tensor = interpreter_->tensor(weight_); + int rows = tensor->dims->data[0]; + int columns = tensor->dims->data[1]; + int features = tensor->dims->data[2]; + for (int i = 0; i < rows; i++) { + for (int j = 0; j < columns; j++) { + for (int k = 0; k < features; k++) { + tensor->data.f[(i * columns + j) * features + k] = function(i, j, k); + } + } + } + } + + std::vector GetOutput() { return ExtractVector(output_); } + + private: + int input_; + int weight_; + int output_; +}; + +// TODO(ahentz): write more tests that exercise the details of the op, such as +// lookup errors and variable input shapes. +TEST(EmbeddingLookupOpTest, SimpleTest) { + EmbeddingLookupOpModel m({3}, {3, 2, 4}); + m.PopulateTensor(0, {1, 0, 2}); + m.Set3DWeightMatrix( + [](int i, int j, int k) { return i + j / 10.0f + k / 100.0f; }); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({ + 1.00, 1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1 + 0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0 + 2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2 + }))); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/fully_connected.cc b/tensorflow/contrib/lite/kernels/fully_connected.cc new file mode 100644 index 0000000000000000000000000000000000000000..a77fe94e499078bc2f0660e8e49fd557ed0f625d --- /dev/null +++ b/tensorflow/contrib/lite/kernels/fully_connected.cc @@ -0,0 +1,307 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/activation_functor.h" +#include "tensorflow/contrib/lite/kernels/gemm_support.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace fully_connected { + +// This file has four implementations of FullyConnected +enum KernelType { + kReference, + kGenericOptimized, // Neon-free + kNeonOptimized, + kPie, // Used by the PIE team +}; + +struct OpData { + // The scaling factor from input to output (aka the 'real multiplier') can + // be represented as a fixed point multipler plus a left shift. + int32_t output_multiplier; + int output_shift; + // The range of the fused activation layer. For example for kNone and + // uint8_t these would be 0 and 255. + int32_t output_activation_min; + int32_t output_activation_max; +}; + +constexpr int kInputTensor = 0; +constexpr int kWeightsTensor = 1; +constexpr int kBiasTensor = 2; +constexpr int kOutputTensor = 0; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + // This is a builtin op, so we don't use the contents in 'buffer', if any. + // Instead, we allocate a new object to carry information from Prepare() to + // Eval(). + gemm_support::IncrementUsageCounter(context); + return new OpData; +} + +void Free(TfLiteContext* context, void* buffer) { + gemm_support::DecrementUsageCounter(context); + delete reinterpret_cast(buffer); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + auto* params = + reinterpret_cast(node->builtin_data); + OpData* data = reinterpret_cast(node->user_data); + + // Check we have all the inputs and outputs we need. + TF_LITE_ENSURE_EQ(context, node->inputs->size, 3); + TF_LITE_ENSURE_EQ(context, node->outputs->size, 1); + + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* filter = GetInput(context, node, kWeightsTensor); + TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + // Check all the parameters of tensor match within themselves and match the + // input configuration. + int input_size = 1; + for (int i = 0; i < input->dims->size; i++) { + input_size *= input->dims->data[i]; + } + + const int batch_size = input_size / filter->dims->data[1]; + const int num_units = filter->dims->data[0]; + + TF_LITE_ASSERT_EQ(input_size, batch_size * filter->dims->data[1]); + if (bias) { + TF_LITE_ASSERT_EQ(bias->dims->data[0], num_units); + } + + TF_LITE_ENSURE_EQ(context, NumDimensions(filter), 2); + TF_LITE_ENSURE_EQ(context, NumDimensions(bias), 1); + + // Note that quantized inference requires that all tensors have their + // parameters set. This is usually done during quantized training. + TfLiteType data_type = input->type; + if (data_type != kTfLiteFloat32) { + double real_multiplier = 0.0; + TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler( + context, input, filter, bias, output, &real_multiplier)); + QuantizeMultiplierSmallerThanOne(real_multiplier, &data->output_multiplier, + &data->output_shift); + CalculateActivationRangeUint8(params->activation, output, + &data->output_activation_min, + &data->output_activation_max); + } + + // Resize output. + TfLiteIntArray* output_size_array = TfLiteIntArrayCreate(2); + output_size_array->data[0] = batch_size; + output_size_array->data[1] = num_units; + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, output, output_size_array)); + return kTfLiteOk; +} + +TfLiteStatus EvalPie(TfLiteContext* context, TfLiteNode* node, + TfLiteFullyConnectedParams* params, OpData* data, + TfLiteTensor* input, TfLiteTensor* filter, + TfLiteTensor* bias, TfLiteTensor* output) { + int total_input_size = 1; + for (int i = 0; i < input->dims->size; i++) { + total_input_size *= input->dims->data[i]; + } + + int input_size = filter->dims->data[1]; + const int batch_size = total_input_size / filter->dims->data[1]; + const int num_units = filter->dims->data[0]; + + // Output = bias if bias tensor exists. + if (bias) { + tensor_utils::VectorBatchVectorAssign(bias->data.f, num_units, batch_size, + output->data.f); + } else { + tensor_utils::ZeroVector(output->data.f, batch_size * num_units); + } + + // Compute output += weight * input + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + filter->data.f, num_units, input_size, input->data.f, batch_size, + output->data.f, /*result_stride=*/1); + + // Apply activation function + tensor_utils::ApplyActivationToVector(output->data.f, batch_size * num_units, + params->activation, output->data.f); + + return kTfLiteOk; +} + +#define TF_LITE_MACRO_DISPATCH(macro_name, params, target_namespace) \ + if (params->activation == kTfLiteActNone) { \ + macro_name(target_namespace, kNone); \ + } \ + if (params->activation == kTfLiteActRelu) { \ + macro_name(target_namespace, kRelu); \ + } \ + if (params->activation == kTfLiteActRelu6) { \ + macro_name(target_namespace, kRelu6); \ + } + +template +TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node, + TfLiteFullyConnectedParams* params, OpData* data, + TfLiteTensor* input, TfLiteTensor* filter, + TfLiteTensor* bias, TfLiteTensor* output) { + gemmlowp::GemmContext* gemm_context = gemm_support::GetFromContext(context); + + int32_t input_offset = -input->params.zero_point; + int32_t filter_offset = -filter->params.zero_point; + int32_t output_offset = output->params.zero_point; +#define TF_LITE_FULLY_CONNECTED(type) \ + type::FullyConnected( \ + GetTensorData(input), GetTensorDims(input), input_offset, \ + GetTensorData(filter), GetTensorDims(filter), filter_offset, \ + GetTensorData(bias), GetTensorDims(bias), output_offset, \ + data->output_multiplier, data->output_shift, \ + data->output_activation_min, data->output_activation_max, \ + GetTensorData(output), GetTensorDims(output), gemm_context) + if (kernel_type == kReference) { + TF_LITE_FULLY_CONNECTED(reference_ops); + } else if (kernel_type == kPie) { + // TODO(ahentz): we don't have a quantized version of the PIE kernels, so + // we just defer to the MINI ones. + TF_LITE_FULLY_CONNECTED(optimized_ops); + } else { + TF_LITE_FULLY_CONNECTED(optimized_ops); + } +#undef TF_LITE_FULLY_CONNECTED + + return kTfLiteOk; +} + +template +TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node, + TfLiteFullyConnectedParams* params, OpData* data, + TfLiteTensor* input, TfLiteTensor* filter, + TfLiteTensor* bias, TfLiteTensor* output) { + float output_activation_min, output_activation_max; + CalculateActivationRangeFloat(params->activation, &output_activation_min, + &output_activation_max); +#define TF_LITE_FULLY_CONNECTED(type) \ + type::FullyConnected(GetTensorData(input), GetTensorDims(input), \ + GetTensorData(filter), GetTensorDims(filter), \ + GetTensorData(bias), GetTensorDims(bias), \ + output_activation_min, output_activation_max, \ + GetTensorData(output), GetTensorDims(output)) + if (kernel_type == kReference) { + TF_LITE_FULLY_CONNECTED(reference_ops); + } else if (kernel_type == kPie) { + return EvalPie(context, node, params, data, input, filter, bias, output); + } else { + TF_LITE_FULLY_CONNECTED(optimized_ops); + } +#undef TF_LITE_FULLY_CONNECTED + + return kTfLiteOk; +} + +#undef TF_LITE_MACRO_DISPATCH + +template +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = + reinterpret_cast(node->builtin_data); + OpData* data = reinterpret_cast(node->user_data); + + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* filter = GetInput(context, node, kWeightsTensor); + TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + switch (input->type) { // Already know in/out types are same. + case kTfLiteFloat32: + return EvalFloat(context, node, params, data, input, filter, + bias, output); + case kTfLiteUInt8: + return EvalQuantized(context, node, params, data, input, + filter, bias, output); + default: + context->ReportError(context, "Type not currently supported."); + return kTfLiteError; + } + return kTfLiteOk; +} + +} // namespace fully_connected + +TfLiteRegistration* Register_FULLY_CONNECTED_REF() { + static TfLiteRegistration r = { + fully_connected::Init, fully_connected::Free, fully_connected::Prepare, + fully_connected::Eval}; + return &r; +} + +TfLiteRegistration* Register_FULLY_CONNECTED_NEON_OPT() { + static TfLiteRegistration r = { + fully_connected::Init, fully_connected::Free, fully_connected::Prepare, + fully_connected::Eval}; + return &r; +} + +TfLiteRegistration* Register_FULLY_CONNECTED_GENERIC_OPT() { + static TfLiteRegistration r = { + fully_connected::Init, fully_connected::Free, fully_connected::Prepare, + fully_connected::Eval}; + return &r; +} + +TfLiteRegistration* Register_FULLY_CONNECTED_PIE() { + static TfLiteRegistration r = {fully_connected::Init, fully_connected::Free, + fully_connected::Prepare, + fully_connected::Eval}; + return &r; +} + +TfLiteRegistration* Register_FULLY_CONNECTED() { + // TODO(ahentz): We don't have a dedicated quantized version of the PIE + // kernel. For now, the quantized version just defer to the corresponding + // optimized MINI kernel. At some point we will allow different libraries to + // be built with different kernels, but for now we have to pick one here. + return Register_FULLY_CONNECTED_PIE(); +#ifdef USE_NEON + return Register_FULLY_CONNECTED_NEON_OPT(); +#else + return Register_FULLY_CONNECTED_GENERIC_OPT(); +#endif +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/fully_connected_test.cc b/tensorflow/contrib/lite/kernels/fully_connected_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..a0f766c4f4580d7679275c0b63aa200410fcb5ad --- /dev/null +++ b/tensorflow/contrib/lite/kernels/fully_connected_test.cc @@ -0,0 +1,376 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Unit test for TFLite FULLY_CONNECTED op. + +#include +#include + +#include +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAre; +using ::testing::ElementsAreArray; + +static float fully_connected_input[] = { + 0.503691, 0.196961, 0.521017, 0.554248, 0.288678, 0.792476, 0.561653, + 0.462230, 0.650736, 0.163132, 0.029658, 0.411544, 0.470539, 0.572390, + 0.538755, 0.212030, 0.264309, 0.193908, 0.777480, 0.745661, 0.423314, + 0.470804, 0.175501, 0.492225, 0.192743, 0.540183, 0.372514, 0.446550, + 0.498173, 0.126472, 0.132706, 0.001864, 0.323433, 0.653723, 0.556112, + 0.612111, 0.446199, 0.117765, 0.074341, 0.096935, 0.280897, 0.103999, + 0.508479, 0.751437, 0.676389, 0.047234, 0.963467, 0.940698, 0.241142, + 0.740947, 0.686359, 0.664456, 0.211751, 0.861860, 0.156681, 0.404494, + 0.402043, 0.529195, 0.851044, 0.900216, 0.655667, 0.983750, 0.902081, + 0.979100, 0.637473, 0.458193, 0.591211, 0.083671, 0.575958, 0.665552, + 0.180606, 0.856856, 0.769551, 0.689086, 0.608293, 0.445940, 0.736320, + 0.571760, 0.386637, 0.977461, 0.312707, 0.072996, 0.641918, 0.524458, + 0.934856, 0.798598, 0.928951, 0.336899, 0.327793, 0.779995, 0.237115, + 0.983460, 0.763746, 0.139196, 0.962560, 0.401218, 0.597389, 0.553771, + 0.484890, 0.173347, 0.219322, 0.665496, 0.030203, 0.988873, 0.354582, + 0.638496, 0.434813, 0.090902, 0.210256, 0.821450, 0.068363, 0.522962, + 0.894446, 0.710280, 0.047420, 0.829302, 0.508879, 0.976371, 0.166202, + 0.836672, 0.756367, 0.403317, 0.820132, 0.520112, 0.542513, 0.782691, + 0.921330, 0.139902}; + +static float fully_connected_golden_output[] = { + 0, 0.0732134, 0, 0, 0, 0.280859, + 0, 0.128927, 0, 0.0777251, 0, 0.270268, + 0.271435, 0.0173503, 0.335465, 0.235562, + + 0, 0.0745866, 0, 0.051611, 0, 0.253876, + 0, 0.0814873, 0, 0.104104, 0, 0.248529, + 0.264194, 0, 0.302973, 0.166252, + + 0, 0.0170409, 0, 0.0509851, 0, 0.212834, + 0, 0.0208326, 0, 0.129932, 0.203978, 0.103428, + 0.298051, 0, 0.332233, 0.00445903, + + 0, 0.125246, 0, 0.0735336, 0, 0.0910256, + 0, 0, 0, 0.18933, 0.378111, 0.0712443, + 0.277298, 0.0123414, 0.267454, 0, + + 0, 0.14687, 0, 0.155495, 0.0300215, 0.147256, + 0, 0, 0, 0.156412, 0.434914, 0.0461529, + 0.246508, 0, 0.363138, 0, + + 0, 0, 0, 0.0212949, 0, 0.301708, + 0, 0.35497, 0, 0.406223, 0.0260211, 0.049195, + 0.197161, 0, 0.37316, 0, + + 0, 0.221783, 0, 0, 0.0116515, 0.281945, + 0, 0, 0, 0, 0.285626, 0.181773, + 0.296401, 0.170452, 0.367135, 0.142597, + + 0, 0, 0, 0, 0, 0.418886, + 0, 0.291063, 0, 0.227541, 0.0424759, 0.27589, + 0.398286, 0.177146, 0.40359, 0.121452, + + 0, 0.0834884, 0, 0, 0, 0.287441, + 0, 0.0046838, 0, 0.0122087, 0, 0.217376, + 0.140183, 0.0948412, 0.436677, 0.0589876, + + 0, 0.0289969, 0, 0.0921397, 0, 0.396802, + 0, 0.0126157, 0, 0.0968433, 0, 0.172271, + 0.173295, 0.0664741, 0.53645, 0.00915603, + + 0, 0, 0, 0, 0, 0.147942, + 0, 0.263795, 0, 0.39782, 0, 0.382435, + 0.561072, 0.0579847, 0.145712, 0.13508, + + 0, 0, 0, 0.16382, 0, 0.322294, + 0, 0.163798, 0, 0.405211, 0.367953, 0.076852, + 0.342473, 0.0834118, 0.377537, 0, + + 0, 0.206, 0, 0, 0, 0.375769, + 0, 0, 0, 0, 0, 0.125165, + 0, 0.105591, 0.52055, 0.0536445, + + 0, 0.259261, 0, 0, 0, 0.247707, + 0, 0, 0, 0, 0, 0.215862, + 0.149153, 0.224678, 0.359519, 0.129419, + + 0, 0.17611, 0, 0.280895, 0, 0.576484, + 0, 0.000418848, 0, 0, 0, 0.151112, + 0.211902, 0, 0.566341, 0.106305, + + 0, 0.0246284, 0, 0, 0, 0.196267, + 0, 0.0248624, 0, 0.265635, 0, 0.436199, + 0.408079, 0.134514, 0.328489, 0.411368}; + +class BaseFullyConnectedOpModel : public SingleOpModel { + public: + // TODO(ahentz): test different activation types too. + BaseFullyConnectedOpModel(int units, int batches, const TensorData& input, + const TensorData& output = {TensorType_FLOAT32}) + : batches_(batches), units_(units) { + int total_input_size = 1; + for (int i = 0; i < input.shape.size(); ++i) { + total_input_size *= input.shape[i]; + } + input_size_ = total_input_size / batches_; + + input_ = AddInput(input); + weights_ = + AddInput({input.type, {units_, input_size_}, input.min, input.max}); + + if (input.type == TensorType_FLOAT32) { + bias_ = AddInput({TensorType_FLOAT32, {units_}}); + } else { + // This is a quantized version. The scale of 'bias' depends on the scales + // of input and filter. Supposedly this is correctly set during quantized + // training. + auto bias_scale = GetScale(input_) * GetScale(weights_); + TensorData bias{TensorType_INT32, {units_}, 0, 0, bias_scale}; + bias_ = AddInput(bias); + } + + output_ = AddOutput(output); + + SetBuiltinOp( + BuiltinOperator_FULLY_CONNECTED, BuiltinOptions_FullyConnectedOptions, + CreateFullyConnectedOptions(builder_, ActivationFunctionType_RELU) + .Union()); + BuildInterpreter({GetShape(input_), GetShape(weights_), GetShape(bias_)}); + } + + int input_size() { return input_size_; } + int num_units() { return units_; } + int num_batches() { return batches_; } + + protected: + int input_; + int weights_; + int bias_; + int output_; + + int batches_; + int units_; + int input_size_; +}; + +class FloatFullyConnectedOpModel : public BaseFullyConnectedOpModel { + public: + using BaseFullyConnectedOpModel::BaseFullyConnectedOpModel; + + void SetBias(std::initializer_list f) { PopulateTensor(bias_, f); } + + void SetWeights(std::initializer_list f) { + PopulateTensor(weights_, f); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + void SetInput(int offset, float* begin, float* end) { + PopulateTensor(input_, offset, begin, end); + } + + std::vector GetOutput() { return ExtractVector(output_); } +}; + +class QuantizedFullyConnectedOpModel : public BaseFullyConnectedOpModel { + public: + using BaseFullyConnectedOpModel::BaseFullyConnectedOpModel; + + void SetBias(std::initializer_list data) { + QuantizeAndPopulate(bias_, data); + } + void SetWeights(std::initializer_list data) { + QuantizeAndPopulate(weights_, data); + } + void SetInput(std::initializer_list data) { + QuantizeAndPopulate(input_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetDequantizedOutput() { + return Dequantize(ExtractVector(output_), + GetScale(output_), GetZeroPoint(output_)); + } +}; + +// TODO(ahentz): add more small tests like this one, focused on making sure the +// calculations are correct. +TEST(FullyConnectedOpTest, SimpleTest) { + FloatFullyConnectedOpModel m(3, 2, {TensorType_FLOAT32, {2, 10}}); + m.SetWeights({ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 + }); + m.SetBias({1, 2, 3}); + + m.SetInput({ + 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0 + 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1 + }); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAre(24, 25, 26, 58, 59, 60)); +} + +TEST(FullyConnectedOpTest, SimpleTestQuantized) { + QuantizedFullyConnectedOpModel m( + 3, 2, + /*input=*/{TensorType_UINT8, {2, 10}, -63.5, 64}, + /*output=*/{TensorType_UINT8, {}, -127, 128}); + + // input_product_scale < output_scale was not true. + m.SetWeights({ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 + }); + m.SetBias({1, 2, 3}); + + m.SetInput({ + 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0 + 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1 + }); + + m.Invoke(); + + EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear({ + 24, 25, 26, // + 58, 59, 60, // + }))); + EXPECT_THAT(m.GetOutput(), ElementsAre(151, 152, 153, 185, 186, 187)); +} + +TEST(FullyConnectedOpTest, SimpleTest4DInput) { + // Note that it is not required that the first dimension be the number of + // batches. All we care is that the input can be evenly distributed in + // batches. In this case, we need the input to have multiples of '2'. + FloatFullyConnectedOpModel m(/*units=*/3, + /*batches=*/2, + /*input=*/{TensorType_FLOAT32, {4, 1, 5, 1}}); + m.SetWeights({ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 + }); + m.SetBias({1, 2, 3}); + + m.SetInput({ + 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // first batch + 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // second batch + }); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 24, 25, 26, // first batch + 58, 59, 60, // second batch + })); +} + +TEST(FullyConnectedOpTest, SimpleTest4dInputQuantized) { + QuantizedFullyConnectedOpModel m( + 3, 2, + /*input=*/{TensorType_UINT8, {4, 1, 5, 1}, -63.5, 64}, + /*output=*/{TensorType_UINT8, {}, -127, 128}); + + // input_product_scale < output_scale was not true. + m.SetWeights({ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 + }); + m.SetBias({1, 2, 3}); + + m.SetInput({ + 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0 + 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1 + }); + + m.Invoke(); + + EXPECT_THAT(m.GetDequantizedOutput(), ElementsAreArray(ArrayFloatNear({ + 24, 25, 26, // + 58, 59, 60, // + }))); + EXPECT_THAT(m.GetOutput(), ElementsAre(151, 152, 153, 185, 186, 187)); +} + +// TODO(ahentz): Reconsider this test. Having arbitrary weights makes it hard +// to debug errors and doesn't necessarily test all the important details. +TEST(FullyConnectedOpTest, BlackBoxTest) { + FloatFullyConnectedOpModel m(16, 2, {TensorType_FLOAT32, {2, 8}}); + m.SetWeights( + {0.091327, 0.103366, -0.316505, -0.083120, 0.149366, -0.196636, + -0.123672, 0.062800, 0.063031, 0.191670, -0.062001, -0.061504, + -0.275581, 0.059388, -0.118497, -0.079224, 0.109758, 0.008307, + -0.062657, -0.060962, -0.049782, -0.106719, -0.319482, -0.103650, + 0.266455, 0.051517, -0.123448, 0.322464, 0.043282, -0.173782, + -0.190381, 0.002013, 0.096086, 0.131157, 0.031164, 0.100638, + -0.312191, -0.080923, -0.101318, -0.116614, 0.142238, 0.086540, + -0.139154, 0.174268, -0.073161, 0.080072, 0.006874, 0.229382, + -0.104321, -0.176035, -0.208587, -0.001019, -0.162032, 0.080824, + -0.025021, 0.074460, -0.252595, -0.161750, -0.136403, 0.008308, + 0.005710, 0.096600, 0.289839, 0.218816, -0.304651, -0.070958, + 0.054598, 0.147113, -0.139112, -0.072798, -0.163335, -0.167863, + -0.128762, -0.035780, 0.117262, 0.017177, 0.263335, -0.176612, + 0.262961, -0.093654, -0.339283, 0.333071, 0.180827, 0.287583, + 0.066350, -0.197947, -0.114449, -0.236035, 0.103532, -0.034284, + 0.093299, -0.145361, 0.054001, 0.250570, 0.157010, -0.143480, + -0.139061, -0.048873, 0.067557, 0.139038, 0.324106, 0.227041, + 0.037793, -0.225747, -0.241619, 0.357835, 0.135762, -0.306764, + -0.125982, 0.091916, 0.266587, 0.030135, 0.265148, 0.141627, + 0.020120, 0.083815, -0.124556, -0.100124, -0.048159, 0.181172, + 0.302309, -0.041084, 0.146334, -0.061511, -0.232605, 0.281324, + 0.145408, -0.221897}); + m.SetBias({-0.160594, 0.205770, -0.078307, -0.077984, 0.001937, 0.015860, + 0.036810, 0.012346, 0.001028, 0.038551, 0.075415, 0.020804, + 0.048478, -0.032270, 0.175688, -0.085662}); + + const int input_sequence_size = sizeof(fully_connected_input) / + sizeof(float) / + (m.input_size() * m.num_batches()); + for (int i = 0; i < input_sequence_size; i++) { + // TODO(ahentz): This is what the original test was doing: two equal + // batches per invocation. We could instead use two different batches. + float* batch_start = fully_connected_input + i * m.input_size(); + float* batch_end = batch_start + m.input_size(); + m.SetInput(0, batch_start, batch_end); + m.SetInput(m.input_size(), batch_start, batch_end); + + m.Invoke(); + + float* golden_start = fully_connected_golden_output + i * m.num_units(); + float* golden_end = golden_start + m.num_units(); + std::vector expected; + expected.insert(expected.end(), golden_start, golden_end); + expected.insert(expected.end(), golden_start, golden_end); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); + } +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/gemm_support.cc b/tensorflow/contrib/lite/kernels/gemm_support.cc new file mode 100644 index 0000000000000000000000000000000000000000..eb2b0aacf7ecc3ed5dbde5ccce7a46dcda0a93b3 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/gemm_support.cc @@ -0,0 +1,68 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/kernels/gemm_support.h" + +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace gemm_support { + +struct RefCountedGemmContext { + gemmlowp::GemmContext* gemm_context_ = nullptr; + int num_references_ = 0; +}; + +void IncrementUsageCounter(TfLiteContext* context) { + auto* ptr = reinterpret_cast(context->gemm_context); + if (ptr == nullptr) { + ptr = new RefCountedGemmContext; + ptr->gemm_context_ = new gemmlowp::GemmContext(); + ptr->num_references_ = 0; + context->gemm_context = ptr; + } + ptr->num_references_++; +} + +void DecrementUsageCounter(TfLiteContext* context) { + auto* ptr = reinterpret_cast(context->gemm_context); + if (ptr == nullptr) { + TF_LITE_FATAL( + "Call to DecrementUsageCounter() not preceded by " + "IncrementUsageCounter()"); + } + if (--ptr->num_references_ == 0) { + delete ptr->gemm_context_; + delete ptr; + context->gemm_context = nullptr; + } +} + +gemmlowp::GemmContext* GetFromContext(TfLiteContext* context) { + auto* ptr = reinterpret_cast(context->gemm_context); + if (ptr == nullptr) { + TF_LITE_FATAL( + "Call to GetFromContext() not preceded by IncrementUsageCounter()"); + } + return ptr->gemm_context_; +} + +void SetMaxNumThreads(TfLiteContext* context, int num_threads) { + IncrementUsageCounter(context); + GetFromContext(context)->set_max_num_threads(num_threads); + DecrementUsageCounter(context); +} + +} // namespace gemm_support +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/gemm_support.h b/tensorflow/contrib/lite/kernels/gemm_support.h new file mode 100644 index 0000000000000000000000000000000000000000..b531959ffb143c774ee715743480b03ebfbdc114 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/gemm_support.h @@ -0,0 +1,54 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_GEMM_SUPPORT_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_GEMM_SUPPORT_H_ + +#include "public/gemmlowp.h" +#include "tensorflow/contrib/lite/context.h" + +namespace tflite { +namespace gemm_support { + +// Returns the GemmContext stored in 'context', allowing multiple ops to +// share a single object, as long as they share a TfLiteContext. The caller +// must ensure that this is called between IncrementUsageCounter() and +// DecrementUsageCounter(). For example, in the implementation of an op: +// void* Init(TfLiteContext* context, const char*, size_t) { +// gemm_support::IncrementUsageCounter(context); +// return nullptr; +// } +// void Free(TfLiteContext* context, void*) { +// gemm_support::DecrementUsageCounter(context); +// } +// TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { +// auto* gemm_context = gemm_support::GetFromContext(context); +// } +gemmlowp::GemmContext* GetFromContext(TfLiteContext* context); + +// Let the framework know that the GemmContext stored in 'context' will be used +// by an op. If necessary a new GemmContext is created and placed in 'context'. +void IncrementUsageCounter(TfLiteContext* context); + +// Let the framework know that the op stopped using the GemmContext stored in +// 'context'. If there are no more usages the GemmContext will be deleted. +void DecrementUsageCounter(TfLiteContext* context); + +// Set the maximum number threads available for gemmlowp operations. +void SetMaxNumThreads(TfLiteContext* context, int num_threads); + +} // namespace gemm_support +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_GEMM_SUPPORT_H_ diff --git a/tensorflow/contrib/lite/kernels/hashtable_lookup.cc b/tensorflow/contrib/lite/kernels/hashtable_lookup.cc new file mode 100644 index 0000000000000000000000000000000000000000..3b82601d119b2e4946db6e3577300168c7e710b6 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/hashtable_lookup.cc @@ -0,0 +1,155 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Op that looks up items from hashtable. +// +// Input: +// Tensor[0]: Hash key to lookup, dim.size == 1, int32 +// Tensor[1]: Key of hashtable, dim.size == 1, int32 +// *MUST* be sorted in ascending order. +// Tensor[2]: Value of hashtable, dim.size >= 1 +// Tensor[1].Dim[0] == Tensor[2].Dim[0] +// +// Output: +// Output[0].dim[0] == Tensor[0].dim[0], num of lookups +// Each item in output is a raw bytes copy of corresponding item in input. +// When key does not exist in hashtable, the returned bytes are all 0s. +// +// Output[1].dim = { Tensor[0].dim[0] }, num of lookups +// Each item indicates whether the corresponding lookup has a returned value. +// 0 for missing key, 1 for found key. + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/contrib/lite/string_util.h" + +namespace tflite { +namespace ops { +namespace builtin { + +namespace { + +int greater(const void* a, const void* b) { + return *static_cast(a) - *static_cast(b); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 3); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 2); + + TfLiteTensor* lookup = GetInput(context, node, 0); + TF_LITE_ENSURE_EQ(context, NumDimensions(lookup), 1); + TF_LITE_ENSURE_EQ(context, lookup->type, kTfLiteInt32); + + TfLiteTensor* key = GetInput(context, node, 1); + TF_LITE_ENSURE_EQ(context, NumDimensions(key), 1); + TF_LITE_ENSURE_EQ(context, key->type, kTfLiteInt32); + + TfLiteTensor* value = GetInput(context, node, 2); + TF_LITE_ENSURE(context, NumDimensions(value) >= 1); + TF_LITE_ENSURE_EQ(context, SizeOfDimension(key, 0), + SizeOfDimension(value, 0)); + if (value->type == kTfLiteString) { + TF_LITE_ENSURE_EQ(context, NumDimensions(value), 1); + } + + TfLiteTensor* hits = GetOutput(context, node, 1); + TF_LITE_ENSURE_EQ(context, hits->type, kTfLiteUInt8); + TfLiteIntArray* hitSize = TfLiteIntArrayCreate(1); + hitSize->data[0] = SizeOfDimension(lookup, 0); + + TfLiteTensor* output = GetOutput(context, node, 0); + TF_LITE_ENSURE_EQ(context, value->type, output->type); + + TfLiteStatus status = kTfLiteOk; + if (output->type != kTfLiteString) { + TfLiteIntArray* outputSize = TfLiteIntArrayCreate(NumDimensions(value)); + outputSize->data[0] = SizeOfDimension(lookup, 0); + for (int i = 1; i < NumDimensions(value); i++) { + outputSize->data[i] = SizeOfDimension(value, i); + } + status = context->ResizeTensor(context, output, outputSize); + } + if (context->ResizeTensor(context, hits, hitSize) == kTfLiteError) { + status = kTfLiteError; + } + return status; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* output = GetOutput(context, node, 0); + TfLiteTensor* hits = GetOutput(context, node, 1); + TfLiteTensor* lookup = GetInput(context, node, 0); + TfLiteTensor* key = GetInput(context, node, 1); + TfLiteTensor* value = GetInput(context, node, 2); + + const int num_rows = SizeOfDimension(value, 0); + const int row_bytes = value->bytes / num_rows; + void* pointer = nullptr; + DynamicBuffer buf; + + for (int i = 0; i < SizeOfDimension(lookup, 0); i++) { + int idx = -1; + pointer = bsearch(&(lookup->data.i32[i]), key->data.i32, num_rows, + sizeof(int32_t), greater); + if (pointer != nullptr) { + idx = (reinterpret_cast(pointer) - (key->data.raw)) / + sizeof(int32_t); + } + + if (idx >= num_rows || idx < 0) { + if (output->type == kTfLiteString) { + buf.AddString(nullptr, 0); + } else { + memset(output->data.raw + i * row_bytes, 0, row_bytes); + } + hits->data.uint8[i] = 0; + } else { + if (output->type == kTfLiteString) { + buf.AddString(GetString(value, idx)); + } else { + memcpy(output->data.raw + i * row_bytes, + value->data.raw + idx * row_bytes, row_bytes); + } + hits->data.uint8[i] = 1; + } + } + if (output->type == kTfLiteString) { + buf.WriteToTensor(output); + } + + return kTfLiteOk; +} +} // namespace + +TfLiteRegistration* Register_HASHTABLE_LOOKUP() { + static TfLiteRegistration r = {nullptr, nullptr, Prepare, Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/hashtable_lookup_test.cc b/tensorflow/contrib/lite/kernels/hashtable_lookup_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..cb6038f9009a3865661e7b4f075c3033166d0f91 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/hashtable_lookup_test.cc @@ -0,0 +1,176 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Unit test for TFLite Lookup op. + +#include +#include + +#include +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/string_util.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class HashtableLookupOpModel : public SingleOpModel { + public: + HashtableLookupOpModel(std::initializer_list lookup_shape, + std::initializer_list key_shape, + std::initializer_list value_shape, + TensorType type) { + lookup_ = AddInput(TensorType_INT32); + key_ = AddInput(TensorType_INT32); + value_ = AddInput(type); + output_ = AddOutput(type); + hit_ = AddOutput(TensorType_UINT8); + SetBuiltinOp(BuiltinOperator_HASHTABLE_LOOKUP, BuiltinOptions_NONE, 0); + BuildInterpreter({lookup_shape, key_shape, value_shape}); + } + + void SetLookup(std::initializer_list data) { + PopulateTensor(lookup_, data); + } + + void SetHashtableKey(std::initializer_list data) { + PopulateTensor(key_, data); + } + + void SetHashtableValue(const std::vector& content) { + PopulateStringTensor(value_, content); + } + + void SetHashtableValue(const std::function& function) { + TfLiteTensor* tensor = interpreter_->tensor(value_); + int rows = tensor->dims->data[0]; + for (int i = 0; i < rows; i++) { + tensor->data.f[i] = function(i); + } + } + + void SetHashtableValue(const std::function& function) { + TfLiteTensor* tensor = interpreter_->tensor(value_); + int rows = tensor->dims->data[0]; + int features = tensor->dims->data[1]; + for (int i = 0; i < rows; i++) { + for (int j = 0; j < features; j++) { + tensor->data.f[i * features + j] = function(i, j); + } + } + } + + std::vector GetStringOutput() { + TfLiteTensor* output = interpreter_->tensor(output_); + int num = GetStringCount(output); + std::vector result(num); + for (int i = 0; i < num; i++) { + auto ref = GetString(output, i); + result[i] = string(ref.str, ref.len); + } + return result; + } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetHit() { return ExtractVector(hit_); } + + private: + int lookup_; + int key_; + int value_; + int output_; + int hit_; +}; + +// TODO(yichengfan): write more tests that exercise the details of the op, +// such as lookup errors and variable input shapes. +TEST(HashtableLookupOpTest, Test2DInput) { + HashtableLookupOpModel m({4}, {3}, {3, 2}, TensorType_FLOAT32); + + m.SetLookup({1234, -292, -11, 0}); + m.SetHashtableKey({-11, 0, 1234}); + m.SetHashtableValue([](int i, int j) { return i + j / 10.0f; }); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 2.0, 2.1, // 2-nd item + 0, 0, // Not found + 0.0, 0.1, // 0-th item + 1.0, 1.1, // 1-st item + }))); + EXPECT_THAT(m.GetHit(), ElementsAreArray({ + 1, 0, 1, 1, + })); +} + +TEST(HashtableLookupOpTest, Test1DInput) { + HashtableLookupOpModel m({4}, {3}, {3}, TensorType_FLOAT32); + + m.SetLookup({1234, -292, -11, 0}); + m.SetHashtableKey({-11, 0, 1234}); + m.SetHashtableValue([](int i) { return i * i / 10.0f; }); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 0.4, // 2-nd item + 0, // Not found + 0.0, // 0-th item + 0.1, // 1-st item + }))); + EXPECT_THAT(m.GetHit(), ElementsAreArray({ + 1, + 0, + 1, + 1, + })); +} + +TEST(HashtableLookupOpTest, TestString) { + HashtableLookupOpModel m({4}, {3}, {3}, TensorType_STRING); + + m.SetLookup({1234, -292, -11, 0}); + m.SetHashtableKey({-11, 0, 1234}); + m.SetHashtableValue({"Hello", "", "Hi"}); + + m.Invoke(); + + EXPECT_THAT(m.GetStringOutput(), ElementsAreArray({ + "Hi", // 2-nd item + "", // Not found + "Hello", // 0-th item + "", // 1-st item + })); + EXPECT_THAT(m.GetHit(), ElementsAreArray({ + 1, + 0, + 1, + 1, + })); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/internal/BUILD b/tensorflow/contrib/lite/kernels/internal/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..288534099b9e090ce0c223a401b4152ca6ffb61f --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/BUILD @@ -0,0 +1,359 @@ +package(default_visibility = [ + "//visibility:public", +]) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts") + +tflite_deps_intel = [ + "@arm_neon_2_x86_sse", +] + +NEON_FLAGS_IF_APPLICABLE = select({ + ":arm": [ + "-O3", + "-mfpu=neon", + "-mfloat-abi=softfp", + ], + ":armeabi-v7a": [ + "-O3", + "-mfpu=neon", + "-mfloat-abi=softfp", + ], + ":armv7a": [ + "-O3", + "-mfpu=neon", + "-mfloat-abi=softfp", + ], + "//conditions:default": [ + "-O3", + ], +}) + +cc_library( + name = "types", + srcs = [], + hdrs = [ + "compatibility.h", + "types.h", + ], +) + +config_setting( + name = "arm", + values = { + "cpu": "arm", + }, +) + +config_setting( + name = "arm64-v8a", + values = { + "cpu": "arm64-v8a", + }, +) + +config_setting( + name = "armv7a", + values = { + "cpu": "armv7a", + }, +) + +config_setting( + name = "armeabi-v7a", + values = { + "cpu": "armeabi-v7a", + }, +) + +config_setting( + name = "haswell", + values = { + "cpu": "haswell", + }, +) + +config_setting( + name = "ios_x86_64", + values = { + "cpu": "ios_x86_64", + }, +) + +config_setting( + name = "ios_armv7", + values = { + "cpu": "ios_armv7", + }, +) + +config_setting( + name = "ios_arm64", + values = { + "cpu": "ios_arm64", + }, +) + +config_setting( + name = "k8", + values = { + "cpu": "k8", + }, +) + +config_setting( + name = "x86", + values = { + "cpu": "x86", + }, +) + +config_setting( + name = "x86_64", + values = { + "cpu": "x86_64", + }, +) + +config_setting( + name = "darwin", + values = { + "cpu": "darwin", + }, +) + +cc_library( + name = "optimized_base", + srcs = [], + hdrs = [ + "common.h", + "optimized/depthwiseconv_float.h", + "optimized/depthwiseconv_uint8.h", + "optimized/optimized_ops.h", + ], + copts = tflite_copts(), + deps = [ + ":types", + ":round", + "//third_party/eigen3", + "@gemmlowp//:gemmlowp", + "//tensorflow/contrib/lite:builtin_op_data", + ] + select({ + ":haswell": tflite_deps_intel, + ":ios_x86_64": tflite_deps_intel, + ":k8": tflite_deps_intel, + ":x86": tflite_deps_intel, + ":x86_64": tflite_deps_intel, + ":darwin": tflite_deps_intel, + "//conditions:default": [], + }), +) + +cc_library( + name = "optimized", + hdrs = [ + "optimized/eigen_spatial_convolutions.h", + "optimized/eigen_tensor_reduced_instantiations_oss.h", + "optimized/multithreaded_conv.h", + "tensor.h", + ], + deps = [ + ":optimized_base", + ":types", + "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite:context", + "//third_party/eigen3", + ], +) + +cc_test( + name = "tensor_test", + srcs = ["tensor_test.cc"], + deps = [ + ":reference", + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "round", + srcs = [], + hdrs = ["round.h"], +) + +cc_library( + name = "quantization_util", + srcs = ["quantization_util.cc"], + hdrs = [ + "compatibility.h", + "quantization_util.h", + ], + deps = [":round"], +) + +cc_test( + name = "quantization_util_test", + srcs = ["quantization_util_test.cc"], + deps = [ + ":quantization_util", + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "reference_base", + srcs = [], + hdrs = [ + "common.h", + "reference/depthwiseconv_float.h", + "reference/depthwiseconv_uint8.h", + "reference/reference_ops.h", + ], + deps = [ + ":round", + ":types", + "//third_party/eigen3", + "@gemmlowp//:gemmlowp", + "//tensorflow/contrib/lite:builtin_op_data", + ] + select({ + ":haswell": tflite_deps_intel, + ":ios_x86_64": tflite_deps_intel, + ":k8": tflite_deps_intel, + ":x86": tflite_deps_intel, + ":x86_64": tflite_deps_intel, + ":darwin": tflite_deps_intel, + "//conditions:default": [], + }), +) + +cc_library( + name = "reference", + hdrs = ["tensor.h"], + deps = [ + ":types", + "//tensorflow/contrib/lite:context", + ], +) + +cc_library( + name = "portable_tensor_utils", + srcs = [ + "reference/portable_tensor_utils.cc", + ], + hdrs = [ + "reference/portable_tensor_utils.h", + ], + deps = [ + "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite/kernels:activation_functor", + "//tensorflow/contrib/lite/kernels:op_macros", + ], +) + +cc_library( + name = "neon_tensor_utils", + srcs = [ + "optimized/neon_tensor_utils.cc", + ], + hdrs = [ + "optimized/neon_tensor_utils.h", + "optimized/tensor_utils_impl.h", + ], + copts = NEON_FLAGS_IF_APPLICABLE, + deps = [ + ":cpu_check", + ":portable_tensor_utils", + "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite/kernels:activation_functor", + ], +) + +cc_library( + name = "tensor_utils", + srcs = [ + "tensor_utils.cc", + ], + hdrs = [ + "optimized/tensor_utils_impl.h", + "reference/portable_tensor_utils.h", + "tensor_utils.h", + ], + copts = NEON_FLAGS_IF_APPLICABLE, + deps = [ + "//tensorflow/contrib/lite/kernels:activation_functor", + "//tensorflow/contrib/lite:builtin_op_data", + ] + select({ + ":arm": [ + ":neon_tensor_utils", + ], + ":arm64-v8a": [ + ":neon_tensor_utils", + ], + ":armeabi-v7a": [ + ":neon_tensor_utils", + ], + ":armv7a": [ + ":neon_tensor_utils", + ], + ":ios_armv7": [ + ":neon_tensor_utils", + ], + ":ios_arm64": [ + ":neon_tensor_utils", + ], + "//conditions:default": [ + ":portable_tensor_utils", + ], + }), +) + +cc_test( + name = "tensor_utils_test", + srcs = ["tensor_utils_test.cc"], + copts = NEON_FLAGS_IF_APPLICABLE, + linkopts = select({ + "//tensorflow:android": [ + "-fPIE -pie", + ], + "//conditions:default": [], + }), + linkstatic = 1, + deps = [ + ":tensor_utils", + "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "cpu_check", + hdrs = [ + "optimized/cpu_check.h", + ], + deps = [ + ] + select( + { + "//tensorflow:android": [ + "@androidndk//:cpufeatures", + ], + "//conditions:default": [], + }, + ), +) + +exports_files(["optimized/eigen_tensor_reduced_instantiations_oss.h"]) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/lite/kernels/internal/common.h b/tensorflow/contrib/lite/kernels/internal/common.h new file mode 100644 index 0000000000000000000000000000000000000000..28f19a250629aec4d03aa71df57d31d8a5014e9f --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/common.h @@ -0,0 +1,107 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMMON_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMMON_H_ + +#ifndef ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK +#ifdef GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK +#define ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK +#endif +#endif + +#ifndef USE_NEON +#if defined(__ARM_NEON__) || defined(__ARM_NEON) +#define USE_NEON +#include +#endif + +#if defined __GNUC__ && defined __SSE4_1__ +#define USE_NEON + +#define OPTIMIZED_OPS_H__IGNORE_DEPRECATED_DECLARATIONS +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wdeprecated-declarations" +#pragma GCC diagnostic ignored "-Wattributes" + +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wnarrowing" +#pragma GCC diagnostic ignored "-Wsequence-point" + +#include "NEON_2_SSE.h" + +#pragma GCC diagnostic pop +#endif +#endif + +#include "public/gemmlowp.h" +#include "tensorflow/contrib/lite/kernels/internal/types.h" + +namespace tflite { + +inline void GetActivationMinMax(FusedActivationFunctionType ac, + float* output_activation_min, + float* output_activation_max) { + switch (ac) { + case FusedActivationFunctionType::kNone: + *output_activation_min = std::numeric_limits::lowest(); + *output_activation_max = std::numeric_limits::max(); + break; + case FusedActivationFunctionType::kRelu: + *output_activation_min = 0.f; + *output_activation_max = std::numeric_limits::max(); + break; + case FusedActivationFunctionType::kRelu1: + *output_activation_min = -1.f; + *output_activation_max = 1.f; + break; + case FusedActivationFunctionType::kRelu6: + *output_activation_min = 0.f; + *output_activation_max = 6.f; + break; + } +} + +inline float ActivationFunctionWithMinMax(float x, float output_activation_min, + float output_activation_max) { + return std::min(std::max(x, output_activation_min), output_activation_max); +} + +// Legacy function, left for compatibility only. +template +float ActivationFunction(float x) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + return ActivationFunctionWithMinMax(x, output_activation_min, + output_activation_max); +} + +inline int32 MultiplyByQuantizedMultiplierSmallerThanOne( + int32 x, int32 quantized_multiplier, int right_shift) { + using gemmlowp::RoundingDivideByPOT; + using gemmlowp::SaturatingRoundingDoublingHighMul; + return RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(x, quantized_multiplier), right_shift); +} + +inline int32 MultiplyByQuantizedMultiplierGreaterThanOne( + int32 x, int32 quantized_multiplier, int left_shift) { + using gemmlowp::SaturatingRoundingDoublingHighMul; + return SaturatingRoundingDoublingHighMul(x * (1 << left_shift), + quantized_multiplier); +} + +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMMON_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/compatibility.h b/tensorflow/contrib/lite/kernels/internal/compatibility.h new file mode 100644 index 0000000000000000000000000000000000000000..796a03566a4bf971294dd2375f590dfd20d600f7 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/compatibility.h @@ -0,0 +1,78 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMPATIBILITY_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMPATIBILITY_H_ + +#include +#include +#include + +#ifndef TFLITE_DCHECK +#define TFLITE_DCHECK(condition) (condition) ? (void)0 : assert(false) +#endif + +#ifndef TFLITE_DCHECK_EQ +#define TFLITE_DCHECK_EQ(x, y) ((x) == (y)) ? (void)0 : assert(false) +#endif + +#ifndef TFLITE_DCHECK_GE +#define TFLITE_DCHECK_GE(x, y) ((x) >= (y)) ? (void)0 : assert(false) +#endif + +#ifndef TFLITE_DCHECK_GT +#define TFLITE_DCHECK_GT(x, y) ((x) > (y)) ? (void)0 : assert(false) +#endif + +#ifndef TFLITE_DCHECK_LE +#define TFLITE_DCHECK_LE(x, y) ((x) <= (y)) ? (void)0 : assert(false) +#endif + +#ifndef TFLITE_DCHECK_LT +#define TFLITE_DCHECK_LT(x, y) ((x) < (y)) ? (void)0 : assert(false) +#endif + +// TODO(ahentz): Clean up: We should stick to the DCHECK versions. +#ifndef TFLITE_CHECK +#define TFLITE_CHECK(condition) (condition) ? (void)0 : abort() +#endif + +#ifndef TFLITE_CHECK_EQ +#define TFLITE_CHECK_EQ(x, y) ((x) == (y)) ? (void)0 : abort() +#endif + +#ifndef TFLITE_CHECK_GE +#define TFLITE_CHECK_GE(x, y) ((x) >= (y)) ? (void)0 : abort() +#endif + +#ifndef TFLITE_CHECK_GT +#define TFLITE_CHECK_GT(x, y) ((x) > (y)) ? (void)0 : abort() +#endif + +#ifndef TFLITE_CHECK_LE +#define TFLITE_CHECK_LE(x, y) ((x) <= (y)) ? (void)0 : abort() +#endif + +#ifndef TFLITE_CHECK_LT +#define TFLITE_CHECK_LT(x, y) ((x) < (y)) ? (void)0 : abort() +#endif + +// TODO(ahentz): Clean up. +using uint8 = std::uint8_t; +using int16 = std::int16_t; +using uint16 = std::uint16_t; +using int32 = std::int32_t; +using uint32 = std::uint32_t; + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMPATIBILITY_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/cpu_check.h b/tensorflow/contrib/lite/kernels/internal/optimized/cpu_check.h new file mode 100644 index 0000000000000000000000000000000000000000..dea46cc12065ed34cf681916a46a55bd7a86f463 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/optimized/cpu_check.h @@ -0,0 +1,65 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_CPU_CHECK_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_CPU_CHECK_ + +namespace tflite { + +#ifdef __ANDROID__ +#include "ndk/sources/android/cpufeatures/cpu-features.h" + +// Runtime check for Neon support on Android. +inline bool TestCPUFeatureNeon() { +#ifdef __aarch64__ + // ARM-64 always has NEON support. + return true; +#else + static bool kUseAndroidNeon = + (android_getCpuFamily() == ANDROID_CPU_FAMILY_ARM && + android_getCpuFeatures() & ANDROID_CPU_ARM_FEATURE_ARMv7 && + android_getCpuFeatures() & ANDROID_CPU_ARM_FEATURE_NEON); + return kUseAndroidNeon; +#endif // __aarch64__ +} + +#elif __ARM_NEON + +inline bool TestCPUFeatureNeon() { + return true; +} + +#else + +inline bool TestCPUFeatureNeon() { + return false; +} + +#endif + +} // namespace tflite + +// NEON_OR_PORTABLE(SomeFunc, arcs) calls NeonSomeFunc(args) if Neon is both +// enabled at build time and detected at runtime, or PortableSomeFunc(args) +// otherwise. +#ifdef __ARM_ARCH_5TE__ +// Neon isn't available at all on ARMv5. +#define NEON_OR_PORTABLE(funcname, ...) Portable##funcname(__VA_ARGS__) +#else +#define NEON_OR_PORTABLE(funcname, ...) \ + TestCPUFeatureNeon() ? Neon##funcname(__VA_ARGS__) \ + : Portable##funcname(__VA_ARGS__) +#endif + +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_CPU_CHECK_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h new file mode 100644 index 0000000000000000000000000000000000000000..da34c8aef94b1c69e661bd33fcb518e73034c4bd --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h @@ -0,0 +1,1060 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_FLOAT_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_FLOAT_H_ + +#include "public/gemmlowp.h" +#include "tensorflow/contrib/lite/kernels/internal/common.h" +#include "tensorflow/contrib/lite/kernels/internal/types.h" + +namespace tflite { +namespace optimized_ops { + +// Implementation of float DepthwiseConv + +template +struct FloatDepthwiseConvKernel {}; + +#ifdef USE_NEON + +template <> +struct FloatDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const float* input_ptr, int input_ptr_increment, + const float* filter_ptr, float* acc_buffer_ptr) { + // Load the filters + float32x4_t filter[2]; + for (int i = 0; i < 2; i++) { + filter[i] = vld1q_f32(filter_ptr + 4 * i); + } + int outp = 0; + // Handle 2 output pixels at a time. + for (; outp <= num_output_pixels - 2; outp += 2) { + // Load the inputs + float32x4_t input[4]; + for (int i = 0; i < 4; i++) { + input[i] = vld1q_f32(input_ptr + 4 * i); + } + input_ptr += 16; + // Load the accumulators from acc_buffer + float32x4_t acc[4]; + for (int i = 0; i < 4; i++) { + acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i); + } + // Multiply-accumulate + acc[0] = vmlaq_f32(acc[0], input[0], filter[0]); + acc[1] = vmlaq_f32(acc[1], input[1], filter[1]); + acc[2] = vmlaq_f32(acc[2], input[2], filter[0]); + acc[3] = vmlaq_f32(acc[3], input[3], filter[1]); + // Store the accumulators back to acc_buffer + for (int i = 0; i < 4; i++) { + vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 16; + } + // Handle one output pixel at a time. + for (; outp < num_output_pixels; outp++) { + // Load the inputs + float32x4_t input[2]; + for (int i = 0; i < 2; i++) { + input[i] = vld1q_f32(input_ptr + 4 * i); + } + input_ptr += 8; + // Load the accumulators from acc_buffer + float32x4_t acc[2]; + for (int i = 0; i < 2; i++) { + acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i); + } + // Multiply-accumulate + for (int i = 0; i < 2; i++) { + acc[i] = vmlaq_f32(acc[i], input[i], filter[i]); + } + // Store the accumulators back to acc_buffer + for (int i = 0; i < 2; i++) { + vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 8; + } + } +}; + +template <> +struct FloatDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const float* input_ptr, int input_ptr_increment, + const float* filter_ptr, float* acc_buffer_ptr) { + const float32x2_t filters = vld1_f32(filter_ptr); + const float32x4_t filters_dup2 = vcombine_f32(filters, filters); + int outp = 0; + // Handle 8 output pixels at a time. + for (; outp <= num_output_pixels - 8; outp += 8) { + // Load the inputs + float32x4_t input[4]; + for (int i = 0; i < 4; i++) { + input[i] = vld1q_f32(input_ptr + 4 * i); + } + input_ptr += 16; + // Load the accumulators from acc_buffer + float32x4_t acc[4]; + for (int i = 0; i < 4; i++) { + acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i); + } + // Multiply-accumulate + for (int i = 0; i < 4; i++) { + acc[i] = vmlaq_f32(acc[i], input[i], filters_dup2); + } + // Store the accumulators back to acc_buffer + for (int i = 0; i < 4; i++) { + vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 16; + } + // Handle 4 output pixels at a time. + for (; outp <= num_output_pixels - 4; outp += 4) { + // Load the inputs + float32x4_t input[2]; + for (int i = 0; i < 2; i++) { + input[i] = vld1q_f32(input_ptr + 4 * i); + } + input_ptr += 8; + // Load the accumulators from acc_buffer + float32x4_t acc[2]; + for (int i = 0; i < 2; i++) { + acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i); + } + // Multiply-accumulate + for (int i = 0; i < 2; i++) { + acc[i] = vmlaq_f32(acc[i], input[i], filters_dup2); + } + // Store the accumulators back to acc_buffer + for (int i = 0; i < 2; i++) { + vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 8; + } + // Handle 2 output pixels at a time. + for (; outp <= num_output_pixels - 2; outp += 2) { + // Load the inputs + const float32x4_t input = vld1q_f32(input_ptr); + input_ptr += 4; + // Load the accumulators from acc_buffer + float32x4_t acc = vld1q_f32(acc_buffer_ptr); + // Multiply-accumulate + acc = vmlaq_f32(acc, input, filters_dup2); + // Store the accumulators back to acc_buffer + vst1q_f32(acc_buffer_ptr, acc); + acc_buffer_ptr += 4; + } + // Handle 1 output pixel at a time + for (; outp < num_output_pixels; outp++) { + // Load the inputs + const float32x2_t input = vld1_f32(input_ptr); + input_ptr += 2; + // Load the accumulators from acc_buffer + float32x2_t acc = vld1_f32(acc_buffer_ptr); + // Multiply-accumulate + acc = vmla_f32(acc, input, filters); + // Store the accumulators back to acc_buffer + vst1_f32(acc_buffer_ptr, acc); + acc_buffer_ptr += 2; + } + } +}; + +template <> +struct FloatDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const float* input_ptr, int input_ptr_increment, + const float* filter_ptr, float* acc_buffer_ptr) { + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + const float* local_filter_ptr = filter_ptr; + const float* local_input_ptr = input_ptr; + int ic = 0; + // Handle 16 input channels at a time. + for (; ic <= input_depth - 16; ic += 16) { + // Load the filters + float32x4_t filter_0 = vld1q_f32(local_filter_ptr + 4 * 0); + float32x4_t filter_1 = vld1q_f32(local_filter_ptr + 4 * 1); + float32x4_t filter_2 = vld1q_f32(local_filter_ptr + 4 * 2); + float32x4_t filter_3 = vld1q_f32(local_filter_ptr + 4 * 3); + local_filter_ptr += 16; + // Load the inputs + float32x4_t input_0 = vld1q_f32(local_input_ptr + 4 * 0); + float32x4_t input_1 = vld1q_f32(local_input_ptr + 4 * 1); + float32x4_t input_2 = vld1q_f32(local_input_ptr + 4 * 2); + float32x4_t input_3 = vld1q_f32(local_input_ptr + 4 * 3); + local_input_ptr += 16; + // Load the accumulators from acc_buffer + float32x4_t acc_0 = vld1q_f32(acc_buffer_ptr + 4 * 0); + float32x4_t acc_1 = vld1q_f32(acc_buffer_ptr + 4 * 1); + float32x4_t acc_2 = vld1q_f32(acc_buffer_ptr + 4 * 2); + float32x4_t acc_3 = vld1q_f32(acc_buffer_ptr + 4 * 3); + // Multiply-accumulate + acc_0 = vmlaq_f32(acc_0, input_0, filter_0); + acc_1 = vmlaq_f32(acc_1, input_1, filter_1); + acc_2 = vmlaq_f32(acc_2, input_2, filter_2); + acc_3 = vmlaq_f32(acc_3, input_3, filter_3); + // Store the accumulators back to acc_buffer + vst1q_f32(acc_buffer_ptr + 4 * 0, acc_0); + vst1q_f32(acc_buffer_ptr + 4 * 1, acc_1); + vst1q_f32(acc_buffer_ptr + 4 * 2, acc_2); + vst1q_f32(acc_buffer_ptr + 4 * 3, acc_3); + acc_buffer_ptr += 16; + } + // Handle 4 input channels at a time. + for (; ic <= input_depth - 4; ic += 4) { + // Load the filters + float32x4_t filter; + filter = vld1q_f32(local_filter_ptr); + local_filter_ptr += 4; + // Load the inputs + float32x4_t input; + input = vld1q_f32(local_input_ptr); + local_input_ptr += 4; + // Load the accumulators from acc_buffer + float32x4_t acc; + acc = vld1q_f32(acc_buffer_ptr); + // Multiply-accumulate + acc = vmlaq_f32(acc, input, filter); + // Store the accumulators back to acc_buffer + vst1q_f32(acc_buffer_ptr, acc); + acc_buffer_ptr += 4; + } + // Handle one input channel at a time. + for (; ic < input_depth; ic++) { + const float input_val = *local_input_ptr++; + const float filter_val = *local_filter_ptr++; + *acc_buffer_ptr++ += filter_val * input_val; + } + input_ptr += input_ptr_increment; + } + } +}; + +template <> +struct FloatDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const float* input_ptr, int input_ptr_increment, + const float* filter_ptr, float* acc_buffer_ptr) { + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + const float* local_filter_ptr = filter_ptr; + const float* local_input_ptr = input_ptr; + int ic = 0; + // Handle 2 input channels at a time. + for (; ic <= input_depth - 2; ic += 2) { + // Load the filters + float32x4_t filter[4]; + for (int i = 0; i < 4; i++) { + filter[i] = vld1q_f32(local_filter_ptr + 4 * i); + } + local_filter_ptr += 16; + // Load the inputs + const float32x2_t input = vld1_f32(local_input_ptr); + local_input_ptr += 2; + // Load the accumulators from acc_buffer + float32x4_t acc[4]; + for (int i = 0; i < 4; i++) { + acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i); + } + // Multiply-accumulate + acc[0] = vmlaq_lane_f32(acc[0], filter[0], input, 0); + acc[1] = vmlaq_lane_f32(acc[1], filter[1], input, 0); + acc[2] = vmlaq_lane_f32(acc[2], filter[2], input, 1); + acc[3] = vmlaq_lane_f32(acc[3], filter[3], input, 1); + // Store the accumulators back to acc_buffer + for (int i = 0; i < 4; i++) { + vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 16; + } + // Handle one input channel at a time. + for (; ic < input_depth; ic++) { + // Load the filters + float32x4_t filter[2]; + for (int i = 0; i < 2; i++) { + filter[i] = vld1q_f32(local_filter_ptr + 4 * i); + } + local_filter_ptr += 8; + // Load the inputs + const float input_val = *local_input_ptr++; + // Load the accumulators from acc_buffer + float32x4_t acc[2]; + for (int i = 0; i < 2; i++) { + acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i); + } + // Multiply-accumulate + for (int i = 0; i < 2; i++) { + acc[i] = vmlaq_n_f32(acc[i], filter[i], input_val); + } + // Store the accumulators back to acc_buffer + for (int i = 0; i < 2; i++) { + vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 8; + } + input_ptr += input_ptr_increment; + } + } +}; + +// Note this implementation is very slow for input_depths < 8 +// (e.g. comparable to reference implementation) see, specializations for +// input_depth=3 below. +template <> +struct FloatDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const float* input_ptr, int input_ptr_increment, + const float* filter_ptr, float* acc_buffer_ptr) { + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + const float* local_filter_ptr = filter_ptr; + const float* local_input_ptr = input_ptr; + int ic = 0; + // Handle 8 input channels at a time. + for (; ic <= input_depth - 8; ic += 8) { + // Load the filters + float32x4_t filter[4]; + for (int i = 0; i < 4; i++) { + filter[i] = vld1q_f32(local_filter_ptr + 4 * i); + } + local_filter_ptr += 16; + // Load the inputs + float32x4x2_t input_dup2[2]; + for (int i = 0; i < 2; i++) { + const float32x4_t input = vld1q_f32(local_input_ptr + 4 * i); + input_dup2[i] = vzipq_f32(input, input); + } + local_input_ptr += 8; + // Load the accumulators from acc_buffer + float32x4_t acc[4]; + for (int i = 0; i < 4; i++) { + acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i); + } + // Multiply-accumulate + acc[0] = vmlaq_f32(acc[0], filter[0], input_dup2[0].val[0]); + acc[1] = vmlaq_f32(acc[1], filter[1], input_dup2[0].val[1]); + acc[2] = vmlaq_f32(acc[2], filter[2], input_dup2[1].val[0]); + acc[3] = vmlaq_f32(acc[3], filter[3], input_dup2[1].val[1]); + // Store the accumulators back to acc_buffer + for (int i = 0; i < 4; i++) { + vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 16; + } + // Handle 4 input channels at a time. + for (; ic <= input_depth - 4; ic += 4) { + // Load the filters + float32x2_t filter[4]; + for (int i = 0; i < 4; i++) { + filter[i] = vld1_f32(local_filter_ptr + 2 * i); + } + local_filter_ptr += 8; + // Load the inputs + const float32x4_t input = vld1q_f32(local_input_ptr); + local_input_ptr += 4; + // Load the accumulators from acc_buffer + float32x2_t acc[4]; + for (int i = 0; i < 4; i++) { + acc[i] = vld1_f32(acc_buffer_ptr + 2 * i); + } + // Multiply-accumulate + acc[0] = vmla_lane_f32(acc[0], filter[0], vget_low_f32(input), 0); + acc[1] = vmla_lane_f32(acc[1], filter[1], vget_low_f32(input), 1); + acc[2] = vmla_lane_f32(acc[2], filter[2], vget_high_f32(input), 0); + acc[3] = vmla_lane_f32(acc[3], filter[3], vget_high_f32(input), 1); + // Store the accumulators back to acc_buffer + for (int i = 0; i < 4; i++) { + vst1_f32(acc_buffer_ptr + 2 * i, acc[i]); + } + acc_buffer_ptr += 8; + } + // Handle 2 input channels at a time. + for (; ic <= input_depth - 2; ic += 2) { + // Load the filters + const float32x4_t filter = vld1q_f32(local_filter_ptr); + local_filter_ptr += 4; + // Load the inputs + const float32x2_t input = vld1_f32(local_input_ptr); + local_input_ptr += 2; + // Load the accumulators from acc_buffer + float32x2_t acc[2]; + for (int i = 0; i < 2; i++) { + acc[i] = vld1_f32(acc_buffer_ptr + 2 * i); + } + // Multiply-accumulate + acc[0] = vmla_lane_f32(acc[0], vget_low_f32(filter), input, 0); + acc[1] = vmla_lane_f32(acc[1], vget_high_f32(filter), input, 1); + // Store the accumulators back to acc_buffer + for (int i = 0; i < 2; i++) { + vst1_f32(acc_buffer_ptr + 2 * i, acc[i]); + } + acc_buffer_ptr += 4; + } + // Handle one input channel at a time. + for (; ic < input_depth; ic++) { + // Load the inputs + const float input_val = *local_input_ptr++; + // Multiply-accumulate + for (int i = 0; i < 2; i++) { + acc_buffer_ptr[i] += local_filter_ptr[i] * input_val; + } + local_filter_ptr += 2; + acc_buffer_ptr += 2; + } + input_ptr += input_ptr_increment; + } + } +}; + +template <> +struct FloatDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const float* input_ptr, int input_ptr_increment, + const float* filter_ptr, float* acc_buffer_ptr) { + // Load the filters + float32x2_t filter[3]; + for (int i = 0; i < 3; i++) { + filter[i] = vld1_f32(filter_ptr + 2 * i); + } + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + const float32x2_t input01 = vld1_f32(input_ptr); + const float32x2_t input2 = vld1_dup_f32(input_ptr + 2); + // Load the accumulators from acc_buffer + float32x2_t acc[3]; + for (int i = 0; i < 3; i++) { + acc[i] = vld1_f32(acc_buffer_ptr + 2 * i); + } + // Multiply-accumulate for each input channel there 2 outputs + acc[0] = vmla_lane_f32(acc[0], filter[0], input01, 0); + acc[1] = vmla_lane_f32(acc[1], filter[1], input01, 1); + acc[2] = vmla_lane_f32(acc[2], filter[2], input2, 0); + // Store the accumulators back to acc_buffer + for (int i = 0; i < 3; i++) { + vst1_f32(acc_buffer_ptr + 2 * i, acc[i]); + } + acc_buffer_ptr += 6; + input_ptr += input_ptr_increment; + } + } +}; + +template <> +struct FloatDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const float* input_ptr, int input_ptr_increment, + const float* filter_ptr, float* acc_buffer_ptr) { + // Load the filters + float32x4_t filter[3]; + for (int i = 0; i < 3; i++) { + filter[i] = vld1q_f32(filter_ptr + 4 * i); + } + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + // NOTE: we only want 3 values, so we read it as two ops where + // the second op just duplicates the lane + const float32x2_t input01 = vld1_f32(input_ptr); + const float32x2_t input2 = vld1_dup_f32(input_ptr + 2); + // Load the accumulators from acc_buffer + float32x4_t acc[3]; + for (int i = 0; i < 3; i++) { + acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i); + } + // Multiply-accumulate all outputs. + acc[0] = vmlaq_lane_f32(acc[0], filter[0], input01, 0); + acc[1] = vmlaq_lane_f32(acc[1], filter[1], input01, 1); + acc[2] = vmlaq_lane_f32(acc[2], filter[2], input2, 0); + // Store the accumulators back to acc_buffer + for (int i = 0; i < 3; i++) { + vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 12; + input_ptr += input_ptr_increment; + } + } +}; + +template <> +struct FloatDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const float* input_ptr, int input_ptr_increment, + const float* filter_ptr, float* acc_buffer_ptr) { + // Load the filters + float32x4_t filter[2]; + for (int i = 0; i < 2; i++) { + filter[i] = vld1q_f32(filter_ptr + 4 * i); + } + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + // Load the inputs + const float input_val = *input_ptr; + input_ptr += input_ptr_increment; + // Load the accumulators from acc_buffer + float32x4_t acc[2]; + for (int i = 0; i < 2; i++) { + acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i); + } + // Multiply-accumulate + for (int i = 0; i < 2; i++) { + acc[i] = vmlaq_n_f32(acc[i], filter[i], input_val); + } + // Store the accumulators back to acc_buffer + for (int i = 0; i < 2; i++) { + vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 8; + } + } +}; + +template <> +struct FloatDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const float* input_ptr, int input_ptr_increment, + const float* filter_ptr, float* acc_buffer_ptr) { + // Load the filters + float32x4_t filter_0 = vld1q_f32(filter_ptr + 4 * 0); + float32x4_t filter_1 = vld1q_f32(filter_ptr + 4 * 1); + float32x4_t filter_2 = vld1q_f32(filter_ptr + 4 * 2); + float32x4_t filter_3 = vld1q_f32(filter_ptr + 4 * 3); + float32x4_t filter_4 = vld1q_f32(filter_ptr + 4 * 4); + float32x4_t filter_5 = vld1q_f32(filter_ptr + 4 * 5); + float32x4_t filter_6 = vld1q_f32(filter_ptr + 4 * 6); + float32x4_t filter_7 = vld1q_f32(filter_ptr + 4 * 7); + + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + // Load the inputs + const float input_val = *input_ptr; + input_ptr += input_ptr_increment; + // Load the accumulators from acc_buffer + float32x4_t acc_0 = vld1q_f32(acc_buffer_ptr + 4 * 0); + float32x4_t acc_1 = vld1q_f32(acc_buffer_ptr + 4 * 1); + float32x4_t acc_2 = vld1q_f32(acc_buffer_ptr + 4 * 2); + float32x4_t acc_3 = vld1q_f32(acc_buffer_ptr + 4 * 3); + float32x4_t acc_4 = vld1q_f32(acc_buffer_ptr + 4 * 4); + float32x4_t acc_5 = vld1q_f32(acc_buffer_ptr + 4 * 5); + float32x4_t acc_6 = vld1q_f32(acc_buffer_ptr + 4 * 6); + float32x4_t acc_7 = vld1q_f32(acc_buffer_ptr + 4 * 7); + // Multiply-accumulate + acc_0 = vmlaq_n_f32(acc_0, filter_0, input_val); + acc_1 = vmlaq_n_f32(acc_1, filter_1, input_val); + acc_2 = vmlaq_n_f32(acc_2, filter_2, input_val); + acc_3 = vmlaq_n_f32(acc_3, filter_3, input_val); + acc_4 = vmlaq_n_f32(acc_4, filter_4, input_val); + acc_5 = vmlaq_n_f32(acc_5, filter_5, input_val); + acc_6 = vmlaq_n_f32(acc_6, filter_6, input_val); + acc_7 = vmlaq_n_f32(acc_7, filter_7, input_val); + // Store the accumulators back to acc_buffer + vst1q_f32(acc_buffer_ptr + 4 * 0, acc_0); + vst1q_f32(acc_buffer_ptr + 4 * 1, acc_1); + vst1q_f32(acc_buffer_ptr + 4 * 2, acc_2); + vst1q_f32(acc_buffer_ptr + 4 * 3, acc_3); + vst1q_f32(acc_buffer_ptr + 4 * 4, acc_4); + vst1q_f32(acc_buffer_ptr + 4 * 5, acc_5); + vst1q_f32(acc_buffer_ptr + 4 * 6, acc_6); + vst1q_f32(acc_buffer_ptr + 4 * 7, acc_7); + acc_buffer_ptr += 32; + } + } +}; + +template <> +struct FloatDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const float* input_ptr, int input_ptr_increment, + const float* filter_ptr, float* acc_buffer_ptr) { + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + const float* local_filter_ptr = filter_ptr; + const float* local_input_ptr = input_ptr; + for (int ic = 0; ic < input_depth; ic++) { + // Load the filters + float32x4_t filter[4]; + for (int i = 0; i < 4; i++) { + filter[i] = vld1q_f32(local_filter_ptr + 4 * i); + } + local_filter_ptr += 16; + // Load the inputs + const float input_val = *local_input_ptr++; + // Load the accumulators from acc_buffer + float32x4_t acc[4]; + for (int i = 0; i < 4; i++) { + acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i); + } + // Multiply-accumulate + for (int i = 0; i < 4; i++) { + acc[i] = vmlaq_n_f32(acc[i], filter[i], input_val); + } + // Store the accumulators back to acc_buffer + for (int i = 0; i < 4; i++) { + vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 16; + } + input_ptr += input_ptr_increment; + } + } +}; + +template <> +struct FloatDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const float* input_ptr, int input_ptr_increment, + const float* filter_ptr, float* acc_buffer_ptr) { + // Load the filters + float32x4_t filter[2]; + for (int i = 0; i < 2; i++) { + filter[i] = vld1q_f32(filter_ptr + 4 * i); + } + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + // Load the inputs + float32x4_t input[2]; + for (int i = 0; i < 2; i++) { + input[i] = vld1q_f32(input_ptr + 4 * i); + } + // Load the accumulators from acc_buffer + float32x4_t acc[2]; + for (int i = 0; i < 2; i++) { + acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i); + } + // Multiply-accumulate + for (int i = 0; i < 2; i++) { + acc[i] = vmlaq_f32(acc[i], input[i], filter[i]); + } + // Store the accumulators back to acc_buffer + for (int i = 0; i < 2; i++) { + vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 8; + input_ptr += input_ptr_increment; + } + } +}; + +template <> +struct FloatDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const float* input_ptr, int input_ptr_increment, + const float* filter_ptr, float* acc_buffer_ptr) { + float32x2_t filter = vld1_f32(filter_ptr); + float32x4_t filter_x4 = vcombine_f32(filter, filter); + int outp = 0; + + // Handle two output pixels at a time. + for (; outp <= num_output_pixels - 2; outp += 2) { + // Load the inputs + float32x2_t input_1 = vld1_f32(input_ptr); + input_ptr += input_ptr_increment; + float32x2_t input_2 = vld1_f32(input_ptr); + input_ptr += input_ptr_increment; + float32x4_t input = vcombine_f32(input_1, input_2); + + // Load the accumulators from acc_buffer + float32x4_t acc = vld1q_f32(acc_buffer_ptr); + + // Multiply-accumulate + acc = vmlaq_f32(acc, input, filter_x4); + + // Store the accumulators back to acc_buffer + vst1q_f32(acc_buffer_ptr, acc); + acc_buffer_ptr += 4; + } + // Handle one output pixel at a time. + for (; outp < num_output_pixels; outp++) { + // Load the inputs + float32x2_t input = vld1_f32(input_ptr); + input_ptr += input_ptr_increment; + + // Load the accumulators from acc_buffer + float32x2_t acc = vld1_f32(acc_buffer_ptr); + + // Multiply-accumulate + acc = vmla_f32(acc, input, filter); + + // Store the accumulators back to acc_buffer + vst1_f32(acc_buffer_ptr, acc); + acc_buffer_ptr += 2; + } + } +}; + +template <> +struct FloatDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const float* input_ptr, int input_ptr_increment, + const float* filter_ptr, float* acc_buffer_ptr) { + float32x4_t filter = vld1q_f32(filter_ptr); + + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + // Load the inputs + float32x4_t input = vld1q_f32(input_ptr); + // Load the accumulators from acc_buffer + float32x4_t acc = vld1q_f32(acc_buffer_ptr); + // Multiply-accumulate + acc = vmlaq_f32(acc, input, filter); + // Store the accumulators back to acc_buffer + vst1q_f32(acc_buffer_ptr, acc); + acc_buffer_ptr += 4; + input_ptr += input_ptr_increment; + } + } +}; +#endif + +// Accumulates the effect of one row of the filter, on a segment of one row +// of the output, accessing the corresponding one row of the input. +template +void FloatDepthwiseConvAccumRow(int stride, int input_depth, int input_width, + const float* input_data, int pad_width, + int depth_multiplier, int filter_width, + const float* filter_data, + int out_x_buffer_start, int out_x_buffer_end, + int output_depth, float* acc_buffer) { +#ifdef GEMMLOWP_PROFILING + gemmlowp::ScopedProfilingLabel label(__PRETTY_FUNCTION__); +#endif + // Sanity check parameters. This is important in particular to ensure + // that we keep the number of template instantiations minimal, so we don't + // increase binary size unnecessarily. + static_assert(kFixedDepthMultiplier || !kFixedInputDepth, ""); + static_assert(kFixedInputDepth || kAllowStrided, ""); + TFLITE_DCHECK(stride == 1 || kAllowStrided); + if (kFixedInputDepth) { + TFLITE_DCHECK_EQ(input_depth, kFixedInputDepth); + } + if (kFixedDepthMultiplier) { + TFLITE_DCHECK_EQ(depth_multiplier, kFixedDepthMultiplier); + } + TFLITE_DCHECK_EQ(output_depth, input_depth * depth_multiplier); + const int input_ptr_increment = stride * input_depth; + const float* filter_base_ptr = filter_data; + for (int filter_x = 0; filter_x < filter_width; ++filter_x) { + // For the current (filter_x, filter_y) point in the filter, + // compute the boundaries of the corresponding output row segment. + int out_x_loop_start_unclampled = 0; + int out_x_loop_end_unclampled = 0; + if (kAllowStrided) { + if (stride == 2) { + out_x_loop_start_unclampled = (pad_width - filter_x + 1) / 2; + out_x_loop_end_unclampled = + (pad_width + input_width - filter_x + 1) / 2; + } else if (stride == 4) { + out_x_loop_start_unclampled = (pad_width - filter_x + 3) / 4; + out_x_loop_end_unclampled = + (pad_width + input_width - filter_x + 3) / 4; + } else { + out_x_loop_start_unclampled = + (pad_width - filter_x + stride - 1) / stride; + out_x_loop_end_unclampled = + (pad_width + input_width - filter_x + stride - 1) / stride; + } + } else { + out_x_loop_start_unclampled = pad_width - filter_x; + out_x_loop_end_unclampled = pad_width + input_width - filter_x; + } + // The kernel will have to iterate on the segment of the + // output row that starts at out_x_loop_start and out_x_loop_end. + const int out_x_loop_start = + std::max(out_x_buffer_start, out_x_loop_start_unclampled); + const int out_x_loop_end = + std::min(out_x_buffer_end, out_x_loop_end_unclampled); + + float* acc_buffer_ptr = + acc_buffer + (out_x_loop_start - out_x_buffer_start) * output_depth; + const int in_x_origin = (out_x_loop_start * stride) - pad_width + filter_x; + const float* input_ptr = input_data + in_x_origin * input_depth; + const int num_output_pixels = out_x_loop_end - out_x_loop_start; + FloatDepthwiseConvKernel::Run(num_output_pixels, + input_depth, + depth_multiplier, + input_ptr, + input_ptr_increment, + filter_base_ptr, + acc_buffer_ptr); + filter_base_ptr += output_depth; + } +} + +// generic fallback of FloatDepthwiseConvAccumRow, portable, non-templatized. +inline void FloatDepthwiseConvAccumRowGeneric( + int stride, int input_depth, int input_width, const float* input_data, + int pad_width, int depth_multiplier, int filter_width, + const float* filter_data, int out_x_buffer_start, int out_x_buffer_end, + int output_depth, float* acc_buffer) { + gemmlowp::ScopedProfilingLabel label("DepthwiseConvAccumRowGeneric (slow)"); +#ifdef TFLITE_PREVENT_SLOW_GENERIC_DEPTHWISECONV_FALLBACK +#ifndef ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK + LOG(FATAL) + << "\n\n" + << "*****************************************************************\n" + << "* This tfmini inference code was about to use the slow generic\n" + << "* fallback implementation for a DepthwiseConv op, and we want you\n" + << "* to be aware of that so that you will know why you get terrible\n" + << "* performance.\n" + << "*\n" + << "* If you would like to carry on with the slow code, compile\n" + << "* with this preprocessor token defined:\n" + << "* ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK.\n" + << "*\n" + << "* The right thing to do, if you care about performance, is to add\n" + << "* a new DepthwiseConv kernel to tfmini to cover your case.\n" + << "* The relevant parameters defining your case are:\n" + << "* stride = " << stride << "\n" + << "* input_depth = " << input_depth << "\n" + << "* depth_multiplier = " << depth_multiplier << "\n" + << "*\n" + << "* Please do not hesitate to contact benoitjacob@ with this\n" + << "* information.\n" + << "*****************************************************************\n"; +#endif // ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK +#endif // TFLITE_PREVENT_SLOW_GENERIC_DEPTHWISECONV_FALLBACK + const float* filter_base_ptr = filter_data; + for (int filter_x = 0; filter_x < filter_width; ++filter_x) { + const int out_x_loop_start = std::max( + out_x_buffer_start, (pad_width - filter_x + stride - 1) / stride); + const int out_x_loop_end = + std::min(out_x_buffer_end, + (pad_width + input_width - filter_x + stride - 1) / stride); + + float* acc_buffer_ptr = + acc_buffer + (out_x_loop_start - out_x_buffer_start) * output_depth; + const int in_x_origin = (out_x_loop_start * stride) - pad_width + filter_x; + const float* input_ptr = input_data + in_x_origin * input_depth; + const int input_ptr_increment = (stride - 1) * input_depth; + for (int out_x = out_x_loop_start; out_x < out_x_loop_end; out_x++) { + const float* filter_ptr = filter_base_ptr; + for (int ic = 0; ic < input_depth; ++ic) { + const float input_val = *input_ptr++; + for (int m = 0; m < depth_multiplier; m++) { + const float filter_val = *filter_ptr++; + *acc_buffer_ptr++ += filter_val * input_val; + } + } + input_ptr += input_ptr_increment; + } + filter_base_ptr += output_depth; + } +} + +// Initializes the accumulator buffer with bias values. +inline void DepthwiseConvInitAccBuffer(int num_output_pixels, int output_depth, + const float* bias_data, + float* acc_buffer) { + // TODO(benoitjacob): This might need optimized specializations + // for small output_depth values, if that ever becomes an important + // case (like it was for some quantized DepthwiseConv cases). + for (int i = 0; i < num_output_pixels; i++) { + memcpy(acc_buffer + i * output_depth, bias_data, + sizeof(acc_buffer[0]) * output_depth); + } +} + +inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + const float* bias_data, const Dims<4>& bias_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int depth_multiplier, + float output_activation_min, + float output_activation_max, float* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("DepthwiseConv"); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int input_depth = ArraySize(input_dims, 0); + const int filter_height = ArraySize(filter_dims, 2); + const int filter_width = ArraySize(filter_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + TFLITE_DCHECK(output_depth == input_depth * depth_multiplier); + + static const int kAccBufferMaxSize = 2048; + float acc_buffer[kAccBufferMaxSize]; + TFLITE_DCHECK_GE(kAccBufferMaxSize, output_depth); + const int kOutputPixelsInAccBuffer = kAccBufferMaxSize / output_depth; + const int kAccBufferActualSize = kOutputPixelsInAccBuffer * output_depth; + TFLITE_DCHECK_LE(kOutputPixelsInAccBuffer * output_depth, + kAccBufferActualSize); + TFLITE_DCHECK_LE(kAccBufferActualSize, kAccBufferMaxSize); + TFLITE_DCHECK_GE(kOutputPixelsInAccBuffer, 1); + + // row_accum_func will point to the core accumulation function to be used + // for this DepthwiseConv op. + using row_accum_func_t = decltype(&FloatDepthwiseConvAccumRowGeneric); + row_accum_func_t row_accum_func = nullptr; + +#define TFMINI_USE_DEPTHWISECONV_KERNEL(ALLOW_STRIDED, FIXED_INPUT_DEPTH, \ + FIXED_DEPTH_MULTIPLIER) \ + if (!row_accum_func && (stride_width == 1 || ALLOW_STRIDED) && \ + (input_depth == FIXED_INPUT_DEPTH || FIXED_INPUT_DEPTH == 0) && \ + depth_multiplier == FIXED_DEPTH_MULTIPLIER) { \ + row_accum_func = \ + FloatDepthwiseConvAccumRow; \ + } + +#ifdef USE_NEON + // We go over our list of kernels by decreasing order of preference + // for the cases where multiple kernels could apply. + + // Start with the fastest kernels: AllowStrided=false, fixed input depth. + + TFMINI_USE_DEPTHWISECONV_KERNEL(false, 8, 1) + TFMINI_USE_DEPTHWISECONV_KERNEL(false, 2, 1) + + // Next come the strided kernels: AllowStrided=true, fixed input depth. + // They are a bit less efficient, but allow stride!=1. + + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 8, 1) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 1, 8) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 1, 32) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 2, 1) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 3, 2) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 3, 4) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 4, 1) + + // Finally, the kernels allowing a variable input depth, + // these are the least efficient but most general kernels. + + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 0, 1) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 0, 2) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 0, 8) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 0, 16) + +#endif // USE_NEON + +#undef TFMINI_USE_DEPTHWISECONV_KERNEL + + // No matching fast kernel found, use slow fallback. + if (!row_accum_func) { + row_accum_func = FloatDepthwiseConvAccumRowGeneric; + } + + // Now that we have determined row_accum_func, we can start work. + float* output_ptr = output_data; + for (int b = 0; b < batches; ++b) { + for (int out_y = 0; out_y < output_height; ++out_y) { + const int in_y_origin = (out_y * stride_height) - pad_height; + const int filter_y_start = std::max(0, -in_y_origin); + const int filter_y_end = + std::min(filter_height, input_height - in_y_origin); + for (int out_x_buffer_start = 0; out_x_buffer_start < output_width; + out_x_buffer_start += kOutputPixelsInAccBuffer) { + const int out_x_buffer_end = std::min( + output_width, out_x_buffer_start + kOutputPixelsInAccBuffer); + // We call a 'pixel' a group of activation that share all but the + // 'depth'/'channel' coordinate. num_output_pixels is the number of + // output pixels that we will accumulate in this loop iteration. + const int num_output_pixels = out_x_buffer_end - out_x_buffer_start; + // Initialize our local accumulator with the bias values, so we don't + // have to add them later. + DepthwiseConvInitAccBuffer(num_output_pixels, output_depth, bias_data, + acc_buffer); + // Accumulation loop. Most of the time should be spent in here. + for (int filter_y = filter_y_start; filter_y < filter_y_end; + ++filter_y) { + const int in_y = in_y_origin + filter_y; + row_accum_func(stride_width, input_depth, input_width, + input_data + in_y * input_dims.strides[2] + + b * input_dims.strides[3], + pad_width, depth_multiplier, filter_width, + filter_data + filter_y * filter_dims.strides[2], + out_x_buffer_start, out_x_buffer_end, output_depth, + acc_buffer); + } + // Finished accumulating. Now store to destination. + const int num_output_values = output_depth * num_output_pixels; + int i = 0; +// TODO(benoitjacob) optimized code goes here +#ifdef USE_NEON + // Handle 16 values at a time + for (; i <= num_output_values - 16; i += 16) { + float32x4_t acc[4]; + for (int k = 0; k < 4; k++) { + acc[k] = vld1q_f32(acc_buffer + i + 4 * k); + } + for (int k = 0; k < 4; k++) { + acc[k] = vmaxq_f32( + vdupq_n_f32(output_activation_min), + vminq_f32(vdupq_n_f32(output_activation_max), acc[k])); + } + for (int k = 0; k < 4; k++) { + vst1q_f32(output_ptr + 4 * k, acc[k]); + } + output_ptr += 16; + } + // Handle 4 values at a time + for (; i <= num_output_values - 4; i += 4) { + float32x4_t acc = vld1q_f32(acc_buffer + i); + + acc = vmaxq_f32(vdupq_n_f32(output_activation_min), + vminq_f32(vdupq_n_f32(output_activation_max), acc)); + + vst1q_f32(output_ptr, acc); + output_ptr += 4; + } +#endif + // Handle leftover values, one by one. This is very slow. + for (; i < num_output_values; i++) { + float acc = acc_buffer[i]; + acc = std::max(output_activation_min, + std::min(output_activation_max, acc)); + + *output_ptr++ = acc; + } + } + } + } +} + +// legacy, for compatibility with old checked-in code +template +void DepthwiseConv(const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + const float* bias_data, const Dims<4>& bias_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int depth_multiplier, float* output_data, + const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + DepthwiseConv(input_data, input_dims, filter_data, filter_dims, bias_data, + bias_dims, stride_width, stride_height, pad_width, pad_height, + depth_multiplier, output_activation_min, output_activation_max, + output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void DepthwiseConv(const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + const float* bias_data, const Dims<4>& bias_dims, int stride, + int pad_width, int pad_height, int depth_multiplier, + float* output_data, const Dims<4>& output_dims) { + DepthwiseConv(input_data, input_dims, filter_data, filter_dims, bias_data, + bias_dims, stride, stride, pad_width, pad_height, + depth_multiplier, output_data, output_dims); +} + +} // namespace optimized_ops +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_FLOAT_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h new file mode 100644 index 0000000000000000000000000000000000000000..051ed2a2c44a04f0473dfd26637e53865a5a51ac --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h @@ -0,0 +1,1916 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_UINT8_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_UINT8_H_ + +#include "fixedpoint/fixedpoint.h" +#include "public/gemmlowp.h" +#include "tensorflow/contrib/lite/kernels/internal/common.h" +#include "tensorflow/contrib/lite/kernels/internal/types.h" + +namespace tflite { +namespace optimized_ops { + +// Implementation of quantized DepthwiseConv + +template +struct QuantizedDepthwiseConvKernel {}; + +#ifdef USE_NEON +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Load the filters, add filter_offset. + uint8x8x2_t filter_u8; + filter_u8.val[0] = vld1_u8(filter_ptr); + filter_u8.val[1] = vld1_u8(filter_ptr + 8); + int16x8_t filter[2]; + for (int i = 0; i < 2; i++) { + filter[i] = vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(filter_u8.val[i])), + vdupq_n_s16(filter_offset)); + } + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + // Load the accumulators from acc_buffer + int32x4x2_t acc[2]; + for (int i = 0; i < 2; i++) { + acc[i].val[0] = vld1q_s32(acc_buffer_ptr + 4 * i); + acc[i].val[1] = vld1q_s32(acc_buffer_ptr + 4 * i + 8); + } + // Load the inputs, add input_offset. + const uint8x8_t input_u8 = vld1_u8(input_ptr); + input_ptr += input_ptr_increment; + const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8)); + const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset)); + // Duplicate the input values, 2-fold + const int16x8x2_t input_dup2 = vzipq_s16(input, input); + // Multiply-accumulate + for (int i = 0; i < 2; i++) { + acc[0].val[i] = vmlal_s16(acc[0].val[i], vget_low_s16(filter[i]), + vget_low_s16(input_dup2.val[i])); + acc[1].val[i] = vmlal_s16(acc[1].val[i], vget_high_s16(filter[i]), + vget_high_s16(input_dup2.val[i])); + } + // Store the accumulators back to acc_buffer + for (int i = 0; i < 2; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i].val[0]); + vst1q_s32(acc_buffer_ptr + 4 * i + 8, acc[i].val[1]); + } + acc_buffer_ptr += 16; + } + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Load the filters, add filter_offset. + const uint8x8_t filter_u8 = vld1_u8(filter_ptr); + const int16x8_t filter_s16 = vreinterpretq_s16_u16(vmovl_u8(filter_u8)); + const int16x8_t filter = vaddq_s16(filter_s16, vdupq_n_s16(filter_offset)); + + int outp = 0; + // Handle 2 output pixels at a time. + for (; outp <= num_output_pixels - 2; outp += 2) { + // Load the accumulators from acc_buffer. + int32x4_t acc[4]; + for (int i = 0; i < 4; i++) { + acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i); + } + // Load the inputs, add input_offset. + uint8x8_t input_u8[2]; + for (int i = 0; i < 2; i++) { + input_u8[i] = vld1_u8(input_ptr + 8 * i); + } + input_ptr += 16; + int16x8_t input[2]; + for (int i = 0; i < 2; i++) { + input[i] = vreinterpretq_s16_u16(vmovl_u8(input_u8[i])); + } + for (int i = 0; i < 2; i++) { + input[i] = vaddq_s16(input[i], vdupq_n_s16(input_offset)); + } + // Multiply-accumulate. + acc[0] = vmlal_s16(acc[0], vget_low_s16(filter), vget_low_s16(input[0])); + acc[1] = + vmlal_s16(acc[1], vget_high_s16(filter), vget_high_s16(input[0])); + acc[2] = vmlal_s16(acc[2], vget_low_s16(filter), vget_low_s16(input[1])); + acc[3] = + vmlal_s16(acc[3], vget_high_s16(filter), vget_high_s16(input[1])); + // Store the accumulators back to acc_buffer + for (int i = 0; i < 4; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 16; + } + // Handle 1 output pixel at a time. + for (; outp < num_output_pixels; outp++) { + // Load the accumulators from acc_buffer. + int32x4_t acc[2]; + acc[0] = vld1q_s32(acc_buffer_ptr); + acc[1] = vld1q_s32(acc_buffer_ptr + 4); + + // Load the inputs, add input_offset. + const uint8x8_t input_u8 = vld1_u8(input_ptr); + input_ptr += 8; + const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8)); + const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset)); + // Multiply-accumulate. + acc[0] = vmlal_s16(acc[0], vget_low_s16(filter), vget_low_s16(input)); + acc[1] = vmlal_s16(acc[1], vget_high_s16(filter), vget_high_s16(input)); + // Store the accumulators back to acc_buffer + vst1q_s32(acc_buffer_ptr, acc[0]); + vst1q_s32(acc_buffer_ptr + 4, acc[1]); + acc_buffer_ptr += 8; + } + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Load the filters, add filter_offset. + const uint8x8_t filter_u8 = vld1_u8(filter_ptr); + const int16x8_t filter_s16 = vreinterpretq_s16_u16(vmovl_u8(filter_u8)); + const int16x8_t filter = vaddq_s16(filter_s16, vdupq_n_s16(filter_offset)); + + int outp = 0; + // Handle 2 output pixels at a time. + for (; outp <= num_output_pixels - 2; outp += 2) { + // Load the accumulators from acc_buffer + int32x4_t acc[4]; + for (int i = 0; i < 4; i++) { + acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i); + } + // Load the inputs, add input_offset. + const uint8x8_t input_u8 = vld1_u8(input_ptr); + input_ptr += 8; + const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8)); + const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset)); + // Duplicate the input values, 2-fold + const int16x8x2_t input_dup2 = vzipq_s16(input, input); + // Multiply-accumulate + for (int i = 0; i < 2; i++) { + acc[2 * i + 0] = vmlal_s16(acc[2 * i + 0], vget_low_s16(filter), + vget_low_s16(input_dup2.val[i])); + acc[2 * i + 1] = vmlal_s16(acc[2 * i + 1], vget_high_s16(filter), + vget_high_s16(input_dup2.val[i])); + } + // Store the accumulators back to acc_buffer + for (int i = 0; i < 4; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 16; + } + // Handle one output pixel at a time. + for (; outp < num_output_pixels; outp++) { + // Load the accumulators from acc_buffer + int32x4_t acc[2]; + for (int i = 0; i < 2; i++) { + acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i); + } + // Load the inputs, add input_offset. + uint8x8_t input_u8 = vdup_n_u8(0); + input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0); + input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1); + input_u8 = vset_lane_u8(input_ptr[2], input_u8, 2); + input_u8 = vset_lane_u8(input_ptr[3], input_u8, 3); + input_ptr += 4; + const int16x4_t input_s16 = + vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8))); + const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset)); + // Duplicate the input values, 2-fold + const int16x4x2_t input_dup2 = vzip_s16(input, input); + // Multiply-accumulate + acc[0] = vmlal_s16(acc[0], vget_low_s16(filter), input_dup2.val[0]); + acc[1] = vmlal_s16(acc[1], vget_high_s16(filter), input_dup2.val[1]); + // Store the accumulators back to acc_buffer + for (int i = 0; i < 2; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 8; + } + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Load the filters, add filter_offset. + int16x8_t filter[2]; + for (int i = 0; i < 2; i++) { + const uint8x8_t filter_u8 = vld1_u8(filter_ptr + 8 * i); + const int16x8_t filter_s16 = vreinterpretq_s16_u16(vmovl_u8(filter_u8)); + filter[i] = vaddq_s16(filter_s16, vdupq_n_s16(filter_offset)); + } + int outp = 0; + // Handle two output pixels at a time. + for (; outp <= num_output_pixels - 2; outp += 2) { + // Load the accumulators from acc_buffer. + int32x4_t acc[8]; + for (int i = 0; i < 8; i++) { + acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i); + } + // Load the inputs, add input_offset. + uint8x8_t input_u8 = vdup_n_u8(0); + input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0); + input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1); + input_u8 = vset_lane_u8(input_ptr[2], input_u8, 2); + input_u8 = vset_lane_u8(input_ptr[3], input_u8, 3); + input_ptr += 4; + const int16x4_t input_s16 = + vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8))); + const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset)); + // Multiply-accumulate. + acc[0] = vmlal_lane_s16(acc[0], vget_low_s16(filter[0]), input, 0); + acc[1] = vmlal_lane_s16(acc[1], vget_high_s16(filter[0]), input, 0); + acc[2] = vmlal_lane_s16(acc[2], vget_low_s16(filter[1]), input, 1); + acc[3] = vmlal_lane_s16(acc[3], vget_high_s16(filter[1]), input, 1); + acc[4] = vmlal_lane_s16(acc[4], vget_low_s16(filter[0]), input, 2); + acc[5] = vmlal_lane_s16(acc[5], vget_high_s16(filter[0]), input, 2); + acc[6] = vmlal_lane_s16(acc[6], vget_low_s16(filter[1]), input, 3); + acc[7] = vmlal_lane_s16(acc[7], vget_high_s16(filter[1]), input, 3); + // Store the accumulators back to acc_buffer. + for (int i = 0; i < 8; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 32; + } + // Handle one output pixel at a time. + for (; outp < num_output_pixels; outp++) { + // Load the accumulators from acc_buffer. + int32x4_t acc[4]; + for (int i = 0; i < 4; i++) { + acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i); + } + // Load the inputs, add input_offset. + uint8x8_t input_u8 = vdup_n_u8(0); + input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0); + input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1); + input_ptr += 2; + const int16x4_t input_s16 = + vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8))); + const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset)); + + // Multiply-accumulate. + acc[0] = vmlal_lane_s16(acc[0], vget_low_s16(filter[0]), input, 0); + acc[1] = vmlal_lane_s16(acc[1], vget_high_s16(filter[0]), input, 0); + acc[2] = vmlal_lane_s16(acc[2], vget_low_s16(filter[1]), input, 1); + acc[3] = vmlal_lane_s16(acc[3], vget_high_s16(filter[1]), input, 1); + + // Store the accumulators back to acc_buffer. + for (int i = 0; i < 4; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 16; + } + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Load the filters, add filter_offset. + uint8x8_t filter_u8 = vdup_n_u8(0); + filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 0); + filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 1); + filter_u8 = vset_lane_u8(filter_ptr[2], filter_u8, 2); + filter_u8 = vset_lane_u8(filter_ptr[3], filter_u8, 3); + const int16x4_t filter_s16 = + vreinterpret_s16_u16(vget_low_u16(vmovl_u8(filter_u8))); + const int16x4_t filter = vadd_s16(filter_s16, vdup_n_s16(filter_offset)); + + int outp = 0; + // Handle 4 output pixels at a time. + for (; outp <= num_output_pixels - 4; outp += 4) { + // Load the accumulators from acc_buffer + int32x4_t acc[4]; + for (int i = 0; i < 4; i++) { + acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i); + } + + // Load the inputs, add input_offset. + const uint8x8_t input_u8 = vld1_u8(input_ptr); + input_ptr += 8; + const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8)); + const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset)); + // Duplicate the input values, 2-fold + const int16x8x2_t input_dup2 = vzipq_s16(input, input); + // Multiply-accumulate + acc[0] = vmlal_s16(acc[0], filter, vget_low_s16(input_dup2.val[0])); + acc[1] = vmlal_s16(acc[1], filter, vget_high_s16(input_dup2.val[0])); + acc[2] = vmlal_s16(acc[2], filter, vget_low_s16(input_dup2.val[1])); + acc[3] = vmlal_s16(acc[3], filter, vget_high_s16(input_dup2.val[1])); + // Store the accumulators back to acc_buffer + for (int i = 0; i < 4; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 16; + } + // Handle one output pixel at a time. + for (; outp < num_output_pixels; outp++) { + // Load the accumulators from acc_buffer + int32x4_t acc = vld1q_s32(acc_buffer_ptr); + + uint8x8_t input_u8 = vdup_n_u8(0); + input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0); + input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1); + input_ptr += 2; + const int16x4_t input_s16 = + vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8))); + const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset)); + // Duplicate the input values, 2-fold + const int16x4_t input_dup2 = vzip_s16(input, input).val[0]; + // Multiply-accumulate + acc = vmlal_s16(acc, filter, input_dup2); + // Store the accumulators back to acc_buffer + vst1q_s32(acc_buffer_ptr, acc); + acc_buffer_ptr += 4; + } + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Load the filters, add filter_offset. + uint8x8_t filter_u8 = vdup_n_u8(0); + filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 0); + filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 1); + filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 2); + filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 3); + const int16x4_t filter_s16 = + vreinterpret_s16_u16(vget_low_u16(vmovl_u8(filter_u8))); + const int16x4_t filter = vadd_s16(filter_s16, vdup_n_s16(filter_offset)); + + int outp = 0; + // Handle 8 output pixels at a time. + for (; outp <= num_output_pixels - 8; outp += 8) { + // Load the accumulators from acc_buffer. + int32x4_t acc[4]; + for (int i = 0; i < 4; i++) { + acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i); + } + // Load the inputs, add input_offset. + uint8x8_t input_u8[2]; + for (int i = 0; i < 2; i++) { + input_u8[i] = vld1_u8(input_ptr + 8 * i); + } + input_ptr += 16; + int16x8_t input[2]; + for (int i = 0; i < 2; i++) { + input[i] = vreinterpretq_s16_u16(vmovl_u8(input_u8[i])); + } + for (int i = 0; i < 2; i++) { + input[i] = vaddq_s16(input[i], vdupq_n_s16(input_offset)); + } + + // Multiply-accumulate. + acc[0] = vmlal_s16(acc[0], filter, vget_low_s16(input[0])); + acc[1] = vmlal_s16(acc[1], filter, vget_high_s16(input[0])); + acc[2] = vmlal_s16(acc[2], filter, vget_low_s16(input[1])); + acc[3] = vmlal_s16(acc[3], filter, vget_high_s16(input[1])); + // Store the accumulators back to acc_buffer. + for (int i = 0; i < 4; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 16; + } + // Handle 4 output pixels at a time. + for (; outp <= num_output_pixels - 4; outp += 4) { + // Load the accumulators from acc_buffer. + int32x4_t acc[2]; + for (int i = 0; i < 2; i++) { + acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i); + } + // Load the inputs, add input_offset. + const uint8x8_t input_u8 = vld1_u8(input_ptr); + input_ptr += 8; + const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8)); + const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset)); + + // Multiply-accumulate. + acc[0] = vmlal_s16(acc[0], filter, vget_low_s16(input)); + acc[1] = vmlal_s16(acc[1], filter, vget_high_s16(input)); + // Store the accumulators back to acc_buffer. + for (int i = 0; i < 2; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 8; + } + // Handle 2 output pixels at a time. + for (; outp <= num_output_pixels - 2; outp += 2) { + // Load the accumulators from acc_buffer. + int32x4_t acc = vld1q_s32(acc_buffer_ptr); + // Load the inputs, add input_offset. + uint8x8_t input_u8 = vdup_n_u8(0); + input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0); + input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1); + input_u8 = vset_lane_u8(input_ptr[2], input_u8, 2); + input_u8 = vset_lane_u8(input_ptr[3], input_u8, 3); + input_ptr += 4; + const int16x4_t input_s16 = + vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8))); + const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset)); + + // Multiply-accumulate. + acc = vmlal_s16(acc, filter, input); + // Store the accumulators back to acc_buffer. + vst1q_s32(acc_buffer_ptr, acc); + acc_buffer_ptr += 4; + } + // Handle 1 output pixel at a time. + for (; outp < num_output_pixels; outp++) { + // Load the accumulators from acc_buffer. + int32x2_t acc = vld1_s32(acc_buffer_ptr); + // Load the inputs, add input_offset. + uint8x8_t input_u8 = vdup_n_u8(0); + input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0); + input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1); + input_ptr += 2; + const int16x4_t input_s16 = + vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8))); + const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset)); + + // Multiply-accumulate. + acc = vget_low_s32(vmlal_s16(vcombine_s32(acc, acc), filter, input)); + // Store the accumulators back to acc_buffer. + vst1_s32(acc_buffer_ptr, acc); + acc_buffer_ptr += 2; + } + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Load the filters, add filter_offset. + uint8x8_t filter_u8 = vdup_n_u8(0); + filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 0); + filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 1); + filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 2); + filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 3); + const int16x4_t filter_s16 = + vreinterpret_s16_u16(vget_low_u16(vmovl_u8(filter_u8))); + const int16x4_t filter = vadd_s16(filter_s16, vdup_n_s16(filter_offset)); + + int outp = 0; + // Handle 8 output pixels at a time. + for (; outp <= num_output_pixels - 8; outp += 8) { + // Load the accumulators from acc_buffer + int32x4_t acc[4]; + for (int i = 0; i < 4; i++) { + acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i); + } + + // Load the inputs, add input_offset. + const uint8x8_t input_u8 = vld1_u8(input_ptr); + input_ptr += 8; + const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8)); + const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset)); + // Duplicate the input values, 2-fold + const int16x8x2_t input_dup2 = vzipq_s16(input, input); + // Multiply-accumulate + acc[0] = vmlal_s16(acc[0], filter, vget_low_s16(input_dup2.val[0])); + acc[1] = vmlal_s16(acc[1], filter, vget_high_s16(input_dup2.val[0])); + acc[2] = vmlal_s16(acc[2], filter, vget_low_s16(input_dup2.val[1])); + acc[3] = vmlal_s16(acc[3], filter, vget_high_s16(input_dup2.val[1])); + // Store the accumulators back to acc_buffer + for (int i = 0; i < 4; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 16; + } + // Handle one output pixel at a time. + for (; outp < num_output_pixels; outp++) { + // Load the accumulators from acc_buffer + int32x2_t acc = vld1_s32(acc_buffer_ptr); + + // Load the inputs, add input_offset. + const uint32 input = *input_ptr++ + input_offset; + + // Multiply-accumulate + acc = vget_low_s32(vmlal_n_s16(vcombine_s32(acc, acc), filter, input)); + // Store the accumulators back to acc_buffer + vst1_s32(acc_buffer_ptr, acc); + acc_buffer_ptr += 2; + } + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Load the filters, add filter_offset. + uint8x8_t filter_u8 = vdup_n_u8(0); + filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 0); + filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 1); + filter_u8 = vset_lane_u8(filter_ptr[2], filter_u8, 2); + filter_u8 = vset_lane_u8(filter_ptr[3], filter_u8, 3); + const int16x4_t filter_s16 = + vreinterpret_s16_u16(vget_low_u16(vmovl_u8(filter_u8))); + const int16x4_t filter = vadd_s16(filter_s16, vdup_n_s16(filter_offset)); + + int outp = 0; + // Handle 8 output pixels at a time. + for (; outp <= num_output_pixels - 8; outp += 8) { + // Load the accumulators from acc_buffer + int32x4_t acc[8]; + for (int i = 0; i < 8; i++) { + acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i); + } + + // Load the inputs, add input_offset. + uint8x8_t input_u8 = vld1_u8(input_ptr); + input_ptr += 8; + const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8)); + const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset)); + + // Multiply-accumulate + acc[0] = vmlal_lane_s16(acc[0], filter, vget_low_s16(input), 0); + acc[1] = vmlal_lane_s16(acc[1], filter, vget_low_s16(input), 1); + acc[2] = vmlal_lane_s16(acc[2], filter, vget_low_s16(input), 2); + acc[3] = vmlal_lane_s16(acc[3], filter, vget_low_s16(input), 3); + acc[4] = vmlal_lane_s16(acc[4], filter, vget_high_s16(input), 0); + acc[5] = vmlal_lane_s16(acc[5], filter, vget_high_s16(input), 1); + acc[6] = vmlal_lane_s16(acc[6], filter, vget_high_s16(input), 2); + acc[7] = vmlal_lane_s16(acc[7], filter, vget_high_s16(input), 3); + + // Store the accumulators back to acc_buffer + for (int i = 0; i < 8; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 32; + } + // Handle 4 output pixels at a time. + for (; outp <= num_output_pixels - 4; outp += 4) { + // Load the accumulators from acc_buffer + int32x4_t acc[4]; + for (int i = 0; i < 4; i++) { + acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i); + } + + // Load the inputs, add input_offset. + uint8x8_t input_u8 = vdup_n_u8(0); + input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0); + input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1); + input_u8 = vset_lane_u8(input_ptr[2], input_u8, 2); + input_u8 = vset_lane_u8(input_ptr[3], input_u8, 3); + input_ptr += 4; + const int16x4_t input_s16 = + vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8))); + const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset)); + + // Multiply-accumulate + acc[0] = vmlal_lane_s16(acc[0], filter, input, 0); + acc[1] = vmlal_lane_s16(acc[1], filter, input, 1); + acc[2] = vmlal_lane_s16(acc[2], filter, input, 2); + acc[3] = vmlal_lane_s16(acc[3], filter, input, 3); + + // Store the accumulators back to acc_buffer + for (int i = 0; i < 4; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 16; + } + // Handle one output pixel at a time. + for (; outp < num_output_pixels; outp++) { + // Load the accumulators from acc_buffer + int32x4_t acc = vld1q_s32(acc_buffer_ptr); + + // Load the inputs, add input_offset. + const uint32 input = *input_ptr++ + input_offset; + + // Multiply-accumulate + acc = vmlal_n_s16(acc, filter, input); + // Store the accumulators back to acc_buffer + vst1q_s32(acc_buffer_ptr, acc); + acc_buffer_ptr += 4; + } + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Load the filters, add filter_offset. + uint8x8_t filter_u8 = vdup_n_u8(0); + filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 0); + filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 1); + filter_u8 = vset_lane_u8(filter_ptr[2], filter_u8, 2); + filter_u8 = vset_lane_u8(filter_ptr[3], filter_u8, 3); + const int16x4_t filter_s16 = + vreinterpret_s16_u16(vget_low_u16(vmovl_u8(filter_u8))); + const int16x4_t filter = vadd_s16(filter_s16, vdup_n_s16(filter_offset)); + + int outp = 0; + // Handle 4 output pixels at a time. + for (; outp <= num_output_pixels - 4; outp += 4) { + // Load the accumulators from acc_buffer + int32x4_t acc[4]; + for (int i = 0; i < 4; i++) { + acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i); + } + // Load the inputs, add input_offset. + int16x8_t input[2]; + for (int i = 0; i < 2; i++) { + const uint8x8_t input_u8 = vld1_u8(input_ptr + 8 * i); + const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8)); + input[i] = vaddq_s16(input_s16, vdupq_n_s16(input_offset)); + } + input_ptr += 16; + // Multiply-accumulate + for (int i = 0; i < 2; i++) { + acc[2 * i + 0] = + vmlal_s16(acc[2 * i + 0], filter, vget_low_s16(input[i])); + acc[2 * i + 1] = + vmlal_s16(acc[2 * i + 1], filter, vget_high_s16(input[i])); + } + // Store the accumulators back to acc_buffer + for (int i = 0; i < 4; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 16; + } + // Handle one output pixel at a time. + for (; outp < num_output_pixels; outp++) { + // Load the accumulators from acc_buffer + int32x4_t acc; + acc = vld1q_s32(acc_buffer_ptr); + + // Load the inputs, add input_offset. + uint8x8_t input_u8 = vdup_n_u8(0); + input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0); + input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1); + input_u8 = vset_lane_u8(input_ptr[2], input_u8, 2); + input_u8 = vset_lane_u8(input_ptr[3], input_u8, 3); + input_ptr += 4; + const int16x4_t input_s16 = + vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8))); + const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset)); + // Multiply-accumulate + acc = vmlal_s16(acc, filter, input); + // Store the accumulators back to acc_buffer + vst1q_s32(acc_buffer_ptr, acc); + acc_buffer_ptr += 4; + } + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Load the filters, add filter_offset. + int16x8_t filter[2]; + for (int i = 0; i < 2; i++) { + const uint8x8_t filter_u8 = vld1_u8(filter_ptr + 8 * i); + const int16x8_t filter_s16 = vreinterpretq_s16_u16(vmovl_u8(filter_u8)); + filter[i] = vaddq_s16(filter_s16, vdupq_n_s16(filter_offset)); + } + + int outp = 0; + // Handle 2 output pixels at a time. + for (; outp <= num_output_pixels - 2; outp += 2) { + // Load the accumulators from acc_buffer + int32x4_t acc[8]; + for (int i = 0; i < 8; i++) { + acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i); + } + + // Load the inputs, add input_offset. + uint8x8_t input_u8 = vld1_u8(input_ptr); + input_ptr += 8; + const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8)); + const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset)); + + // Multiply-accumulate + acc[0] = vmlal_lane_s16(acc[0], vget_low_s16(filter[0]), + vget_low_s16(input), 0); + acc[1] = vmlal_lane_s16(acc[1], vget_high_s16(filter[0]), + vget_low_s16(input), 1); + acc[2] = vmlal_lane_s16(acc[2], vget_low_s16(filter[1]), + vget_low_s16(input), 2); + acc[3] = vmlal_lane_s16(acc[3], vget_high_s16(filter[1]), + vget_low_s16(input), 3); + acc[4] = vmlal_lane_s16(acc[4], vget_low_s16(filter[0]), + vget_high_s16(input), 0); + acc[5] = vmlal_lane_s16(acc[5], vget_high_s16(filter[0]), + vget_high_s16(input), 1); + acc[6] = vmlal_lane_s16(acc[6], vget_low_s16(filter[1]), + vget_high_s16(input), 2); + acc[7] = vmlal_lane_s16(acc[7], vget_high_s16(filter[1]), + vget_high_s16(input), 3); + // Store the accumulators back to acc_buffer + for (int i = 0; i < 8; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 32; + } + // Handle one output pixel at a time. + for (; outp < num_output_pixels; outp++) { + // Load the accumulators from acc_buffer + int32x4_t acc[4]; + for (int i = 0; i < 4; i++) { + acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i); + } + + // Load the inputs, add input_offset. + uint8x8_t input_u8 = vdup_n_u8(0); + input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0); + input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1); + input_u8 = vset_lane_u8(input_ptr[2], input_u8, 2); + input_u8 = vset_lane_u8(input_ptr[3], input_u8, 3); + input_ptr += 4; + const int16x4_t input_s16 = + vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8))); + const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset)); + + // Multiply-accumulate + acc[0] = vmlal_lane_s16(acc[0], vget_low_s16(filter[0]), input, 0); + acc[1] = vmlal_lane_s16(acc[1], vget_high_s16(filter[0]), input, 1); + acc[2] = vmlal_lane_s16(acc[2], vget_low_s16(filter[1]), input, 2); + acc[3] = vmlal_lane_s16(acc[3], vget_high_s16(filter[1]), input, 3); + // Store the accumulators back to acc_buffer + for (int i = 0; i < 4; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 16; + } + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // We will have to duplicate bytes in a NEON register, 3-fold. + // We will do that by register-level table-look-up using VTBL instructions. + // Here we prepare the registers containing the table-lookup indices. + static const uint8 dup3_indices_array[3][8] = {{0, 0, 0, 1, 1, 1, 2, 2}, + {2, 3, 3, 3, 4, 4, 4, 5}, + {5, 5, 6, 6, 6, 7, 7, 7}}; + uint8x8_t dup3_indices[3]; + for (int i = 0; i < 3; i++) { + dup3_indices[i] = vld1_u8(dup3_indices_array[i]); + } + + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + const uint8* local_filter_ptr = filter_ptr; + const uint8* local_input_ptr = input_ptr; + int ic = 0; + // Handle 8 input channels at a time. + for (; ic <= input_depth - 8; ic += 8) { + // Load the filters, add filter_offset. + int16x8_t filter[3]; + uint8x8x3_t filter_u8; + filter_u8.val[0] = vld1_u8(local_filter_ptr); + filter_u8.val[1] = vld1_u8(local_filter_ptr + 8); + filter_u8.val[2] = vld1_u8(local_filter_ptr + 16); + local_filter_ptr += 24; + for (int i = 0; i < 3; i++) { + const int16x8_t filter_s16 = + vreinterpretq_s16_u16(vmovl_u8(filter_u8.val[i])); + filter[i] = vaddq_s16(filter_s16, vdupq_n_s16(filter_offset)); + } + // Load the inputs, duplicate 3-fold, add input_offset. + const uint8x8_t input_u8 = vld1_u8(local_input_ptr); + local_input_ptr += 8; + + uint8x8_t input_u8_dup3[3]; + for (int i = 0; i < 3; i++) { + input_u8_dup3[i] = vtbl1_u8(input_u8, dup3_indices[i]); + } + int16x8_t input_dup3[3]; + for (int i = 0; i < 3; i++) { + const int16x8_t input_s16_dup3 = + vreinterpretq_s16_u16(vmovl_u8(input_u8_dup3[i])); + input_dup3[i] = vaddq_s16(input_s16_dup3, vdupq_n_s16(input_offset)); + } + // Load the accumulators from acc_buffer + int32x4x3_t acc[2]; + for (int i = 0; i < 2; i++) { + acc[i].val[0] = vld1q_s32(acc_buffer_ptr + 4 * i); + acc[i].val[1] = vld1q_s32(acc_buffer_ptr + 4 * i + 8); + acc[i].val[2] = vld1q_s32(acc_buffer_ptr + 4 * i + 16); + } + // Multiply-accumulate + for (int j = 0; j < 3; j++) { + acc[0].val[j] = vmlal_s16(acc[0].val[j], vget_low_s16(input_dup3[j]), + vget_low_s16(filter[j])); + acc[1].val[j] = vmlal_s16(acc[1].val[j], vget_high_s16(input_dup3[j]), + vget_high_s16(filter[j])); + } + // Store the accumulators back to acc_buffer + for (int i = 0; i < 2; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i].val[0]); + vst1q_s32(acc_buffer_ptr + 4 * i + 8, acc[i].val[1]); + vst1q_s32(acc_buffer_ptr + 4 * i + 16, acc[i].val[2]); + } + acc_buffer_ptr += 24; + } + // Handle one input channel at a time. + for (; ic < input_depth; ic++) { + const int16 input_val = *local_input_ptr++ + input_offset; + for (int i = 0; i < 3; i++) { + const int16 filter_val = local_filter_ptr[i] + filter_offset; + *acc_buffer_ptr++ += static_cast(filter_val) * input_val; + } + local_filter_ptr += 3; + } + input_ptr += input_ptr_increment; + } + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + const uint8* local_filter_ptr = filter_ptr; + const uint8* local_input_ptr = input_ptr; + int ic = 0; + // Handle 8 input channels at a time. + for (; ic <= input_depth - 8; ic += 8) { + // Load the filters, add filter_offset. + int16x8_t filter[2]; + uint8x8x2_t filter_u8; + filter_u8.val[0] = vld1_u8(local_filter_ptr); + filter_u8.val[1] = vld1_u8(local_filter_ptr + 8); + local_filter_ptr += 16; + for (int i = 0; i < 2; i++) { + const int16x8_t filter_s16 = + vreinterpretq_s16_u16(vmovl_u8(filter_u8.val[i])); + filter[i] = vaddq_s16(filter_s16, vdupq_n_s16(filter_offset)); + } + // Load the inputs, add input_offset, duplicate 2-fold. + const uint8x8_t input_u8 = vld1_u8(local_input_ptr); + local_input_ptr += 8; + const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8)); + const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset)); + const int16x8x2_t input_dup2 = vzipq_s16(input, input); + // Load the accumulators from acc_buffer. + int32x4x2_t acc[2]; + for (int i = 0; i < 2; i++) { + acc[i].val[0] = vld1q_s32(acc_buffer_ptr + 4 * i); + acc[i].val[1] = vld1q_s32(acc_buffer_ptr + 4 * i + 8); + } + // Multiply-accumulate. + for (int j = 0; j < 2; j++) { + acc[0].val[j] = vmlal_s16(acc[0].val[j], vget_low_s16(filter[j]), + vget_low_s16(input_dup2.val[j])); + acc[1].val[j] = vmlal_s16(acc[1].val[j], vget_high_s16(filter[j]), + vget_high_s16(input_dup2.val[j])); + } + // Store the accumulators back to acc_buffer. + for (int i = 0; i < 2; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i].val[0]); + vst1q_s32(acc_buffer_ptr + 4 * i + 8, acc[i].val[1]); + } + acc_buffer_ptr += 16; + } + // Handle one input channel at a time. + for (; ic < input_depth; ic++) { + // Load the inputs. + const int16 input_val = *local_input_ptr++ + input_offset; + for (int i = 0; i < 2; i++) { + const int16 filter_val = local_filter_ptr[i] + filter_offset; + *acc_buffer_ptr++ += static_cast(filter_val) * input_val; + } + local_filter_ptr += 2; + } + input_ptr += input_ptr_increment; + } + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + const uint8* local_filter_ptr = filter_ptr; + const uint8* local_input_ptr = input_ptr; + int ic = 0; + // Handle 16 input channels at a time. + for (; ic <= input_depth - 16; ic += 16) { + // Load the filters, add filter_offset. + uint8x8_t filter_u8_0 = vld1_u8(local_filter_ptr + 8 * 0); + uint8x8_t filter_u8_1 = vld1_u8(local_filter_ptr + 8 * 1); + local_filter_ptr += 16; + int16x8_t filter_0 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_0)); + int16x8_t filter_1 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_1)); + filter_0 = vaddq_s16(filter_0, vdupq_n_s16(filter_offset)); + filter_1 = vaddq_s16(filter_1, vdupq_n_s16(filter_offset)); + // Load the inputs, add input_offset. + uint8x8_t input_u8_0 = vld1_u8(local_input_ptr + 8 * 0); + uint8x8_t input_u8_1 = vld1_u8(local_input_ptr + 8 * 1); + local_input_ptr += 16; + int16x8_t input_0 = vreinterpretq_s16_u16(vmovl_u8(input_u8_0)); + int16x8_t input_1 = vreinterpretq_s16_u16(vmovl_u8(input_u8_1)); + input_0 = vaddq_s16(input_0, vdupq_n_s16(input_offset)); + input_1 = vaddq_s16(input_1, vdupq_n_s16(input_offset)); + // Load the accumulators from acc_buffer + int32x4_t acc_0 = vld1q_s32(acc_buffer_ptr + 4 * 0); + int32x4_t acc_1 = vld1q_s32(acc_buffer_ptr + 4 * 1); + int32x4_t acc_2 = vld1q_s32(acc_buffer_ptr + 4 * 2); + int32x4_t acc_3 = vld1q_s32(acc_buffer_ptr + 4 * 3); + acc_0 = vmlal_s16(acc_0, vget_low_s16(input_0), vget_low_s16(filter_0)); + acc_1 = + vmlal_s16(acc_1, vget_high_s16(input_0), vget_high_s16(filter_0)); + acc_2 = vmlal_s16(acc_2, vget_low_s16(input_1), vget_low_s16(filter_1)); + acc_3 = + vmlal_s16(acc_3, vget_high_s16(input_1), vget_high_s16(filter_1)); + // Store the accumulators back to acc_buffer + vst1q_s32(acc_buffer_ptr + 4 * 0, acc_0); + vst1q_s32(acc_buffer_ptr + 4 * 1, acc_1); + vst1q_s32(acc_buffer_ptr + 4 * 2, acc_2); + vst1q_s32(acc_buffer_ptr + 4 * 3, acc_3); + acc_buffer_ptr += 16; + } + // Handle 8 input channels at a time. + for (; ic <= input_depth - 8; ic += 8) { + // Load the filters, add filter_offset. + const uint8x8_t filter_u8 = vld1_u8(local_filter_ptr); + local_filter_ptr += 8; + const int16x8_t filter_s16 = vreinterpretq_s16_u16(vmovl_u8(filter_u8)); + const int16x8_t filter = + vaddq_s16(filter_s16, vdupq_n_s16(filter_offset)); + // Load the inputs, add input_offset. + const uint8x8_t input_u8 = vld1_u8(local_input_ptr); + local_input_ptr += 8; + const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8)); + const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset)); + // Load the accumulators from acc_buffer + int32x4_t acc[2]; + for (int i = 0; i < 2; i++) { + acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i); + } + // Multiply-accumulate + acc[0] = vmlal_s16(acc[0], vget_low_s16(input), vget_low_s16(filter)); + acc[1] = vmlal_s16(acc[1], vget_high_s16(input), vget_high_s16(filter)); + // Store the accumulators back to acc_buffer + for (int i = 0; i < 2; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 8; + } + // Handle one input channel at a time. + for (; ic < input_depth; ic++) { + const int16 input_val = *local_input_ptr++ + input_offset; + const int16 filter_val = *local_filter_ptr++ + filter_offset; + *acc_buffer_ptr++ += static_cast(filter_val) * input_val; + } + input_ptr += input_ptr_increment; + } + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Load the filters, add filter_offset. + uint8x8_t filter_u8[2]; + for (int i = 0; i < 2; i++) { + filter_u8[i] = vld1_u8(filter_ptr + 8 * i); + } + int16x8_t filter[2]; + for (int i = 0; i < 2; i++) { + filter[i] = vreinterpretq_s16_u16(vmovl_u8(filter_u8[i])); + } + for (int i = 0; i < 2; i++) { + filter[i] = vaddq_s16(filter[i], vdupq_n_s16(filter_offset)); + } + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + // Load the inputs, add input_offset. + uint8x8_t input_u8[2]; + for (int i = 0; i < 2; i++) { + input_u8[i] = vld1_u8(input_ptr + 8 * i); + } + input_ptr += input_ptr_increment; + int16x8_t input[2]; + for (int i = 0; i < 2; i++) { + input[i] = vreinterpretq_s16_u16(vmovl_u8(input_u8[i])); + } + for (int i = 0; i < 2; i++) { + input[i] = vaddq_s16(input[i], vdupq_n_s16(input_offset)); + } + // Load the accumulators from acc_buffer + int32x4_t acc[4]; + for (int i = 0; i < 4; i++) { + acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i); + } + // Multiply-accumulate + for (int i = 0; i < 2; i++) { + acc[2 * i + 0] = vmlal_s16(acc[2 * i + 0], vget_low_s16(input[i]), + vget_low_s16(filter[i])); + acc[2 * i + 1] = vmlal_s16(acc[2 * i + 1], vget_high_s16(input[i]), + vget_high_s16(filter[i])); + } + // Store the accumulators back to acc_buffer + for (int i = 0; i < 4; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 16; + } + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Load the filters, add filter_offset. + const uint8x8_t filter_u8 = vld1_u8(filter_ptr); + const int16x8_t filter_s16 = vreinterpretq_s16_u16(vmovl_u8(filter_u8)); + const int16x8_t filter = vaddq_s16(filter_s16, vdupq_n_s16(filter_offset)); + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + // Load the inputs, add input_offset. + const uint8x8_t input_u8 = vld1_u8(input_ptr); + const int16x8_t input_s16 = vreinterpretq_s16_u16(vmovl_u8(input_u8)); + const int16x8_t input = vaddq_s16(input_s16, vdupq_n_s16(input_offset)); + // Load the accumulators from acc_buffer + int32x4_t acc[2]; + for (int i = 0; i < 2; i++) { + acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i); + } + // Multiply-accumulate + acc[0] = vmlal_s16(acc[0], vget_low_s16(input), vget_low_s16(filter)); + acc[1] = vmlal_s16(acc[1], vget_high_s16(input), vget_high_s16(filter)); + // Store the accumulators back to acc_buffer + for (int i = 0; i < 2; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 8; + input_ptr += input_ptr_increment; + } + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Load the filters, add filter_offset. + uint8x8_t filter_u8[2]; + for (int i = 0; i < 2; i++) { + filter_u8[i] = vld1_u8(filter_ptr + 8 * i); + } + int16x8_t filter[2]; + for (int i = 0; i < 2; i++) { + filter[i] = vreinterpretq_s16_u16(vmovl_u8(filter_u8[i])); + } + for (int i = 0; i < 2; i++) { + filter[i] = vaddq_s16(filter[i], vdupq_n_s16(filter_offset)); + } + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + uint8 input_u8 = *input_ptr; + input_ptr += input_ptr_increment; + int16 input = static_cast(input_u8 + input_offset); + // Load the accumulators from acc_buffer + int32x4_t acc[4]; + for (int i = 0; i < 4; i++) { + acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i); + } + // Multiply-accumulate + for (int i = 0; i < 2; i++) { + acc[2 * i + 0] = + vmlal_n_s16(acc[2 * i + 0], vget_low_s16(filter[i]), input); + acc[2 * i + 1] = + vmlal_n_s16(acc[2 * i + 1], vget_high_s16(filter[i]), input); + } + // Store the accumulators back to acc_buffer + for (int i = 0; i < 4; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 16; + } + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Load the filters, add filter_offset. + uint8x8_t filter_u8_0 = vld1_u8(filter_ptr + 8 * 0); + uint8x8_t filter_u8_1 = vld1_u8(filter_ptr + 8 * 1); + uint8x8_t filter_u8_2 = vld1_u8(filter_ptr + 8 * 2); + uint8x8_t filter_u8_3 = vld1_u8(filter_ptr + 8 * 3); + int16x8_t filter_0 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_0)); + int16x8_t filter_1 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_1)); + int16x8_t filter_2 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_2)); + int16x8_t filter_3 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_3)); + filter_0 = vaddq_s16(filter_0, vdupq_n_s16(filter_offset)); + filter_1 = vaddq_s16(filter_1, vdupq_n_s16(filter_offset)); + filter_2 = vaddq_s16(filter_2, vdupq_n_s16(filter_offset)); + filter_3 = vaddq_s16(filter_3, vdupq_n_s16(filter_offset)); + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + uint8 input_u8 = *input_ptr; + input_ptr += input_ptr_increment; + int16 input = static_cast(input_u8 + input_offset); + // Load the accumulators from acc_buffer + int32x4_t acc_0 = vld1q_s32(acc_buffer_ptr + 4 * 0); + int32x4_t acc_1 = vld1q_s32(acc_buffer_ptr + 4 * 1); + int32x4_t acc_2 = vld1q_s32(acc_buffer_ptr + 4 * 2); + int32x4_t acc_3 = vld1q_s32(acc_buffer_ptr + 4 * 3); + int32x4_t acc_4 = vld1q_s32(acc_buffer_ptr + 4 * 4); + int32x4_t acc_5 = vld1q_s32(acc_buffer_ptr + 4 * 5); + int32x4_t acc_6 = vld1q_s32(acc_buffer_ptr + 4 * 6); + int32x4_t acc_7 = vld1q_s32(acc_buffer_ptr + 4 * 7); + // Multiply-accumulate + acc_0 = vmlal_n_s16(acc_0, vget_low_s16(filter_0), input); + acc_1 = vmlal_n_s16(acc_1, vget_high_s16(filter_0), input); + acc_2 = vmlal_n_s16(acc_2, vget_low_s16(filter_1), input); + acc_3 = vmlal_n_s16(acc_3, vget_high_s16(filter_1), input); + acc_4 = vmlal_n_s16(acc_4, vget_low_s16(filter_2), input); + acc_5 = vmlal_n_s16(acc_5, vget_high_s16(filter_2), input); + acc_6 = vmlal_n_s16(acc_6, vget_low_s16(filter_3), input); + acc_7 = vmlal_n_s16(acc_7, vget_high_s16(filter_3), input); + // Store the accumulators back to acc_buffer + vst1q_s32(acc_buffer_ptr + 4 * 0, acc_0); + vst1q_s32(acc_buffer_ptr + 4 * 1, acc_1); + vst1q_s32(acc_buffer_ptr + 4 * 2, acc_2); + vst1q_s32(acc_buffer_ptr + 4 * 3, acc_3); + vst1q_s32(acc_buffer_ptr + 4 * 4, acc_4); + vst1q_s32(acc_buffer_ptr + 4 * 5, acc_5); + vst1q_s32(acc_buffer_ptr + 4 * 6, acc_6); + vst1q_s32(acc_buffer_ptr + 4 * 7, acc_7); + acc_buffer_ptr += 32; + } + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Load the filters, add filter_offset. + const uint8x8_t filter_u8 = vld1_u8(filter_ptr); + const int16x8_t filter = vaddq_s16( + vreinterpretq_s16_u16(vmovl_u8(filter_u8)), vdupq_n_s16(filter_offset)); + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + uint8 input_u8 = *input_ptr; + input_ptr += input_ptr_increment; + int16 input = static_cast(input_u8 + input_offset); + // Load the accumulators from acc_buffer + int32x4_t acc[2]; + for (int i = 0; i < 2; i++) { + acc[i] = vld1q_s32(acc_buffer_ptr + 4 * i); + } + // Multiply-accumulate + acc[0] = vmlal_n_s16(acc[0], vget_low_s16(filter), input); + acc[1] = vmlal_n_s16(acc[1], vget_high_s16(filter), input); + // Store the accumulators back to acc_buffer + for (int i = 0; i < 2; i++) { + vst1q_s32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 8; + } + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Load the filters, add filter_offset. + uint8x8_t filter_u8 = vdup_n_u8(0); + filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 0); + filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 1); + filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 2); + filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 3); + const int16x4_t filter_s16 = + vreinterpret_s16_u16(vget_low_u16(vmovl_u8(filter_u8))); + const int16x4_t filter = vadd_s16(filter_s16, vdup_n_s16(filter_offset)); + + int outp = 0; + + // Handle 2 output pixels at a time. + for (; outp <= num_output_pixels - 2; outp += 2) { + // Load the accumulators from acc_buffer. + int32x4_t acc = vld1q_s32(acc_buffer_ptr); + // Load the inputs, add input_offset. + uint16x4_t input_u16 = vdup_n_u16(0); + input_u16 = vset_lane_u16((reinterpret_cast(input_ptr))[0], + input_u16, 0); + input_ptr += input_ptr_increment; + input_u16 = vset_lane_u16((reinterpret_cast(input_ptr))[0], + input_u16, 1); + input_ptr += input_ptr_increment; + const int16x4_t input_s16 = vreinterpret_s16_u16( + vget_low_u16(vmovl_u8(vreinterpret_u8_u16(input_u16)))); + const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset)); + + // Multiply-accumulate. + acc = vmlal_s16(acc, filter, input); + // Store the accumulators back to acc_buffer. + vst1q_s32(acc_buffer_ptr, acc); + acc_buffer_ptr += 4; + } + + // Handle 1 output pixel at a time. + for (; outp < num_output_pixels; outp++) { + // Load the accumulators from acc_buffer. + int32x2_t acc = vld1_s32(acc_buffer_ptr); + // Load the inputs, add input_offset. + uint8x8_t input_u8 = vdup_n_u8(0); + input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0); + input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1); + input_ptr += input_ptr_increment; + const int16x4_t input_s16 = + vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8))); + const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset)); + + // Multiply-accumulate. + acc = vget_low_s32(vmlal_s16(vcombine_s32(acc, acc), filter, input)); + // Store the accumulators back to acc_buffer. + vst1_s32(acc_buffer_ptr, acc); + acc_buffer_ptr += 2; + } + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + if (num_output_pixels <= 0) { + return; + } + + // Load the filters, add filter_offset. + uint8x8_t filter_u8 = vdup_n_u8(0); + filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 0); + filter_u8 = vset_lane_u8(filter_ptr[1], filter_u8, 1); + filter_u8 = vset_lane_u8(filter_ptr[2], filter_u8, 2); + filter_u8 = vset_lane_u8(filter_ptr[3], filter_u8, 3); + const int16x4_t filter_s16 = + vreinterpret_s16_u16(vget_low_u16(vmovl_u8(filter_u8))); + const int16x4_t filter = vadd_s16(filter_s16, vdup_n_s16(filter_offset)); + + int outp = 0; + + // Handle one output pixel at a time until second to the last pixel. Second + // to the last because we read eight input pixels while only processing + // four. + for (; outp < num_output_pixels - 1; outp++) { + // Load the accumulators from acc_buffer + int32x4_t acc; + acc = vld1q_s32(acc_buffer_ptr); + + // Load the inputs, add input_offset. + uint8x8_t input_u8 = vld1_u8(input_ptr); + input_ptr += input_ptr_increment; + const int16x4_t input_s16 = + vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8))); + const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset)); + // Multiply-accumulate + acc = vmlal_s16(acc, filter, input); + // Store the accumulators back to acc_buffer + vst1q_s32(acc_buffer_ptr, acc); + acc_buffer_ptr += 4; + } + + // Handle the last output pixel. + // Load the accumulators from acc_buffer + int32x4_t acc; + acc = vld1q_s32(acc_buffer_ptr); + + // Load the inputs, add input_offset. + uint8x8_t input_u8 = vdup_n_u8(0); + input_u8 = vset_lane_u8(input_ptr[0], input_u8, 0); + input_u8 = vset_lane_u8(input_ptr[1], input_u8, 1); + input_u8 = vset_lane_u8(input_ptr[2], input_u8, 2); + input_u8 = vset_lane_u8(input_ptr[3], input_u8, 3); + const int16x4_t input_s16 = + vreinterpret_s16_u16(vget_low_u16(vmovl_u8(input_u8))); + const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset)); + // Multiply-accumulate + acc = vmlal_s16(acc, filter, input); + // Store the accumulators back to acc_buffer + vst1q_s32(acc_buffer_ptr, acc); + } +}; + +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Load the filters, add filter_offset. + uint8x8_t filter_u8_0 = vld1_u8(filter_ptr); + uint8x8_t filter_u8_1 = vld1_u8(filter_ptr + 4); + int16x8_t filter_s16_0 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_0)); + int16x8_t filter_s16_1 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_1)); + filter_s16_0 = vaddq_s16(filter_s16_0, vdupq_n_s16(filter_offset)); + filter_s16_1 = vaddq_s16(filter_s16_1, vdupq_n_s16(filter_offset)); + int16x4_t filter_0 = vget_low_s16(filter_s16_0); + int16x4_t filter_1 = vget_high_s16(filter_s16_0); + int16x4_t filter_2 = vget_high_s16(filter_s16_1); + + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + // Load the inputs, add input_offset. + uint8x8_t input_u8_0 = vld1_u8(input_ptr); + uint8x8_t input_u8_1 = vld1_u8(input_ptr + 4); + input_ptr += input_ptr_increment; + int16x8_t input_0 = vreinterpretq_s16_u16(vmovl_u8(input_u8_0)); + int16x8_t input_1 = vreinterpretq_s16_u16(vmovl_u8(input_u8_1)); + input_0 = vaddq_s16(input_0, vdupq_n_s16(input_offset)); + input_1 = vaddq_s16(input_1, vdupq_n_s16(input_offset)); + + // Load the accumulators from acc_buffer + int32x4_t acc_0 = vld1q_s32(acc_buffer_ptr + 4 * 0); + int32x4_t acc_1 = vld1q_s32(acc_buffer_ptr + 4 * 1); + int32x4_t acc_2 = vld1q_s32(acc_buffer_ptr + 4 * 2); + + // Multiply-accumulate + acc_0 = vmlal_s16(acc_0, vget_low_s16(input_0), filter_0); + acc_1 = vmlal_s16(acc_1, vget_high_s16(input_0), filter_1); + acc_2 = vmlal_s16(acc_2, vget_high_s16(input_1), filter_2); + + // Store the accumulators back to acc_buffer + vst1q_s32(acc_buffer_ptr + 4 * 0, acc_0); + vst1q_s32(acc_buffer_ptr + 4 * 1, acc_1); + vst1q_s32(acc_buffer_ptr + 4 * 2, acc_2); + + acc_buffer_ptr += 12; + } + } +}; +#endif + +// Accumulates the effect of one row of the filter, on a segment of one row +// of the output, accessing the corresponding one row of the input. +template +void QuantizedDepthwiseConvAccumRow( + int stride, int input_depth, int input_width, const uint8* input_data, + int16 input_offset, int pad_width, int depth_multiplier, int filter_width, + const uint8* filter_data, int16 filter_offset, int out_x_buffer_start, + int out_x_buffer_end, int output_depth, int32* acc_buffer) { +#ifdef GEMMLOWP_PROFILING + gemmlowp::ScopedProfilingLabel label(__PRETTY_FUNCTION__); +#endif + // Sanity check parameters. This is important in particular to ensure + // that we keep the number of template instantiations minimal, so we don't + // increase binary size unnecessarily. + static_assert(kFixedDepthMultiplier || !kFixedInputDepth, ""); + static_assert(kFixedInputDepth || kAllowStrided, ""); + TFLITE_DCHECK(stride == 1 || kAllowStrided); + if (kFixedInputDepth) { + TFLITE_DCHECK_EQ(input_depth, kFixedInputDepth); + } + if (kFixedDepthMultiplier) { + TFLITE_DCHECK_EQ(depth_multiplier, kFixedDepthMultiplier); + } + TFLITE_DCHECK_EQ(output_depth, input_depth * depth_multiplier); + const int input_ptr_increment = stride * input_depth; + const uint8* filter_base_ptr = filter_data; + for (int filter_x = 0; filter_x < filter_width; ++filter_x) { + // For the current (filter_x, filter_y) point in the filter, + // compute the boundaries of the corresponding output row segment. + int out_x_loop_start_unclampled = 0; + int out_x_loop_end_unclampled = 0; + if (kAllowStrided) { + if (stride == 2) { + out_x_loop_start_unclampled = (pad_width - filter_x + 1) / 2; + out_x_loop_end_unclampled = + (pad_width + input_width - filter_x + 1) / 2; + } else if (stride == 4) { + out_x_loop_start_unclampled = (pad_width - filter_x + 3) / 4; + out_x_loop_end_unclampled = + (pad_width + input_width - filter_x + 3) / 4; + } else { + out_x_loop_start_unclampled = + (pad_width - filter_x + stride - 1) / stride; + out_x_loop_end_unclampled = + (pad_width + input_width - filter_x + stride - 1) / stride; + } + } else { + out_x_loop_start_unclampled = pad_width - filter_x; + out_x_loop_end_unclampled = pad_width + input_width - filter_x; + } + // The kernel will have to iterate on the segment of the + // output row that starts at out_x_loop_start and out_x_loop_end. + const int out_x_loop_start = + std::max(out_x_buffer_start, out_x_loop_start_unclampled); + const int out_x_loop_end = + std::min(out_x_buffer_end, out_x_loop_end_unclampled); + + int32* acc_buffer_ptr = + acc_buffer + (out_x_loop_start - out_x_buffer_start) * output_depth; + const int in_x_origin = (out_x_loop_start * stride) - pad_width + filter_x; + const uint8* input_ptr = input_data + in_x_origin * input_depth; + const int num_output_pixels = out_x_loop_end - out_x_loop_start; + QuantizedDepthwiseConvKernel< + kAllowStrided, kFixedInputDepth, + kFixedDepthMultiplier>::Run(num_output_pixels, input_depth, + depth_multiplier, input_ptr, input_offset, + input_ptr_increment, filter_base_ptr, + filter_offset, acc_buffer_ptr); + filter_base_ptr += output_depth; + } +} + +// generic fallback of DepthwiseConvAccumRow, portable, non-templatized. +inline void QuantizedDepthwiseConvAccumRowGeneric( + int stride, int input_depth, int input_width, const uint8* input_data, + int16 input_offset, int pad_width, int depth_multiplier, int filter_width, + const uint8* filter_data, int16 filter_offset, int out_x_buffer_start, + int out_x_buffer_end, int output_depth, int32* acc_buffer) { + gemmlowp::ScopedProfilingLabel label("DepthwiseConvAccumRowGeneric (slow)"); +#ifdef TFLITE_PREVENT_SLOW_GENERIC_DEPTHWISECONV_FALLBACK +#ifndef ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK + LOG(FATAL) + << "\n\n" + << "*****************************************************************\n" + << "* This tfmini inference code was about to use the slow generic\n" + << "* fallback implementation for a DepthwiseConv op, and we want you\n" + << "* to be aware of that so that you will know why you get terrible\n" + << "* performance.\n" + << "*\n" + << "* If you would like to carry on with the slow code, compile\n" + << "* with this preprocessor token defined:\n" + << "* TFLITE_ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK.\n" + << "*\n" + << "* The right thing to do, if you care about performance, is to add\n" + << "* a new DepthwiseConv kernel to tfmini to cover your case.\n" + << "* The relevant parameters defining your case are:\n" + << "* stride = " << stride << "\n" + << "* input_depth = " << input_depth << "\n" + << "* depth_multiplier = " << depth_multiplier << "\n" + << "*\n" + << "* Please do not hesitate to contact benoitjacob@ with this\n" + << "* information.\n" + << "*****************************************************************\n"; +#endif // ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK +#endif // TFLITE_PREVENT_SLOW_GENERIC_DEPTHWISECONV_FALLBACK + const uint8* filter_base_ptr = filter_data; + for (int filter_x = 0; filter_x < filter_width; ++filter_x) { + const int out_x_loop_start = std::max( + out_x_buffer_start, (pad_width - filter_x + stride - 1) / stride); + const int out_x_loop_end = + std::min(out_x_buffer_end, + (pad_width + input_width - filter_x + stride - 1) / stride); + + int32* acc_buffer_ptr = + acc_buffer + (out_x_loop_start - out_x_buffer_start) * output_depth; + const int in_x_origin = (out_x_loop_start * stride) - pad_width + filter_x; + const uint8* input_ptr = input_data + in_x_origin * input_depth; + const int input_ptr_increment = (stride - 1) * input_depth; + for (int out_x = out_x_loop_start; out_x < out_x_loop_end; out_x++) { + const uint8* filter_ptr = filter_base_ptr; + for (int ic = 0; ic < input_depth; ++ic) { + const int16 input_val = *input_ptr++ + input_offset; + for (int m = 0; m < depth_multiplier; m++) { + const int16 filter_val = *filter_ptr++ + filter_offset; + *acc_buffer_ptr++ += static_cast(filter_val) * input_val; + } + } + input_ptr += input_ptr_increment; + } + filter_base_ptr += output_depth; + } +} + +// Initializes the accumulator buffer with bias values. +inline void DepthwiseConvInitAccBuffer(int num_output_pixels, int output_depth, + const int32* bias_data, + int32* acc_buffer) { + int i = 0; +#ifdef USE_NEON + if (output_depth == 1) { + const int32x4_t b = vdupq_n_s32(bias_data[0]); + for (; i <= num_output_pixels - 16; i += 16) { + vst1q_s32(acc_buffer + i + 0, b); + vst1q_s32(acc_buffer + i + 4, b); + vst1q_s32(acc_buffer + i + 8, b); + vst1q_s32(acc_buffer + i + 12, b); + } + for (; i <= num_output_pixels - 4; i += 4) { + vst1q_s32(acc_buffer + i, b); + } + } else if (output_depth == 2) { + int32x4_t b = vdupq_n_s32(bias_data[0]); + b = vsetq_lane_s32(bias_data[1], b, 1); + b = vsetq_lane_s32(bias_data[1], b, 3); + for (; i <= num_output_pixels - 8; i += 8) { + vst1q_s32(acc_buffer + 2 * i + 0, b); + vst1q_s32(acc_buffer + 2 * i + 4, b); + vst1q_s32(acc_buffer + 2 * i + 8, b); + vst1q_s32(acc_buffer + 2 * i + 12, b); + } + for (; i <= num_output_pixels - 2; i += 2) { + vst1q_s32(acc_buffer + 2 * i, b); + } + } else if (output_depth == 4) { + const int32x4_t b = vld1q_s32(bias_data); + for (; i <= num_output_pixels - 4; i += 4) { + vst1q_s32(acc_buffer + 4 * i + 0, b); + vst1q_s32(acc_buffer + 4 * i + 4, b); + vst1q_s32(acc_buffer + 4 * i + 8, b); + vst1q_s32(acc_buffer + 4 * i + 12, b); + } + for (; i < num_output_pixels; i++) { + vst1q_s32(acc_buffer + 4 * i, b); + } + } else if (output_depth == 8) { + const int32x4_t b0 = vld1q_s32(bias_data); + const int32x4_t b1 = vld1q_s32(bias_data + 4); + for (; i <= num_output_pixels - 2; i += 2) { + vst1q_s32(acc_buffer + 8 * i + 0, b0); + vst1q_s32(acc_buffer + 8 * i + 4, b1); + vst1q_s32(acc_buffer + 8 * i + 8, b0); + vst1q_s32(acc_buffer + 8 * i + 12, b1); + } + for (; i < num_output_pixels; i++) { + vst1q_s32(acc_buffer + 8 * i + 0, b0); + vst1q_s32(acc_buffer + 8 * i + 4, b1); + } + } else if (output_depth == 16) { + const int32x4_t b0 = vld1q_s32(bias_data); + const int32x4_t b1 = vld1q_s32(bias_data + 4); + const int32x4_t b2 = vld1q_s32(bias_data + 8); + const int32x4_t b3 = vld1q_s32(bias_data + 12); + for (; i < num_output_pixels; i++) { + vst1q_s32(acc_buffer + 16 * i + 0, b0); + vst1q_s32(acc_buffer + 16 * i + 4, b1); + vst1q_s32(acc_buffer + 16 * i + 8, b2); + vst1q_s32(acc_buffer + 16 * i + 12, b3); + } + } +#endif + for (; i < num_output_pixels; i++) { + memcpy(acc_buffer + i * output_depth, bias_data, + sizeof(acc_buffer[0]) * output_depth); + } +} + +inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims, + int32 input_offset, const uint8* filter_data, + const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int depth_multiplier, + int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("DepthwiseConv/8bit"); + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int input_depth = ArraySize(input_dims, 0); + const int filter_height = ArraySize(filter_dims, 2); + const int filter_width = ArraySize(filter_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + TFLITE_DCHECK(output_depth == input_depth * depth_multiplier); + + static const int kAccBufferMaxSize = 2048; + int32 acc_buffer[kAccBufferMaxSize]; + TFLITE_DCHECK_GE(kAccBufferMaxSize, output_depth); + const int kOutputPixelsInAccBuffer = kAccBufferMaxSize / output_depth; + const int kAccBufferActualSize = kOutputPixelsInAccBuffer * output_depth; + TFLITE_DCHECK_LE(kOutputPixelsInAccBuffer * output_depth, + kAccBufferActualSize); + TFLITE_DCHECK_LE(kAccBufferActualSize, kAccBufferMaxSize); + TFLITE_DCHECK_GE(kOutputPixelsInAccBuffer, 1); + + // row_accum_func will point to the core accumulation function to be used + // for this DepthwiseConv op. + using row_accum_func_t = decltype(&QuantizedDepthwiseConvAccumRowGeneric); + row_accum_func_t row_accum_func = nullptr; + +#define TFMINI_USE_DEPTHWISECONV_KERNEL(ALLOW_STRIDED, FIXED_INPUT_DEPTH, \ + FIXED_DEPTH_MULTIPLIER) \ + if (!row_accum_func && (stride_width == 1 || ALLOW_STRIDED) && \ + (input_depth == FIXED_INPUT_DEPTH || FIXED_INPUT_DEPTH == 0) && \ + depth_multiplier == FIXED_DEPTH_MULTIPLIER) { \ + row_accum_func = \ + QuantizedDepthwiseConvAccumRow; \ + } + +#ifdef USE_NEON + // We go over our list of kernels by decreasing order of preference + // for the cases where multiple kernels could apply. + + // Start with the fastest kernels: AllowStrided=false, fixed input depth. + + TFMINI_USE_DEPTHWISECONV_KERNEL(false, 1, 2) + TFMINI_USE_DEPTHWISECONV_KERNEL(false, 2, 2) + TFMINI_USE_DEPTHWISECONV_KERNEL(false, 4, 2) + TFMINI_USE_DEPTHWISECONV_KERNEL(false, 1, 4) + TFMINI_USE_DEPTHWISECONV_KERNEL(false, 4, 1) + TFMINI_USE_DEPTHWISECONV_KERNEL(false, 4, 4) + TFMINI_USE_DEPTHWISECONV_KERNEL(false, 8, 1) + TFMINI_USE_DEPTHWISECONV_KERNEL(false, 2, 8) + TFMINI_USE_DEPTHWISECONV_KERNEL(false, 2, 1) + TFMINI_USE_DEPTHWISECONV_KERNEL(false, 12, 1) + + // Next come the strided kernels: AllowStrided=true, fixed input depth. + // They are a bit less efficient, but allow stride!=1. + + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 8, 2) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 16, 1) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 1, 16) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 1, 32) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 1, 8) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 8, 1) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 2, 1) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 4, 1) + + // Finally, the kernels allowing a variable input depth, + // these are the least efficient but most general kernels. + + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 0, 1) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 0, 2) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 0, 3) +#endif // USE_NEON + + // No matching fast kernel found, use slow fallback. + if (!row_accum_func) { + row_accum_func = QuantizedDepthwiseConvAccumRowGeneric; + } + +#undef TFMINI_USE_DEPTHWISECONV_KERNEL + + // Now that we have determined row_accum_func, we can start work. + uint8* output_ptr = output_data; + for (int b = 0; b < batches; ++b) { + for (int out_y = 0; out_y < output_height; ++out_y) { + const int in_y_origin = (out_y * stride_height) - pad_height; + const int filter_y_start = std::max(0, -in_y_origin); + const int filter_y_end = + std::min(filter_height, input_height - in_y_origin); + for (int out_x_buffer_start = 0; out_x_buffer_start < output_width; + out_x_buffer_start += kOutputPixelsInAccBuffer) { + const int out_x_buffer_end = std::min( + output_width, out_x_buffer_start + kOutputPixelsInAccBuffer); + // We call a 'pixel' a group of activation that share all but the + // 'depth'/'channel' coordinate. num_output_pixels is the number of + // output pixels that we will accumulate in this loop iteration. + const int num_output_pixels = out_x_buffer_end - out_x_buffer_start; + // Initialize our local accumulator with the bias values, so we don't + // have to add them later. + DepthwiseConvInitAccBuffer(num_output_pixels, output_depth, bias_data, + acc_buffer); + // Accumulation loop. Most of the time should be spent in here. + for (int filter_y = filter_y_start; filter_y < filter_y_end; + ++filter_y) { + const int in_y = in_y_origin + filter_y; + row_accum_func( + stride_width, input_depth, input_width, + input_data + in_y * input_dims.strides[2] + + b * input_dims.strides[3], + input_offset, pad_width, depth_multiplier, filter_width, + filter_data + filter_y * filter_dims.strides[2], filter_offset, + out_x_buffer_start, out_x_buffer_end, output_depth, acc_buffer); + } + // Finished accumulating int32 values. Now need to convert them to + // the final 8bit form and store them. + gemmlowp::ScopedProfilingLabel label("downquantize+store"); + const int num_output_values = output_depth * num_output_pixels; + int i = 0; +#ifdef USE_NEON + using gemmlowp::RoundingDivideByPOT; + const int32x4_t output_offset_vec = vdupq_n_s32(output_offset); + const int32x4_t output_activation_min_vec = + vdupq_n_s32(output_activation_min); + const int32x4_t output_activation_max_vec = + vdupq_n_s32(output_activation_max); + // Handle 16 values at once. + // This allows us to issue 4 mutually independent int32 + // multiplications (vqrdmulh), which should alleviate most of their + // high latency. + for (; i <= num_output_values - 16; i += 16) { + int32x4_t acc[4]; + for (int j = 0; j < 4; j++) { + acc[j] = vld1q_s32(acc_buffer + i + 4 * j); + } + + // Fixed-point multiplication. + for (int j = 0; j < 4; j++) { + acc[j] = vqrdmulhq_n_s32(acc[j], output_multiplier); + } + for (int j = 0; j < 4; j++) { + acc[j] = RoundingDivideByPOT(acc[j], output_shift); + } + // Add the output offset. + for (int j = 0; j < 4; j++) { + acc[j] = vaddq_s32(acc[j], output_offset_vec); + } + // Apply the activation function. + for (int j = 0; j < 4; j++) { + acc[j] = vmaxq_s32(acc[j], output_activation_min_vec); + } + for (int j = 0; j < 4; j++) { + acc[j] = vminq_s32(acc[j], output_activation_max_vec); + } + // Saturating cast to uint8 and store to destination. + int16x4_t acc_s16[4]; + for (int j = 0; j < 4; j++) { + acc_s16[j] = vqmovn_s32(acc[j]); + } + const int16x8_t res_s16_0 = vcombine_s16(acc_s16[0], acc_s16[1]); + const int16x8_t res_s16_1 = vcombine_s16(acc_s16[2], acc_s16[3]); + const uint8x8_t res_u8_0 = vqmovun_s16(res_s16_0); + const uint8x8_t res_u8_1 = vqmovun_s16(res_s16_1); + vst1q_u8(output_ptr, vcombine_u8(res_u8_0, res_u8_1)); + output_ptr += 16; + } + // Handle 8 values at once. + // Not as good as 16 (now we're only issuing 2 mutually independent + // vqrdmulh instructions, so we're probably paying for their high + // latency). + for (; i <= num_output_values - 8; i += 8) { + int32x4_t acc0 = vld1q_s32(acc_buffer + i); + int32x4_t acc1 = vld1q_s32(acc_buffer + i + 4); + // Fixed-point multiplication. + acc0 = vqrdmulhq_n_s32(acc0, output_multiplier); + acc1 = vqrdmulhq_n_s32(acc1, output_multiplier); + // Rounding right shift. + acc0 = RoundingDivideByPOT(acc0, output_shift); + acc1 = RoundingDivideByPOT(acc1, output_shift); + // Add the output offset. + acc0 = vaddq_s32(acc0, output_offset_vec); + acc1 = vaddq_s32(acc1, output_offset_vec); + // Apply the activation function. + acc0 = vmaxq_s32(acc0, output_activation_min_vec); + acc1 = vmaxq_s32(acc1, output_activation_min_vec); + acc0 = vminq_s32(acc0, output_activation_max_vec); + acc1 = vminq_s32(acc1, output_activation_max_vec); + // Saturating cast to uint8 and store to destination. + const int16x4_t acc0_s16 = vqmovn_s32(acc0); + const int16x4_t acc1_s16 = vqmovn_s32(acc1); + const int16x8_t res_s16 = vcombine_s16(acc0_s16, acc1_s16); + const uint8x8_t res_u8 = vqmovun_s16(res_s16); + vst1_u8(output_ptr, res_u8); + output_ptr += 8; + } + // Handle 4 values at once. Now we're paying the full price of the + // high latency of vqrdmulh. Also, storing only 4 bytes at the end + // (without any alignment) can only be done 1 byte at a time. + // Yet, that is still worth doing to minimize the amount of leftover + // that will have to go through the very slow scalar code. + for (; i <= num_output_values - 4; i += 4) { + int32x4_t acc = vld1q_s32(acc_buffer + i); + // Fixed-point multiplication. + acc = vqrdmulhq_n_s32(acc, output_multiplier); + // Rounding right shift. + acc = RoundingDivideByPOT(acc, output_shift); + // Add the output offset. + acc = vaddq_s32(acc, output_offset_vec); + // Apply the activation function. + acc = vmaxq_s32(acc, output_activation_min_vec); + acc = vminq_s32(acc, output_activation_max_vec); + // Saturating cast to uint8 and store to destination. + const int16x4_t acc_s16 = vqmovn_s32(acc); + const int16x8_t res_s16 = vcombine_s16(acc_s16, acc_s16); + const uint8x8_t res_u8 = vqmovun_s16(res_s16); + vst1_lane_u8(output_ptr + 0, res_u8, 0); + vst1_lane_u8(output_ptr + 1, res_u8, 1); + vst1_lane_u8(output_ptr + 2, res_u8, 2); + vst1_lane_u8(output_ptr + 3, res_u8, 3); + output_ptr += 4; + } +#endif // USE_NEON + + // Handle leftover values, one by one. This is very slow. + for (; i < num_output_values; i++) { + int32 acc = acc_buffer[i]; + acc = MultiplyByQuantizedMultiplierSmallerThanOne( + acc, output_multiplier, output_shift); + acc += output_offset; + acc = std::max(acc, output_activation_min); + acc = std::min(acc, output_activation_max); + *output_ptr++ = static_cast(acc); + } + } + } + } +} + +// Legacy, for compatibility with old checked-in code. +template +void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims, + int32 input_offset, const uint8* filter_data, + const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int depth_multiplier, int32 output_offset, + int32 output_multiplier, int output_shift, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + DepthwiseConv(input_data, input_dims, input_offset, filter_data, filter_dims, + filter_offset, bias_data, bias_dims, stride_width, + stride_height, pad_width, pad_height, depth_multiplier, + output_offset, output_multiplier, output_shift, + output_activation_min, output_activation_max, output_data, + output_dims); +} + +// Legacy, for compatibility with old checked-in code. +template +void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims, + int32 input_offset, const uint8* filter_data, + const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, int stride, + int pad_width, int pad_height, int depth_multiplier, + int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + DepthwiseConv(input_data, input_dims, input_offset, filter_data, + filter_dims, filter_offset, bias_data, bias_dims, stride, + stride, pad_width, pad_height, depth_multiplier, + output_offset, output_multiplier, output_shift, + output_activation_min, output_activation_max, output_data, + output_dims); +} + +} // namespace optimized_ops +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_UINT8_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/eigen_spatial_convolutions.h b/tensorflow/contrib/lite/kernels/internal/optimized/eigen_spatial_convolutions.h new file mode 100644 index 0000000000000000000000000000000000000000..8004c24a9914e216974539930853d0aadf61e324 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/optimized/eigen_spatial_convolutions.h @@ -0,0 +1,231 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Copied from tensorflow/core/kernels/eigen_spatial_convolutions.h. +// TODO(petewarden) - move this to a common location in Eigen itself. + +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_SPATIAL_CONVOLUTIONS_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_SPATIAL_CONVOLUTIONS_H_ + +#define EIGEN_USE_CUSTOM_THREAD_POOL +#define EIGEN_USE_THREADS + +// NOTE: Eigen is slightly different internally and externally. We need to +// hack the unsupported/Eigen/CXX11/Tensor header instantiation macros at +// specific places, so we need two copies of the hacked file, one for +// internal and one for external. +// If you have trouble simply undef out the reducer macro e.g. +// TFLITE_REDUCE_INSTANTIATIONS_GOOGLE, but be aware this will make +// the binary much bigger! +#define TFLITE_REDUCE_INSTANTIATIONS_OPEN_SOURCE +#define Eigen EigenForTFLite +#if defined(TFLITE_REDUCE_INSTANTIATIONS_GOOGLE) +#include "tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h" +#elif defined(TFLITE_REDUCE_INSTANTIATIONS_OPEN_SOURCE) +#include "tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_oss.h" +#else +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#endif + + +namespace Eigen { + +/** SpatialConvolution + * \ingroup CXX11_NeuralNetworks_Module + * + * \brief Applies a 2D convolution over a multichannel input image. + * + * The input parameter is expected to be a tensor with a rank of 3 or more + * (channels, height, width, and optionally others) + * The kernel parameter is expected to be a 4D tensor (filters, channels, + * kernel_height, kernel_width) + * The input and the kernel must both be in col-major layout. The result will + * also be in col-major layout. + * + * If col_in_stride, row_in_stride > 1, then applies convolution with holes + * (aka atrous convolution), sampling every col_in_stride, row_in_stride input + * pixels. + * + * The result can be assigned to a tensor of rank equal to the rank of the + * input. The dimensions of the result will be filters, height, width (and + * others if applicable). + * + * It is possible to swap the order of the width and height dimensions provided + * that the same order is used in the input, the kernel, and the output. + * + */ +template +EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE static const typename internal::conditional< + internal::traits::Layout == ColMajor, + TensorReshapingOp< + const DSizes::Index, + internal::traits::NumDimensions>, + const TensorContractionOp< + const array::Index>, + 1>, + const TensorReshapingOp< + const DSizes::Index, 2>, + const Kernel>, + const TensorReshapingOp< + const DSizes::Index, 2>, + const TensorImagePatchOp > > >, + TensorReshapingOp< + const DSizes::Index, + internal::traits::NumDimensions>, + const TensorContractionOp< + const array::Index>, + 1>, + const TensorReshapingOp< + const DSizes::Index, 2>, + const TensorImagePatchOp >, + const TensorReshapingOp< + const DSizes::Index, 2>, + const Kernel> > > >::type + SpatialConvolution(const Input& input, const Kernel& kernel, + const DenseIndex row_stride = 1, + const DenseIndex col_stride = 1, + const PaddingType padding_type = PADDING_SAME, + const DenseIndex row_in_stride = 1, + const DenseIndex col_in_stride = 1) { + typedef typename internal::traits::Index TensorIndex; + TensorRef::Scalar, + internal::traits::NumDimensions, + internal::traits::Layout, TensorIndex> > + in(input); + TensorRef::Scalar, + internal::traits::NumDimensions, + internal::traits::Layout, TensorIndex> > + kern(kernel); + + EIGEN_STATIC_ASSERT( + internal::traits::Layout == internal::traits::Layout, + YOU_MADE_A_PROGRAMMING_MISTAKE); + const bool isColMajor = (internal::traits::Layout == ColMajor); + + const int NumDims = internal::traits::NumDimensions; + + // Number of filters to apply. This is the same as the output depth of the + // result + const TensorIndex kernelFilters = + isColMajor ? kern.dimensions()[0] : kern.dimensions()[3]; + // Number of channels. This is the same as the input depth. + const TensorIndex kernelChannels = + isColMajor ? kern.dimensions()[1] : kern.dimensions()[2]; + const TensorIndex kernelRows = + isColMajor ? kern.dimensions()[2] : kern.dimensions()[1]; + const TensorIndex kernelCols = + isColMajor ? kern.dimensions()[3] : kern.dimensions()[0]; + + const DenseIndex kernelRowsEff = + kernelRows + (kernelRows - 1) * (row_in_stride - 1); + const DenseIndex kernelColsEff = + kernelCols + (kernelCols - 1) * (col_in_stride - 1); + + array, 1> contract_dims; + contract_dims[0] = IndexPair(1, 0); + + const TensorIndex InputRows = + isColMajor ? in.dimension(1) : in.dimension(NumDims - 2); + const TensorIndex InputCols = + isColMajor ? in.dimension(2) : in.dimension(NumDims - 3); + + TensorIndex out_height; + TensorIndex out_width; + switch (padding_type) { + case PADDING_VALID: + out_height = numext::ceil((InputRows - kernelRowsEff + 1.f) / + static_cast(row_stride)); + out_width = numext::ceil((InputCols - kernelColsEff + 1.f) / + static_cast(col_stride)); + break; + case PADDING_SAME: + out_height = numext::ceil(InputRows / static_cast(row_stride)); + out_width = numext::ceil(InputCols / static_cast(col_stride)); + break; + default: + // Initialize unused variables to avoid a compiler warning + out_height = 0; + out_width = 0; + eigen_assert(false && "unexpected padding"); + } + + // Molds the output of the patch extraction code into a 2d tensor: + // - the first dimension (dims[0]): the patch values to be multiplied with the + // kernels + // - the second dimension (dims[1]): everything else + DSizes pre_contract_dims; + if (isColMajor) { + pre_contract_dims[0] = kernelChannels * kernelRows * kernelCols; + pre_contract_dims[1] = out_height * out_width; + for (int i = 3; i < NumDims; ++i) { + pre_contract_dims[1] *= in.dimension(i); + } + } else { + pre_contract_dims[1] = kernelChannels * kernelRows * kernelCols; + pre_contract_dims[0] = out_height * out_width; + for (int i = 0; i < NumDims - 3; ++i) { + pre_contract_dims[0] *= in.dimension(i); + } + } + + // Molds the output of the contraction into the shape expected by the used + // (assuming this is ColMajor): + // - 1st dim: kernel filters + // - 2nd dim: output height + // - 3rd dim: output width + // - 4th dim and beyond: everything else including batch size + DSizes post_contract_dims; + if (isColMajor) { + post_contract_dims[0] = kernelFilters; + post_contract_dims[1] = out_height; + post_contract_dims[2] = out_width; + for (int i = 3; i < NumDims; ++i) { + post_contract_dims[i] = in.dimension(i); + } + } else { + post_contract_dims[NumDims - 1] = kernelFilters; + post_contract_dims[NumDims - 2] = out_height; + post_contract_dims[NumDims - 3] = out_width; + for (int i = 0; i < NumDims - 3; ++i) { + post_contract_dims[i] = in.dimension(i); + } + } + + DSizes kernel_dims; + if (isColMajor) { + kernel_dims[0] = kernelFilters; + kernel_dims[1] = kernelChannels * kernelRows * kernelCols; + } else { + kernel_dims[0] = kernelChannels * kernelRows * kernelCols; + kernel_dims[1] = kernelFilters; + } + // TODO(yangke): choose() is defined in TensorContraction.h -- consider + // moving it to somewhere more "common". + return + input + .extract_image_patches(kernelRows, kernelCols, row_stride, col_stride, + row_in_stride, col_in_stride, padding_type) + .reshape(pre_contract_dims) + .contract(kernel.reshape(kernel_dims), contract_dims) + .reshape(post_contract_dims); +} + +} // end namespace Eigen + +// clang-format on + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_SPATIAL_CONVOLUTIONS_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h b/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h new file mode 100644 index 0000000000000000000000000000000000000000..7f78f69360b1ebbfb08600c8bc427f1ba9d5244d --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h @@ -0,0 +1,143 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_GOOGLE_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_GOOGLE_H_ + +#define EIGEN_USE_CUSTOM_THREAD_POOL +#define EIGEN_USE_THREADS + +// clang-format off + +#include + +#include +#include +#include +#include +#include +#include // NOLINT(build/c++11) +#include // NOLINT(build/c++11) +#include // NOLINT(build/c++11) +#include + +#ifdef _WIN32 +#include +#elif defined(__APPLE__) +#include +#else +#include +#endif + + +// Because some programs may link Eigen in through other frameworks with +// different flags, we can run into multiple definition issues if we don't have +// a private namespace for our versions. This is a nasty hack, but a similar +// approach is used elsewhere to handle the problem, so it should be stable. +#define Eigen EigenForTFLite + +#include "Eigen/src/Core/util/StaticAssert.h" +#include "unsupported/Eigen/CXX11/Core" +#include "unsupported/Eigen/SpecialFunctions" + +#include "Eigen/src/Core/util/DisableStupidWarnings.h" + +#include "Eigen/Core" + +// Beware: the order of the include matters to some compilers. For example +// TensorIndexList.h should be included before TensorDimensions.h in order to +// use index lists to encode tensor dimensions when compiling with llvm. +// We're defining this ourselves rather than using the Eigen Tensor header file +// so that we can alter the macro definition of TENSOR_CONTRACTION_DISPATCH to +// reduce binary size. +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorMeta.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorCostModel.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/ThreadPoolInterface.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceType.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorNonBlockingThreadPool.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorIndexList.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorDimensionList.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorInitializer.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorTraits.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorRandom.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorUInt128.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorIntDiv.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorBlock.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorGlobalFunctions.h" + +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorStats.h" + +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h" + +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorReduction.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorArgMax.h" + +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMappers.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorContractionBlocking.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h" +#undef TENSOR_CONTRACTION_DISPATCH +#define TENSOR_CONTRACTION_DISPATCH(METHOD, ALIGNMENT, ARGS) \ + if (this->m_lhs_inner_dim_contiguous && \ + this->m_rhs_inner_dim_contiguous && \ + !this->m_rhs_inner_dim_reordered) { \ + METHOD ARGS; \ + } else { \ + eigen_assert(false && "Unsupported contraction formats"); \ + } + +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorContractionCuda.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorConvolution.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorFFT.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorPatch.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorImagePatch.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorVolumePatch.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorChipping.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorInflation.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorLayoutSwap.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorPadding.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorReverse.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorShuffling.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorStriding.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorCustomOp.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorEvalTo.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorForcedEval.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorGenerator.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorAssign.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorScan.h" + +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorDevice.h" + +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorStorage.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/Tensor.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorFixedSize.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorMap.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorRef.h" + +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorReductionCuda.h" + +#include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorIO.h" + +#include "Eigen/src/Core/util/ReenableStupidWarnings.h" +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_H diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_oss.h b/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_oss.h new file mode 100644 index 0000000000000000000000000000000000000000..1d5c316194df0b87ee7eecbdd04bd5ce9e2e40b5 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_oss.h @@ -0,0 +1,167 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This is essentially unsupported/CXX11/Eigen/Tensor.h +// TODO(petewarden) - move this to a common location in Eigen itself. + +// clang-format off + + +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_OSS_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_OSS_H_ + + +#include "Eigen/Core" + +#if defined(EIGEN_USE_SYCL) +#undef min +#undef max +#undef isnan +#undef isinf +#undef isfinite +#include +#include +#include +#include +#include +#endif +#include +#include +#include + + + + + +#ifdef _WIN32 +typedef __int16 int16_t; +typedef unsigned __int16 uint16_t; +typedef __int32 int32_t; +typedef unsigned __int32 uint32_t; +typedef __int64 int64_t; +typedef unsigned __int64 uint64_t; +#include +#else +#include +#include +#endif + +#if __cplusplus > 199711 || EIGEN_COMP_MSVC >= 1900 +#include +#endif + +#ifdef _WIN32 +#include +#elif defined(__APPLE__) +#include +#else +#include +#endif + +// #if defined(EIGEN_USE_LIBXSMM) +// #include "libxsmm.h" +// #endif + +#ifdef EIGEN_USE_THREADS +#include "unsupported/Eigen/CXX11/ThreadPool" +#endif + + +#include "Eigen/src/Core/util/DisableStupidWarnings.h" + +#include "unsupported/Eigen/SpecialFunctions" +#include "unsupported/Eigen/CXX11/src/util/CXX11Meta.h" +#include "unsupported/Eigen/CXX11/src/util/MaxSizeVector.h" + + +#include "unsupported/Eigen/CXX11/src/Tensor/TensorMacros.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorMeta.h" + +#include "unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorCostModel.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorDeviceDefault.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorDeviceThreadPool.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorDeviceCuda.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorDeviceSycl.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorIndexList.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorDimensionList.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorInitializer.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorTraits.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorRandom.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorUInt128.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorIntDiv.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorGlobalFunctions.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorBase.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorReduction.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorReductionCuda.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorArgMax.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorContractionBlocking.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h" + +#undef TENSOR_CONTRACTION_DISPATCH +#define TENSOR_CONTRACTION_DISPATCH(METHOD, ALIGNMENT, ARGS) \ + if (this->m_lhs_inner_dim_contiguous && \ + this->m_rhs_inner_dim_contiguous && \ + !this->m_rhs_inner_dim_reordered) { \ + METHOD ARGS; \ + } else { \ + eigen_assert(false && "Unsupported contraction formats"); \ + } + + +#include "unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorContractionCuda.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorConvolution.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorFFT.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorPatch.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorImagePatch.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorVolumePatch.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorChipping.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorInflation.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorLayoutSwap.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorPadding.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorReverse.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorShuffling.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorStriding.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorCustomOp.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorEvalTo.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorForcedEval.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorGenerator.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorAssign.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorScan.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorTrace.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorSycl.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorDevice.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorStorage.h" +#include "unsupported/Eigen/CXX11/src/Tensor/Tensor.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorFixedSize.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorMap.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorRef.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorIO.h" + +#include "Eigen/src/Core/util/ReenableStupidWarnings.h" + + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_OSS_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h b/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h new file mode 100644 index 0000000000000000000000000000000000000000..b3615f4658a1a70284cc9d386a868a87aa09819b --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h @@ -0,0 +1,195 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_MULTITHREAD_CONV +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_MULTITHREAD_CONV + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/kernels/internal/common.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/eigen_spatial_convolutions.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/types.h" + +namespace tflite { +namespace multithreaded_ops { + +class EigenThreadPoolWrapper : public Eigen::ThreadPoolInterface { + public: + explicit EigenThreadPoolWrapper(Eigen::ThreadPool* pool) : pool_(pool) {} + ~EigenThreadPoolWrapper() override {} + + void Schedule(std::function fn) override { + pool_->Schedule(std::move(fn)); + } + int NumThreads() const override { return pool_->NumThreads(); } + int CurrentThreadId() const override { return pool_->CurrentThreadId(); } + + private: + Eigen::ThreadPool* pool_ = nullptr; +}; + +// We have a single global threadpool for all convolution operations. This means +// that inferences started from different threads may block each other, but +// since the underlying resource of CPU cores should be consumed by the +// operations anyway, it shouldn't affect overall performance. +const Eigen::ThreadPoolDevice& GetThreadPoolDevice() { + const int thread_count = 4; + static Eigen::ThreadPool* tp = new Eigen::ThreadPool(thread_count); + static EigenThreadPoolWrapper* thread_pool_wrapper = + new EigenThreadPoolWrapper(tp); + static Eigen::ThreadPoolDevice* device = + new Eigen::ThreadPoolDevice(thread_pool_wrapper, thread_count); + return *device; +} + +// Shorthands for the types we need when interfacing with the EigenTensor +// library. +typedef Eigen::TensorMap< + Eigen::Tensor, Eigen::Aligned> + EigenMatrix; +typedef Eigen::TensorMap< + Eigen::Tensor, + Eigen::Aligned> + ConstEigenMatrix; + +typedef Eigen::TensorMap< + Eigen::Tensor, Eigen::Aligned> + EigenTensor; +typedef Eigen::TensorMap< + Eigen::Tensor, + Eigen::Aligned> + ConstEigenTensor; + +// Utility functions we need for the EigenTensor API. +template +struct MatMulConvFunctor { + // Computes on device "d": out = in0 * in1, where * is matrix + // multiplication. + void operator()( + const Device& d, EigenMatrix out, ConstEigenMatrix in0, + ConstEigenMatrix in1, + const Eigen::array, 1>& dim_pair) { + out.device(d) = in0.contract(in1, dim_pair); + } +}; + +template +class EigenTensorConvFunctor { + private: + Eigen::PaddingType TfLitePadding2EigenPadding(TfLitePadding padding) { + switch (padding) { + case kTfLitePaddingValid: + return Eigen::PADDING_VALID; + case kTfLitePaddingSame: + return Eigen::PADDING_SAME; + case kTfLitePaddingUnknown: + assert(false); // should never get here. + return Eigen::PADDING_VALID; + } + return Eigen::PADDING_SAME; // Prevent compiler warning about missing + // return + } + + public: + void operator()(const T* input_data, T* im2col_buffer, int input_batches, + int input_height, int input_width, int input_depth, + const T* filter_data, int filter_height, int filter_width, + int filter_count, int stride_rows, int stride_cols, + int pad_width, int pad_height, TfLitePadding padding, + T* output_data, int output_height, int output_width) { + const Eigen::ThreadPoolDevice& device = GetThreadPoolDevice(); + + const bool is_1x1_kernel = (filter_height == 1 && filter_width == 1 && + stride_rows == 1 && stride_cols == 1); + if (is_1x1_kernel) { + // For 1x1 kernel, the 2D convolution is reduced to matrix + // multiplication. + const int conv_width = output_height * output_width; + Eigen::array, 1> dim_pair; + dim_pair[0] = Eigen::IndexPair(1, 0); + EigenMatrix output(output_data, conv_width, filter_count); + ConstEigenMatrix input(input_data, conv_width, input_depth); + ConstEigenMatrix filter(filter_data, input_depth, filter_count); + MatMulConvFunctor()(device, output, input, + filter, dim_pair); + } else if (filter_height == input_height && filter_width == input_width && + pad_width == 0 && pad_height == 0) { + // If the input data and filter have the same height/width, + // the 2D convolution is reduced to matrix multiplication. + const int k = // Length of reduction dimension. + filter_width * filter_height * input_depth; + Eigen::array, 1> dim_pair; + dim_pair[0] = Eigen::IndexPair(1, 0); + EigenMatrix output(output_data, 1, filter_count); + ConstEigenMatrix input(input_data, 1, k); + ConstEigenMatrix filter(filter_data, k, filter_count); + MatMulConvFunctor()(device, output, input, + filter, dim_pair); + } else { + EigenTensor output(output_data, input_batches, output_height, + output_width, filter_count); + ConstEigenTensor input(input_data, input_batches, input_height, + input_width, input_depth); + ConstEigenTensor filter(filter_data, filter_height, filter_width, + input_depth, filter_count); + output.device(device) = + Eigen::SpatialConvolution(input, filter, stride_cols, stride_rows, + TfLitePadding2EigenPadding(padding)); + } + } +}; + +inline void Conv(const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + const float* bias_data, const Dims<4>& bias_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, TfLitePadding padding, + float output_activation_min, float output_activation_max, + float* output_data, const Dims<4>& output_dims, + float* im2col_data, const Dims<4>& im2col_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 0); + const int output_depth = MatchingArraySize(filter_dims, 3, output_dims, 0); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int filter_height = ArraySize(filter_dims, 2); + const int filter_width = ArraySize(filter_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + EigenTensorConvFunctor conv_functor; + conv_functor(input_data, im2col_data, batches, input_height, input_width, + input_depth, filter_data, filter_height, filter_width, + output_depth, stride_height, stride_width, pad_height, pad_width, + padding, output_data, output_height, output_width); + + optimized_ops::AddBiasAndEvalActivationFunction( + bias_data, bias_dims, output_data, output_dims, output_activation_min, + output_activation_max); +} + +} // namespace multithreaded_ops +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_MULTITHREAD_CONV diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc new file mode 100644 index 0000000000000000000000000000000000000000..bf0bdfb1fb875c4b54c55e25d4a17541507ecd4c --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc @@ -0,0 +1,337 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/kernels/activation_functor.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h" + +#ifdef USE_NEON + +#include +#define kFloatWeightsPerNeonLane 4 + +namespace tflite { +namespace tensor_utils { + +void NeonMatrixBatchVectorMultiplyAccumulate(const float* matrix, int m_rows, + int m_cols, const float* vector, + int n_batch, float* result, + int result_stride) { + // If v_size is not divisible by kWeightsPerNeonLane, we cannot use the main + // vectorized loop, and we need to process sequentially. postamble_start shows + // the start index where this should happen. + const int postamble_start = + m_cols - (m_cols & (kFloatWeightsPerNeonLane - 1)); + + // The arrays used to cache the vector. + float32x4_t* vector_cache_float32x4 = + new float32x4_t[(m_cols / kFloatWeightsPerNeonLane) * + sizeof(float32x4_t)]; + const int kUnrollSize = 2; + for (int b = 0; b < n_batch; b++) { + float* result_in_batch = result + b * m_rows * result_stride; + const float* vector_in_batch = vector + b * m_cols; + + const float* matrix_ptr0 = matrix; + // If there is only 1 row, we don't want to assign an illegal pointer. + const float* matrix_ptr1 = nullptr; + if (m_rows > 1) { + matrix_ptr1 = matrix + m_cols; + } + + // Cahce the vector. + for (int c = 0; c < postamble_start; c += kFloatWeightsPerNeonLane) { + vector_cache_float32x4[c >> 2] = vld1q_f32(vector_in_batch + c); + } + + // Main matrix by vector multiplication loop, which handles two rows of + // matrix by vector multiplication. + for (int r = 0; r < (m_rows & ~(kUnrollSize - 1)); r += kUnrollSize) { + float32x4_t acc0_32x4 = vmovq_n_f32(0.0); + float32x4_t acc1_32x4 = vmovq_n_f32(0.0); + for (int c = 0; c < postamble_start; c += kFloatWeightsPerNeonLane) { + float32x4_t temp = vector_cache_float32x4[c >> 2]; + // Load 4 float values from vector1 and vector2 and accumulator. + float32x4_t v0_f32x4 = vld1q_f32(matrix_ptr0 + c); + float32x4_t v1_f32x4 = vld1q_f32(matrix_ptr1 + c); + // Vector multiply-accumulate 4 float + acc0_32x4 = vmlaq_f32(acc0_32x4, v0_f32x4, temp); + acc1_32x4 = vmlaq_f32(acc1_32x4, v1_f32x4, temp); + } + // Add the 4 intermediate sum values to get the final dot-prod value for + // this column. + *result_in_batch += + (vgetq_lane_f32(acc0_32x4, 0) + vgetq_lane_f32(acc0_32x4, 1) + + vgetq_lane_f32(acc0_32x4, 2) + vgetq_lane_f32(acc0_32x4, 3)); + *(result_in_batch + result_stride) += + (vgetq_lane_f32(acc1_32x4, 0) + vgetq_lane_f32(acc1_32x4, 1) + + vgetq_lane_f32(acc1_32x4, 2) + vgetq_lane_f32(acc1_32x4, 3)); + for (int c = postamble_start; c < m_cols; c++) { + *result_in_batch += matrix_ptr0[c] * vector_in_batch[c]; + *(result_in_batch + result_stride) += + matrix_ptr1[c] * vector_in_batch[c]; + } + matrix_ptr0 += kUnrollSize * m_cols; + matrix_ptr1 += kUnrollSize * m_cols; + result_in_batch += kUnrollSize * result_stride; + } + for (int r = (m_rows & ~(kUnrollSize - 1)); r < m_rows; r++) { + float32x4_t acc0_32x4 = vmovq_n_f32(0.0); + for (int c = 0; c < postamble_start; c += kFloatWeightsPerNeonLane) { + float32x4_t temp = vector_cache_float32x4[c >> 2]; + // Load 4 float values from vector1 and vector2 and accumulator. + float32x4_t v0_f32x4 = vld1q_f32(matrix_ptr0 + c); + // Vector multiply-accumulate 4 float + acc0_32x4 = vmlaq_f32(acc0_32x4, v0_f32x4, temp); + } + // Add the 4 intermediate sum values to get the final dot-prod value for + // this column. + *result_in_batch += + (vgetq_lane_f32(acc0_32x4, 0) + vgetq_lane_f32(acc0_32x4, 1) + + vgetq_lane_f32(acc0_32x4, 2) + vgetq_lane_f32(acc0_32x4, 3)); + for (int c = postamble_start; c < m_cols; c++) { + *result_in_batch += matrix_ptr0[c] * vector_in_batch[c]; + } + matrix_ptr0 += m_cols; + result_in_batch += result_stride; + } + } + delete[] vector_cache_float32x4; +} + +void NeonVectorVectorCwiseProduct(const float* vector1, const float* vector2, + int v_size, float* result) { + // If v_size is not divisible by kWeightsPerNeonLane, we cannot use the main + // vectorized loop, and we need to process sequentially. postamble_start shows + // the start index where this should happen. + const int postamble_start = + v_size - (v_size & (kFloatWeightsPerNeonLane - 1)); + for (int v = 0; v < postamble_start; v += kFloatWeightsPerNeonLane) { + // Load 4 float values from vector1 and vector2. + float32x4_t v1_f32x4 = vld1q_f32(vector1 + v); + float32x4_t v2_f32x4 = vld1q_f32(vector2 + v); + // Vector multiply 4 float + float32x4_t mul_32x4 = vmulq_f32(v1_f32x4, v2_f32x4); + // Save to result array. + vst1q_f32(&result[v], mul_32x4); + } + for (int v = postamble_start; v < v_size; v++) { + result[v] = vector1[v] * vector2[v]; + } +} + +void NeonVectorVectorCwiseProductAccumulate(const float* vector1, + const float* vector2, int v_size, + float* result) { + // If v_size is not divisible by kWeightsPerNeonLane, we cannot use the main + // vectorized loop, and we need to process sequentially. postamble_start shows + // the start index where this should happen. + const int postamble_start = + v_size - (v_size & (kFloatWeightsPerNeonLane - 1)); + for (int v = 0; v < postamble_start; v += kFloatWeightsPerNeonLane) { + // Load 4 float values from vector1 and vector2 and accumulator. + float32x4_t v1_f32x4 = vld1q_f32(vector1 + v); + float32x4_t v2_f32x4 = vld1q_f32(vector2 + v); + float32x4_t acc_32x4 = vld1q_f32(result + v); + // Vector multiply-accumulate 4 float + acc_32x4 = vmlaq_f32(acc_32x4, v1_f32x4, v2_f32x4); + // Save to result array. + vst1q_f32(&result[v], acc_32x4); + } + for (int v = postamble_start; v < v_size; v++) { + result[v] += vector1[v] * vector2[v]; + } +} + +void NeonVectorBatchVectorCwiseProductAccumulate(const float* vector, + int v_size, + const float* batch_vector, + int n_batch, float* result) { + // If v_size is not divisible by kWeightsPerNeonLane, we cannot use the main + // vectorized loop, and we need to process sequentially. postamble_start shows + // the start index where this should happen. + const int postamble_start = + v_size - (v_size & (kFloatWeightsPerNeonLane - 1)); + + // The arrays used to cache the vector. + float32x4_t* vector_cache_float32x4 = + new float32x4_t[(v_size / kFloatWeightsPerNeonLane) * + sizeof(float32x4_t)]; + for (int v = 0; v < postamble_start; v += kFloatWeightsPerNeonLane) { + vector_cache_float32x4[v >> 2] = vld1q_f32(vector + v); + } + + float* result_ptr = result; + const float* batch_vector_ptr = batch_vector; + for (int b = 0; b < n_batch; b++) { + for (int v = 0; v < postamble_start; v += kFloatWeightsPerNeonLane) { + // Load from memory to vectors. + float32x4_t result_f32x4 = vld1q_f32(result_ptr + v); + float32x4_t batch_vector_f32x4 = vld1q_f32(batch_vector_ptr + v); + // Multiply-accumulate. + result_f32x4 = vmlaq_f32(result_f32x4, batch_vector_f32x4, + vector_cache_float32x4[v >> 2]); + // Store. + vst1q_f32(result_ptr + v, result_f32x4); + } + // Postamble loop + for (int v = postamble_start; v < v_size; v++) { + result_ptr[v] += vector[v] * batch_vector_ptr[v]; + } + // Update the pointers. + result_ptr += v_size; + batch_vector_ptr += v_size; + } + delete[] vector_cache_float32x4; +} + +void NeonSub1Vector(const float* vector, int v_size, float* result) { + // If v_size is not divisible by kWeightsPerNeonLane, we cannot use the main + // vectorized loop, and we need to process sequentially. postamble_start shows + // the start index where this should happen. + const int postamble_start = + v_size - (v_size & (kFloatWeightsPerNeonLane - 1)); + + float32x4_t one_f32x4 = vmovq_n_f32(1.0); + for (int v = 0; v < postamble_start; v += kFloatWeightsPerNeonLane) { + // Load 4 float values from the current pointers of the input column and + // subtract from 1. + float32x4_t v_f32x4 = vld1q_f32(vector + v); + float32x4_t result_f32x4 = vsubq_f32(one_f32x4, v_f32x4); + // Save to output. + vst1q_f32(result + v, result_f32x4); + } + for (int v = postamble_start; v < v_size; v++) { + result[v] = 1.0f - vector[v]; + } +} + +void NeonClipVector(const float* vector, int v_size, float abs_limit, + float* result) { + // If v_size is not divisible by kWeightsPerNeonLane, we cannot use the main + // vectorized loop, and we need to process sequentially. postamble_start shows + // the start index where this should happen. + const int postamble_start = + v_size - (v_size & (kFloatWeightsPerNeonLane - 1)); + + // Replicate abs_limit and -abs_limit in two vectors. + const float32x4_t abs_limit_f32x4 = vmovq_n_f32(abs_limit); + const float32x4_t neg_abs_limit_f32x4 = vmovq_n_f32(-abs_limit); + + for (int v = 0; v < postamble_start; v += kFloatWeightsPerNeonLane) { + // Load from memory to vector. + float32x4_t v_f32x4 = vld1q_f32(vector + v); + // Clip between abs_limit and -abs_limit. + float32x4_t result_f32x4 = vminq_f32(abs_limit_f32x4, v_f32x4); + result_f32x4 = vmaxq_f32(neg_abs_limit_f32x4, result_f32x4); + // Save to output. + vst1q_f32(result + v, result_f32x4); + } + // Postamble loop. + for (int v = postamble_start; v < v_size; v++) { + result[v] = (abs_limit < vector[v]) ? abs_limit : vector[v]; + result[v] = (-abs_limit > result[v]) ? -abs_limit : result[v]; + } +} + +float NeonVectorVectorDotProduct(const float* vector1, const float* vector2, + int v_size) { + // If v_size is not divisible by kWeightsPerNeonLane, we cannot use the main + // vectorized loop, and we need to process sequentially. postamble_start shows + // the start index where this should happen. + const int postamble_start = + v_size - (v_size & (kFloatWeightsPerNeonLane - 1)); + float32x4_t acc_32x4 = vmovq_n_f32(0.0); + for (int v = 0; v < postamble_start; v += kFloatWeightsPerNeonLane) { + // Load 4 float values from vector1 and vector2 and accumulator. + float32x4_t v1_f32x4 = vld1q_f32(vector1 + v); + float32x4_t v2_f32x4 = vld1q_f32(vector2 + v); + // Vector multiply-accumulate 4 float + acc_32x4 = vmlaq_f32(acc_32x4, v1_f32x4, v2_f32x4); + } + + float result = (vgetq_lane_f32(acc_32x4, 0) + vgetq_lane_f32(acc_32x4, 1) + + vgetq_lane_f32(acc_32x4, 2) + vgetq_lane_f32(acc_32x4, 3)); + // Postamble loop. + for (int v = postamble_start; v < v_size; v++) { + result += vector1[v] * vector2[v]; + } + return result; +} + +void NeonBatchVectorBatchVectorDotProduct(const float* vector1, + const float* vector2, int v_size, + int n_batch, float* result, + int result_stride) { + float* result_ptr = result; + const float* vector1_ptr = vector1; + const float* vector2_ptr = vector2; + for (int b = 0; b < n_batch; b++) { + *result_ptr = NeonVectorVectorDotProduct(vector1_ptr, vector2_ptr, v_size); + vector1_ptr += v_size; + vector2_ptr += v_size; + result_ptr += result_stride; + } +} + +void NeonReductionSumVector(const float* input_vector, float* output_vector, + int output_size, int reduction_size) { + const float* input_vector_ptr = input_vector; + for (int o = 0; o < output_size; o++) { + // If reduction_size is not divisible by kWeightsPerNeonLane, we cannot use + // the main vectorized loop, and we need to process sequentially. + // postamble_start shows the start index where this should happen. + const int postamble_start = + reduction_size - (reduction_size & (kFloatWeightsPerNeonLane - 1)); + float32x4_t sum_f32x4 = vmovq_n_f32(0.0); + for (int r = 0; r < postamble_start; r += kFloatWeightsPerNeonLane) { + float32x4_t v1_f32x4 = vld1q_f32(input_vector_ptr + r); + sum_f32x4 = vaddq_f32(sum_f32x4, v1_f32x4); + } + output_vector[o] += + (vgetq_lane_f32(sum_f32x4, 0) + vgetq_lane_f32(sum_f32x4, 1) + + vgetq_lane_f32(sum_f32x4, 2) + vgetq_lane_f32(sum_f32x4, 3)); + input_vector_ptr += postamble_start; + + // Postamble loop. + for (int r = postamble_start; r < reduction_size; r++) { + output_vector[o] += *input_vector_ptr++; + } + } +} + +void NeonVectorShiftLeft(float* vector, int v_size, float shift_value) { + // This variable keeps track of the next to the last index which is being + // copied to make sure we are not out of the vector boundary. + int last_index_copy = kFloatWeightsPerNeonLane; + int current_index_copy = 0; + while (last_index_copy < v_size) { + float32x4_t v_f32x4 = vld1q_f32(vector + current_index_copy + 1); + vst1q_f32(vector + current_index_copy, v_f32x4); + current_index_copy += kFloatWeightsPerNeonLane; + last_index_copy += kFloatWeightsPerNeonLane; + } + // Postamble loop. + for (int i = current_index_copy; i < v_size - 1; i++) { + vector[i] = vector[i + 1]; + } + vector[v_size - 1] = shift_value; +} + +} // namespace tensor_utils +} // namespace tflite + +#endif // USE_NEON diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..3a4af87304eaf33489b38bd9b15ad9789e091d24 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h @@ -0,0 +1,113 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_NEON_TENSOR_UTILS_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_NEON_TENSOR_UTILS_H_ + +// TODO(ghodrat): Remove this header file and the dependency to internal data +// structure. +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/cpu_check.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h" + +namespace tflite { +namespace tensor_utils { + +void MatrixBatchVectorMultiplyAccumulate(const float* matrix, int m_rows, + int m_cols, const float* vector, + int n_batch, float* result, + int result_stride) { + NEON_OR_PORTABLE(MatrixBatchVectorMultiplyAccumulate, matrix, m_rows, m_cols, + vector, n_batch, result, result_stride); +} + +void VectorVectorCwiseProduct(const float* vector1, const float* vector2, + int v_size, float* result) { + NEON_OR_PORTABLE(VectorVectorCwiseProduct, vector1, vector2, v_size, result); +} + +void VectorVectorCwiseProductAccumulate(const float* vector1, + const float* vector2, int v_size, + float* result) { + NEON_OR_PORTABLE(VectorVectorCwiseProductAccumulate, vector1, vector2, v_size, + result); +} + +void VectorBatchVectorCwiseProductAccumulate(const float* vector, int v_size, + const float* batch_vector, + int n_batch, float* result) { + NEON_OR_PORTABLE(VectorBatchVectorCwiseProductAccumulate, vector, v_size, + batch_vector, n_batch, result); +} + +float VectorVectorDotProduct(const float* vector1, const float* vector2, + int v_size) { + return NEON_OR_PORTABLE(VectorVectorDotProduct, vector1, vector2, v_size); +} + +void BatchVectorBatchVectorDotProduct(const float* vector1, + const float* vector2, int v_size, + int n_batch, float* result, + int result_stride) { + NEON_OR_PORTABLE(BatchVectorBatchVectorDotProduct, vector1, vector2, v_size, + n_batch, result, result_stride); +} + +void VectorBatchVectorAssign(const float* vector, int v_size, int n_batch, + float* batch_vector) { + PortableVectorBatchVectorAssign(vector, v_size, n_batch, batch_vector); +} + +void ApplySigmoidToVector(const float* vector, int v_size, float* result) { + PortableApplySigmoidToVector(vector, v_size, result); +} + +void ApplyActivationToVector(const float* vector, int v_size, + TfLiteFusedActivation activation, float* result) { + PortableApplyActivationToVector(vector, v_size, activation, result); +} + +void CopyVector(const float* vector, int v_size, float* result) { + PortableCopyVector(vector, v_size, result); +} + +void Sub1Vector(const float* vector, int v_size, float* result) { + NEON_OR_PORTABLE(Sub1Vector, vector, v_size, result); +} + +void ZeroVector(float* vector, int v_size) { + PortableZeroVector(vector, v_size); +} + +float Clip(float f, float abs_limit) { return PortableClip(f, abs_limit); } + +void ClipVector(const float* vector, int v_size, float abs_limit, + float* result) { + NEON_OR_PORTABLE(ClipVector, vector, v_size, abs_limit, result); +} + +void VectorShiftLeft(float* vector, int v_size, float shift_value) { + NEON_OR_PORTABLE(VectorShiftLeft, vector, v_size, shift_value); +} + +void ReductionSumVector(const float* input_vector, float* output_vector, + int output_size, int reduction_size) { + NEON_OR_PORTABLE(ReductionSumVector, input_vector, output_vector, output_size, + reduction_size); +} + +} // namespace tensor_utils +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_NEON_TENSOR_UTILS_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..cd565c16a1ee7226f83c19f0020beed75e401497 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h @@ -0,0 +1,3715 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_OPS_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_OPS_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "third_party/eigen3/Eigen/Core" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "fixedpoint/fixedpoint.h" +#include "public/gemmlowp.h" +#include "tensorflow/contrib/lite/kernels/internal/common.h" +#include "tensorflow/contrib/lite/kernels/internal/round.h" +#include "tensorflow/contrib/lite/kernels/internal/types.h" + +namespace tflite { +namespace optimized_ops { + +// Make a local VectorMap typedef allowing to map a float array +// as a Eigen vector expression. The std::conditional here is to +// construct the suitable Eigen type for the constness of the +// data. Indeed, for const data, we need to produce +// Eigen::Map> +// and not the more straightforward +// Eigen::Map> +template +using VectorMap = typename std::conditional< + std::is_const::value, + Eigen::Map::type, + Eigen::Dynamic, 1>>, + Eigen::Map>>::type; + +template +VectorMap MapAsVector(Scalar* data, const Dims& dims) { + const int size = RequiredBufferSizeForDims(dims); + return VectorMap(data, size, 1); +} + +// Make a local VectorMap typedef allowing to map a float array +// as a Eigen matrix expression. The same explanation as for VectorMap +// above also applies here. +template +using MatrixMap = typename std::conditional< + std::is_const::value, + Eigen::Map::type, + Eigen::Dynamic, Eigen::Dynamic>>, + Eigen::Map>>::type; + +template +MatrixMap MapAsMatrixWithFirstDimAsRows(Scalar* data, + const Dims& dims) { + const int rows = dims.sizes[0]; + int cols = 1; + for (int d = 1; d < N; d++) { + cols *= dims.sizes[d]; + } + return MatrixMap(data, rows, cols); +} + +template +MatrixMap MapAsMatrixWithLastDimAsCols(Scalar* data, + const Dims& dims) { + const int cols = dims.sizes[N - 1]; + int rows = 1; + for (int d = 0; d < N - 1; d++) { + rows *= dims.sizes[d]; + } + return MatrixMap(data, rows, cols); +} + +template +using ArrayMap = typename std::conditional< + std::is_const::value, + Eigen::Map::type, + Eigen::Dynamic, Eigen::Dynamic>>, + Eigen::Map>>::type; + +template +ArrayMap MapAsArrayWithFirstDimAsRows(Scalar* data, + const Dims& dims) { + const int rows = dims.sizes[0]; + int cols = 1; + for (int d = 1; d < N; d++) { + cols *= dims.sizes[d]; + } + return ArrayMap(data, rows, cols); +} + +// TODO(b/62193649): this function is only needed as long +// as we have the --variable_batch hack. +template +MatrixMap MapAsMatrixWithGivenNumberOfRows(Scalar* data, + const Dims& dims, + int rows) { + int cols = 1; + bool matched_rows = false; + for (int d = 0; d < N; d++) { + cols *= dims.sizes[d]; + if (cols == rows) { + matched_rows = true; + cols = 1; + } + } + TFLITE_DCHECK(matched_rows); + return MatrixMap(data, rows, cols); +} + +// DO NOT USE THIS STRUCT FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING ELEMENT-WISE +// BROADCASTING. +// +// NdArrayDesc describes the shape and memory layout of an N-dimensional +// rectangular array of numbers. +// +// NdArrayDesc is basically identical to Dims defined in types.h. +// However, as Dims is to be deprecated, this class exists as an adaptor +// to enable simple unoptimized implementations of element-wise broadcasting +// operations. +template +struct NdArrayDesc { + // The "extent" of each dimension. Indices along dimension d must be in the + // half-open interval [0, extents[d]). + int extents[N]; + + // The number of *elements* (not bytes) between consecutive indices of each + // dimension. + int strides[N]; +}; + +// DO NOT USE THIS FUNCTION FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING +// ELEMENT-WISE BROADCASTING. +// +// Same as Offset(), except takes as NdArrayDesc instead of Dims. +inline int SubscriptToIndex(const NdArrayDesc<4>& desc, int i0, int i1, int i2, + int i3) { + TFLITE_DCHECK(i0 >= 0 && i0 < desc.extents[0]); + TFLITE_DCHECK(i1 >= 0 && i1 < desc.extents[1]); + TFLITE_DCHECK(i2 >= 0 && i2 < desc.extents[2]); + TFLITE_DCHECK(i3 >= 0 && i3 < desc.extents[3]); + return i0 * desc.strides[0] + i1 * desc.strides[1] + i2 * desc.strides[2] + + i3 * desc.strides[3]; +} + +// Given the dimensions of the operands for an element-wise binary broadcast, +// adjusts them so that they can be directly iterated over with simple loops. +// Returns the adjusted dims as instances of NdArrayDesc in 'desc0_out' and +// 'desc1_out'. 'desc0_out' and 'desc1_out' cannot be nullptr. +// +// This function assumes that the two input shapes are compatible up to +// broadcasting and the shorter one has already been prepended with 1s to be the +// same length. E.g., if shape0 is (1, 16, 16, 64) and shape1 is (1, 64), +// shape1 must already have been prepended to be (1, 1, 1, 64). Recall that +// Dims refer to shapes in reverse order. In this case, input0_dims will be +// (64, 16, 16, 1) and input1_dims will be (64, 1, 1, 1). +// +// When two shapes are compatible up to broadcasting, for each dimension d, +// the input extents are either equal, or one of them is 1. +// +// This function performs the following for each dimension d: +// - If the extents are equal, then do nothing since the loop that walks over +// both of the input arrays is correct. +// - Otherwise, one (and only one) of the extents must be 1. Say extent0 is 1 +// and extent1 is e1. Then set extent0 to e1 and stride0 *to 0*. This allows +// array0 to be referenced *at any index* in dimension d and still access the +// same slice. +template +inline void NdArrayDescsForElementwiseBroadcast(const Dims& input0_dims, + const Dims& input1_dims, + NdArrayDesc* desc0_out, + NdArrayDesc* desc1_out) { + TFLITE_DCHECK(desc0_out != nullptr); + TFLITE_DCHECK(desc1_out != nullptr); + + // Copy dims to desc. + for (int i = 0; i < N; ++i) { + desc0_out->extents[i] = input0_dims.sizes[i]; + desc0_out->strides[i] = input0_dims.strides[i]; + desc1_out->extents[i] = input1_dims.sizes[i]; + desc1_out->strides[i] = input1_dims.strides[i]; + } + + // Walk over each dimension. If the extents are equal do nothing. + // Otherwise, set the desc with extent 1 to have extent equal to the other and + // stride 0. + for (int i = 0; i < N; ++i) { + const int extent0 = ArraySize(input0_dims, i); + const int extent1 = ArraySize(input1_dims, i); + if (extent0 != extent1) { + if (extent0 == 1) { + desc0_out->strides[i] = 0; + desc0_out->extents[i] = extent1; + } else { + TFLITE_DCHECK_EQ(extent1, 1); + desc1_out->strides[i] = 0; + desc1_out->extents[i] = extent0; + } + } + } +} + +inline bool AreSameDims(const Dims<4>& dims1, const Dims<4>& dims2) { + for (int i = 0; i < 4; i++) { + if (dims1.sizes[i] != dims2.sizes[i]) { + return false; + } + } + return true; +} + +inline void AddBiasAndEvalActivationFunction(const float* bias_data, + const Dims<4>& bias_dims, + float* array_data, + const Dims<4>& array_dims, + float output_activation_min, + float output_activation_max) { +#ifdef USE_NEON + gemmlowp::ScopedProfilingLabel label("AddBiasAndEvalActivationFunction"); + const int bias_size = bias_dims.sizes[3] * bias_dims.strides[3]; + const int array_size = array_dims.sizes[3] * array_dims.strides[3]; + TFLITE_DCHECK_EQ((array_size % bias_size), 0); + float* array_ptr = array_data; + float* array_end_ptr = array_ptr + array_size; + const auto activation_min = vdupq_n_f32(output_activation_min); + const auto activation_max = vdupq_n_f32(output_activation_max); + for (; array_ptr != array_end_ptr; array_ptr += bias_size) { + int i = 0; + for (; i <= bias_size - 16; i += 16) { + auto b0 = vld1q_f32(bias_data + i); + auto b1 = vld1q_f32(bias_data + i + 4); + auto b2 = vld1q_f32(bias_data + i + 8); + auto b3 = vld1q_f32(bias_data + i + 12); + auto a0 = vld1q_f32(array_ptr + i); + auto a1 = vld1q_f32(array_ptr + i + 4); + auto a2 = vld1q_f32(array_ptr + i + 8); + auto a3 = vld1q_f32(array_ptr + i + 12); + auto x0 = vaddq_f32(a0, b0); + auto x1 = vaddq_f32(a1, b1); + auto x2 = vaddq_f32(a2, b2); + auto x3 = vaddq_f32(a3, b3); + x0 = vmaxq_f32(activation_min, x0); + x1 = vmaxq_f32(activation_min, x1); + x2 = vmaxq_f32(activation_min, x2); + x3 = vmaxq_f32(activation_min, x3); + x0 = vminq_f32(activation_max, x0); + x1 = vminq_f32(activation_max, x1); + x2 = vminq_f32(activation_max, x2); + x3 = vminq_f32(activation_max, x3); + vst1q_f32(array_ptr + i, x0); + vst1q_f32(array_ptr + i + 4, x1); + vst1q_f32(array_ptr + i + 8, x2); + vst1q_f32(array_ptr + i + 12, x3); + } + for (; i <= bias_size - 4; i += 4) { + auto b = vld1q_f32(bias_data + i); + auto a = vld1q_f32(array_ptr + i); + auto x = vaddq_f32(a, b); + x = vmaxq_f32(activation_min, x); + x = vminq_f32(activation_max, x); + vst1q_f32(array_ptr + i, x); + } + for (; i < bias_size; i++) { + array_ptr[i] = ActivationFunctionWithMinMax(array_ptr[i] + bias_data[i], + output_activation_min, + output_activation_max); + } + } +#else // not NEON + gemmlowp::ScopedProfilingLabel label("AddBiasAndEvalActivationFunction"); + const int bias_size = bias_dims.sizes[3] * bias_dims.strides[3]; + const int array_size = array_dims.sizes[3] * array_dims.strides[3]; + TFLITE_DCHECK_EQ((array_size % bias_size), 0); + for (int array_offset = 0; array_offset < array_size; + array_offset += bias_size) { + for (int i = 0; i < bias_size; i++) { + array_data[array_offset + i] = ActivationFunctionWithMinMax( + array_data[array_offset + i] + bias_data[i], output_activation_min, + output_activation_max); + } + } +#endif +} + +// legacy, for compatibility with old checked-in code +template +void AddBiasAndEvalActivationFunction(const float* bias_data, + const Dims<4>& bias_dims, + float* array_data, + const Dims<4>& array_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + AddBiasAndEvalActivationFunction(bias_data, bias_dims, array_data, array_dims, + output_activation_min, + output_activation_max); +} + +template +void Gemm(const Eigen::MatrixBase& lhs, const Eigen::MatrixBase& rhs, + Eigen::MatrixBase* result) { + if (rhs.cols() == 1) { + gemmlowp::ScopedProfilingLabel label("GEMV"); + result->col(0).noalias() = lhs * rhs.col(0); + } else { + gemmlowp::ScopedProfilingLabel label("GEMM"); + result->noalias() = lhs * rhs; + } +} + +inline void FullyConnected(const float* input_data, const Dims<4>& input_dims, + const float* weights_data, + const Dims<4>& weights_dims, const float* bias_data, + const Dims<4>& bias_dims, + float output_activation_min, + float output_activation_max, float* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("FullyConnected"); + // TODO(b/62193649): this convoluted shape computation (determining + // input_rows from the weights_dims, then MapAsMatrixWithGivenNumberOfRows) + // is because the current --variable_batch hack consists in overwriting the + // 3rd dimension with the runtime batch size, as we don't keep track for each + // array of which dimension is the batch dimension in it. + // When that is fixed, this should become: + // const auto input_matrix_map = + // MapAsMatrixWithFirstDimAsRows(input_data, input_dims); + const int input_rows = ArraySize(weights_dims, 0); + const auto input_matrix_map = + MapAsMatrixWithGivenNumberOfRows(input_data, input_dims, input_rows); + const auto filter_matrix_map = + MapAsMatrixWithFirstDimAsRows(weights_data, weights_dims); + auto output_matrix_map = + MapAsMatrixWithFirstDimAsRows(output_data, output_dims); + + Gemm(filter_matrix_map.transpose(), input_matrix_map, &output_matrix_map); + AddBiasAndEvalActivationFunction(bias_data, bias_dims, output_data, + output_dims, output_activation_min, + output_activation_max); +} + +// legacy, for compatibility with old checked-in code +template +void FullyConnected(const float* input_data, const Dims<4>& input_dims, + const float* weights_data, const Dims<4>& weights_dims, + const float* bias_data, const Dims<4>& bias_dims, + float* output_data, const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + FullyConnected(input_data, input_dims, weights_data, weights_dims, bias_data, + bias_dims, output_activation_min, output_activation_max, + output_data, output_dims); +} + +inline void preload_l1_stream(const uint8* ptr) { +#ifdef GEMMLOWP_ARM_64 + asm volatile("prfm pldl1strm, [%[ptr]]\n" ::[ptr] "r"(ptr) :); +#else + gemmlowp::Prefetch(ptr); +#endif +} + +#ifdef USE_NEON +inline void FullyConnectedAsGEMV( + const uint8* input_data, const Dims<4>& input_dims, int32 input_offset, + const uint8* filter_data, const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, int32 output_offset, + int32 output_multiplier, int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("FullyConnectedAsGEMV/8bit"); + TFLITE_DCHECK(IsPackedWithoutStrides(input_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(filter_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(bias_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); + TFLITE_DCHECK_EQ(ArraySize(output_dims, 1) * ArraySize(output_dims, 2) * + ArraySize(output_dims, 3), + 1); + const int input_size = input_dims.strides[3]; + const int output_size = MatchingArraySize(filter_dims, 1, output_dims, 0); + static constexpr int kPeel = 4; + for (int k = 0; k < input_size; k += 64) { + preload_l1_stream(input_data + k); + } + for (int k = 0; k < kPeel * input_size; k += 64) { + preload_l1_stream(filter_data + k); + } + TFLITE_DCHECK(!(output_size % kPeel)); + const int32* bias_ptr = bias_data; + uint8* output_ptr = output_data; + for (int out = 0; out < output_size; out += kPeel) { + int32x4_t acc[kPeel]; + for (int k = 0; k < kPeel; k++) { + acc[k] = vdupq_n_s32(0); + } + const int16x8_t input_offset_vec = vdupq_n_s16(input_offset); + const int16x8_t filter_offset_vec = vdupq_n_s16(filter_offset); + int in = 0; + for (; in <= input_size - 16; in += 16) { + const uint8x16_t input_val_u8 = vld1q_u8(input_data + in); + uint8x16_t filter_val_u8[kPeel]; + for (int k = 0; k < kPeel; k++) { + const uint8* filter_ptr = filter_data + in + (out + k) * input_size; + filter_val_u8[k] = vld1q_u8(filter_ptr); + preload_l1_stream(filter_ptr + 64); + } + int16x8_t input_val[2]; + const uint8x8_t low = vget_low_u8(input_val_u8); + const uint8x8_t high = vget_high_u8(input_val_u8); + input_val[0] = vreinterpretq_s16_u16(vmovl_u8(low)); + input_val[1] = vreinterpretq_s16_u16(vmovl_u8(high)); + input_val[0] = vaddq_s16(input_val[0], input_offset_vec); + input_val[1] = vaddq_s16(input_val[1], input_offset_vec); + int16x8_t filter_val[kPeel][2]; + for (int k = 0; k < kPeel; k++) { + const uint8x8_t low = vget_low_u8(filter_val_u8[k]); + const uint8x8_t high = vget_high_u8(filter_val_u8[k]); + filter_val[k][0] = vreinterpretq_s16_u16(vmovl_u8(low)); + filter_val[k][1] = vreinterpretq_s16_u16(vmovl_u8(high)); + filter_val[k][0] = vaddq_s16(filter_val[k][0], filter_offset_vec); + filter_val[k][1] = vaddq_s16(filter_val[k][1], filter_offset_vec); + } + for (int p = 0; p < 2; p++) { + for (int k = 0; k < kPeel; k++) { + acc[k] = vmlal_s16(acc[k], vget_low_s16(filter_val[k][p]), + vget_low_s16(input_val[p])); + } + for (int k = 0; k < kPeel; k++) { + acc[k] = vmlal_s16(acc[k], vget_high_s16(filter_val[k][p]), + vget_high_s16(input_val[p])); + } + } + } + for (; in <= input_size - 8; in += 8) { + const uint8x8_t input_val_u8 = vld1_u8(input_data + in); + uint8x8_t filter_val_u8[kPeel]; + for (int k = 0; k < kPeel; k++) { + const uint8* filter_ptr = filter_data + in + (out + k) * input_size; + filter_val_u8[k] = vld1_u8(filter_ptr); + } + int16x8_t input_val; + input_val = vreinterpretq_s16_u16(vmovl_u8(input_val_u8)); + input_val = vaddq_s16(input_val, input_offset_vec); + int16x8_t filter_val[kPeel]; + for (int k = 0; k < kPeel; k++) { + filter_val[k] = vreinterpretq_s16_u16(vmovl_u8(filter_val_u8[k])); + filter_val[k] = vaddq_s16(filter_val[k], filter_offset_vec); + } + for (int k = 0; k < kPeel; k++) { + acc[k] = vmlal_s16(acc[k], vget_low_s16(filter_val[k]), + vget_low_s16(input_val)); + } + for (int k = 0; k < kPeel; k++) { + acc[k] = vmlal_s16(acc[k], vget_high_s16(filter_val[k]), + vget_high_s16(input_val)); + } + } + if (in < input_size) { + int32 buf[4 * kPeel]; + for (int k = 0; k < 4; k++) { + vst1q_s32(buf + 4 * k, acc[k]); + } + for (; in < input_size; in++) { + int lane = (in + 8 - input_size) % 4; + const int32 input_val = input_data[in] + input_offset; + for (int k = 0; k < kPeel; k++) { + int32 filter_val = + filter_data[in + (out + k) * input_size] + filter_offset; + buf[lane + 4 * k] += filter_val * input_val; + } + } + for (int k = 0; k < 4; k++) { + acc[k] = vld1q_s32(buf + 4 * k); + } + } + + // Horizontally reduce accumulators + int32x2_t pairwise_reduced_acc[kPeel]; + for (int k = 0; k < kPeel; k++) { + pairwise_reduced_acc[k] = + vpadd_s32(vget_low_s32(acc[k]), vget_high_s32(acc[k])); + } + static_assert(kPeel == 4, "the code below currently assumes kPeel = 4"); + const int32x2_t reduced_lo = + vpadd_s32(pairwise_reduced_acc[0], pairwise_reduced_acc[1]); + const int32x2_t reduced_hi = + vpadd_s32(pairwise_reduced_acc[2], pairwise_reduced_acc[3]); + int32x4_t reduced = vcombine_s32(reduced_lo, reduced_hi); + // Add bias values. + int32x4_t bias_vec = vld1q_s32(bias_ptr); + bias_ptr += 4; + reduced = vaddq_s32(reduced, bias_vec); + // Multiply by the fixed-point multiplier. + reduced = vqrdmulhq_n_s32(reduced, output_multiplier); + // Rounding-shift-right. + using gemmlowp::RoundingDivideByPOT; + reduced = RoundingDivideByPOT(reduced, output_shift); + // Add the output offset. + const int32x4_t output_offset_vec = vdupq_n_s32(output_offset); + reduced = vaddq_s32(reduced, output_offset_vec); + // Narrow values down to 16 bit signed. + const int16x4_t res16 = vqmovn_s32(reduced); + // Narrow values down to 8 bit unsigned, saturating. + uint8x8_t res8 = vqmovun_s16(vcombine_s16(res16, res16)); + // Apply the clamping from the activation function + res8 = vmax_u8(res8, vdup_n_u8(output_activation_min)); + res8 = vmin_u8(res8, vdup_n_u8(output_activation_max)); + // Store results to destination. Assumes 32bit alignment. + vst1_lane_u32(reinterpret_cast(output_ptr), + vreinterpret_u32_u8(res8), 0); + output_ptr += kPeel; + } +} +#endif // USE_NEON + +struct GemmlowpOutputPipeline { + typedef gemmlowp::VectorMap + ColVectorMap; + typedef std::tuple< + gemmlowp::OutputStageBiasAddition, + gemmlowp::OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint, + gemmlowp::OutputStageClamp, gemmlowp::OutputStageSaturatingCastToUint8> + Pipeline; + static Pipeline Make(const int32* bias_data, int output_rows, + int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max) { + ColVectorMap bias_vector(bias_data, output_rows); + gemmlowp::OutputStageBiasAddition bias_addition_stage; + bias_addition_stage.bias_vector = bias_vector; + gemmlowp::OutputStageQuantizeDownInt32ToUint8ScaleByFixedPoint + quantize_down_stage; + quantize_down_stage.result_offset_after_shift = output_offset; + quantize_down_stage.result_fixedpoint_multiplier = output_multiplier; + quantize_down_stage.result_shift = output_shift; + gemmlowp::OutputStageClamp clamp_stage; + clamp_stage.min = output_activation_min; + clamp_stage.max = output_activation_max; + gemmlowp::OutputStageSaturatingCastToUint8 saturating_cast_stage; + return std::make_tuple(bias_addition_stage, quantize_down_stage, + clamp_stage, saturating_cast_stage); + } +}; + +inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims, + int32 input_offset, const uint8* filter_data, + const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, + int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims, + gemmlowp::GemmContext* gemm_context) { + gemmlowp::ScopedProfilingLabel label("FullyConnected/8bit"); + // TODO(benoitjacob): This really should be: + // const int batches = ArraySize(output_dims, 1); + // but the current --variable_batch hack consists in overwriting the 3rd + // dimension with the runtime batch size, as we don't keep track for each + // array of which dimension is the batch dimension in it. + const int batches = ArraySize(output_dims, 1) * ArraySize(output_dims, 2) * + ArraySize(output_dims, 3); +#ifdef USE_NEON + const int output_size = MatchingArraySize(filter_dims, 1, output_dims, 0); + if (batches == 1 && !(output_size % 4)) { + return FullyConnectedAsGEMV( + input_data, input_dims, input_offset, filter_data, filter_dims, + filter_offset, bias_data, bias_dims, output_offset, output_multiplier, + output_shift, output_activation_min, output_activation_max, output_data, + output_dims); + } +#endif // USE_NEON + const int filter_rows = filter_dims.sizes[1]; + const int filter_cols = filter_dims.sizes[0]; + TFLITE_DCHECK_EQ(filter_dims.sizes[2], 1); + TFLITE_DCHECK_EQ(filter_dims.sizes[3], 1); + const int output_rows = output_dims.sizes[0]; + TFLITE_DCHECK_EQ(output_rows, filter_rows); + TFLITE_DCHECK_EQ(bias_dims.sizes[0], output_rows); + TFLITE_DCHECK_EQ(bias_dims.sizes[1], 1); + TFLITE_DCHECK_EQ(bias_dims.sizes[2], 1); + TFLITE_DCHECK_EQ(bias_dims.sizes[3], 1); + + gemmlowp::MatrixMap filter_matrix( + filter_data, output_rows, filter_cols, filter_cols); + gemmlowp::MatrixMap input_matrix( + input_data, filter_cols, batches, filter_cols); + gemmlowp::MatrixMap output_matrix( + output_data, output_rows, batches, output_rows); + const auto& output_pipeline = GemmlowpOutputPipeline::Make( + bias_data, output_rows, output_offset, output_multiplier, output_shift, + output_activation_min, output_activation_max); + gemmlowp::GemmWithOutputPipeline( + gemm_context, filter_matrix, input_matrix, &output_matrix, filter_offset, + input_offset, output_pipeline); +} + +// legacy, for compatibility with old checked-in code +template +void FullyConnected(const uint8* input_data, const Dims<4>& input_dims, + int32 input_offset, const uint8* filter_data, + const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, + int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims, + gemmlowp::GemmContext* gemm_context) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + FullyConnected(input_data, input_dims, input_offset, filter_data, filter_dims, + filter_offset, bias_data, bias_dims, output_offset, + output_multiplier, output_shift, output_activation_min, + output_activation_max, output_data, output_dims, gemm_context); +} + +template +inline void ExtractPatchIntoBufferColumn( + const Dims<4>& input_dims, int w, int h, int b, int kheight, int kwidth, + int stride_width, int stride_height, int pad_width, int pad_height, + int in_width, int in_height, int in_depth, int single_buffer_length, + int buffer_id, const T* in_data, T* conv_buffer_data, uint8 byte_zero) { + gemmlowp::ScopedProfilingLabel label("ExtractPatchIntoBufferColumn"); + // This chunk of code reshapes all the inputs corresponding to + // output (b, h, w) to a column vector in conv_buffer(:, buffer_id). + const int kwidth_times_indepth = kwidth * in_depth; + const int inwidth_times_indepth = in_width * in_depth; + const int ih_ungated_start = h * stride_height - pad_height; + const int ih_ungated_end = (ih_ungated_start + kheight); + const int ih_end = std::min(ih_ungated_end, in_height); + const int iw_ungated_start = w * stride_width - pad_width; + const int iw_ungated_end = (iw_ungated_start + kwidth); + const int iw_end = std::min(iw_ungated_end, in_width); + // If the patch is off the edge of the input image, skip writing those rows + // and columns from the patch into the output array. + const int h_offset = std::max(0, -ih_ungated_start); + const int w_offset = std::max(0, -iw_ungated_start); + const int ih_start = std::max(0, ih_ungated_start); + const int iw_start = std::max(0, iw_ungated_start); + const int single_row_num = + std::min(kwidth - w_offset, in_width - iw_start) * in_depth; + const int output_row_offset = (buffer_id * single_buffer_length); + int out_offset = + output_row_offset + (h_offset * kwidth + w_offset) * in_depth; + int in_offset = Offset(input_dims, 0, iw_start, ih_start, b); + + // Express all of the calculations as padding around the input patch. + const int top_padding = h_offset; + const int bottom_padding = (ih_ungated_end - ih_end); + const int left_padding = w_offset; + const int right_padding = (iw_ungated_end - iw_end); + assert(single_row_num == + ((kwidth - (left_padding + right_padding)) * in_depth)); + + // Write out zeroes to the elements representing the top rows of the input + // patch that are off the edge of the input image. + if (top_padding > 0) { + const int top_row_elements = (top_padding * kwidth * in_depth); + memset(conv_buffer_data + output_row_offset, byte_zero, + (top_row_elements * sizeof(T))); + } + + // If the patch is on the interior of the input image horizontally, just copy + // over the rows sequentially, otherwise add zero padding at the start or end. + if ((left_padding == 0) && (right_padding == 0)) { + for (int ih = ih_start; ih < ih_end; ++ih) { + memcpy(conv_buffer_data + out_offset, in_data + in_offset, + single_row_num * sizeof(T)); + out_offset += kwidth_times_indepth; + in_offset += inwidth_times_indepth; + } + } else { + for (int ih = ih_start; ih < ih_end; ++ih) { + if (left_padding > 0) { + const int left_start = (out_offset - (left_padding * in_depth)); + memset(conv_buffer_data + left_start, byte_zero, + (left_padding * in_depth * sizeof(T))); + } + memcpy(conv_buffer_data + out_offset, in_data + in_offset, + single_row_num * sizeof(T)); + if (right_padding > 0) { + const int right_start = (out_offset + single_row_num); + memset(conv_buffer_data + right_start, byte_zero, + (right_padding * in_depth * sizeof(T))); + } + out_offset += kwidth_times_indepth; + in_offset += inwidth_times_indepth; + } + } + + // If the bottom of the patch falls off the input image, pad the values + // representing those input rows with zeroes. + if (bottom_padding > 0) { + const int bottom_row_elements = (bottom_padding * kwidth * in_depth); + const int bottom_start = + output_row_offset + + ((top_padding + (ih_end - ih_start)) * kwidth * in_depth); + memset(conv_buffer_data + bottom_start, byte_zero, + (bottom_row_elements * sizeof(T))); + } +} + +template +void Im2col(const T* input_data, const Dims<4>& input_dims, int stride_width, + int stride_height, int pad_width, int pad_height, int kheight, + int kwidth, uint8 byte_zero, T* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Im2col"); + TFLITE_DCHECK(IsPackedWithoutStrides(input_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int input_depth = ArraySize(input_dims, 0); + const int input_width = ArraySize(input_dims, 1); + const int input_height = ArraySize(input_dims, 2); + const int output_depth = ArraySize(output_dims, 0); + const int output_width = ArraySize(output_dims, 1); + const int output_height = ArraySize(output_dims, 2); + + int buffer_id = 0; + // Loop over the output nodes. + for (int b = 0; b < batches; ++b) { + for (int h = 0; h < output_height; ++h) { + for (int w = 0; w < output_width; ++w) { + ExtractPatchIntoBufferColumn( + input_dims, w, h, b, kheight, kwidth, stride_width, stride_height, + pad_width, pad_height, input_width, input_height, input_depth, + output_depth, buffer_id, input_data, output_data, byte_zero); + ++buffer_id; + } + } + } +} + +// legacy, for compatibility with old checked-in code +template +void Im2col(const T* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int kheight, int kwidth, + uint8 byte_zero, T* output_data, const Dims<4>& output_dims) { + Im2col(input_data, input_dims, stride, stride, pad_width, pad_height, kheight, + kwidth, byte_zero, output_data, output_dims); +} + +inline void Conv(const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + const float* bias_data, const Dims<4>& bias_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, float output_activation_min, + float output_activation_max, float* output_data, + const Dims<4>& output_dims, float* im2col_data, + const Dims<4>& im2col_dims) { + (void)im2col_data; + (void)im2col_dims; + gemmlowp::ScopedProfilingLabel label("Conv"); + + const float* gemm_input_data = nullptr; + const Dims<4>* gemm_input_dims = nullptr; + const int filter_width = ArraySize(filter_dims, 1); + const int filter_height = ArraySize(filter_dims, 2); + const bool need_im2col = stride_width != 1 || stride_height != 1 || + filter_width != 1 || filter_height != 1; + if (need_im2col) { + TFLITE_DCHECK(im2col_data); + Im2col(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, filter_height, filter_width, 0, im2col_data, + im2col_dims); + gemm_input_data = im2col_data; + gemm_input_dims = &im2col_dims; + } else { + // TODO(aselle): We need to make sure to not send im2col if it is not + // needed. + TFLITE_DCHECK(!im2col_data); + gemm_input_data = input_data; + gemm_input_dims = &input_dims; + } + + const auto im2col_matrix_map = + MapAsMatrixWithFirstDimAsRows(gemm_input_data, *gemm_input_dims); + const auto filter_matrix_map = + MapAsMatrixWithLastDimAsCols(filter_data, filter_dims); + auto output_matrix_map = + MapAsMatrixWithFirstDimAsRows(output_data, output_dims); + + Gemm(filter_matrix_map.transpose(), im2col_matrix_map, &output_matrix_map); + + AddBiasAndEvalActivationFunction(bias_data, bias_dims, output_data, + output_dims, output_activation_min, + output_activation_max); +} + +// legacy, for compatibility with old checked-in code +template +void Conv(const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + const float* bias_data, const Dims<4>& bias_dims, int stride_width, + int stride_height, int pad_width, int pad_height, float* output_data, + const Dims<4>& output_dims, float* im2col_data, + const Dims<4>& im2col_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + Conv(input_data, input_dims, filter_data, filter_dims, bias_data, bias_dims, + stride_width, stride_height, pad_width, pad_height, + output_activation_min, output_activation_max, output_data, output_dims, + im2col_data, im2col_dims); +} + +// legacy, for compatibility with old checked-in code +template +void Conv(const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + const float* bias_data, const Dims<4>& bias_dims, int stride, + int pad_width, int pad_height, float* output_data, + const Dims<4>& output_dims, float* im2col_data, + const Dims<4>& im2col_dims) { + Conv(input_data, input_dims, filter_data, filter_dims, bias_data, + bias_dims, stride, stride, pad_width, pad_height, output_data, + output_dims, im2col_data, im2col_dims); +} + +inline void Conv(const uint8* input_data, const Dims<4>& input_dims, + int32 input_offset, const uint8* filter_data, + const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims, uint8* im2col_data, + const Dims<4>& im2col_dims, + gemmlowp::GemmContext* gemm_context) { + gemmlowp::ScopedProfilingLabel label("Conv/8bit"); + + TFLITE_DCHECK(IsPackedWithoutStrides(input_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(filter_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); + + const uint8* gemm_input_data = nullptr; + const Dims<4>* gemm_input_dims = nullptr; + const int filter_width = ArraySize(filter_dims, 1); + const int filter_height = ArraySize(filter_dims, 2); + const bool need_im2col = stride_width != 1 || stride_height != 1 || + filter_width != 1 || filter_height != 1; + if (need_im2col) { + TFLITE_DCHECK(im2col_data); + const int input_zero_point = -input_offset; + TFLITE_DCHECK_GE(input_zero_point, 0); + TFLITE_DCHECK_LE(input_zero_point, 255); + Im2col(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, filter_height, filter_width, input_zero_point, + im2col_data, im2col_dims); + gemm_input_data = im2col_data; + gemm_input_dims = &im2col_dims; + } else { + TFLITE_DCHECK(!im2col_data); + gemm_input_data = input_data; + gemm_input_dims = &input_dims; + } + + const int gemm_input_rows = gemm_input_dims->sizes[0]; + const int gemm_input_cols = gemm_input_dims->sizes[1] * + gemm_input_dims->sizes[2] * + gemm_input_dims->sizes[3]; + const int filter_rows = filter_dims.sizes[3]; + const int filter_cols = + filter_dims.sizes[0] * filter_dims.sizes[1] * filter_dims.sizes[2]; + const int output_rows = output_dims.sizes[0]; + const int output_cols = + output_dims.sizes[1] * output_dims.sizes[2] * output_dims.sizes[3]; + TFLITE_DCHECK_EQ(output_rows, filter_rows); + TFLITE_DCHECK_EQ(output_cols, gemm_input_cols); + TFLITE_DCHECK_EQ(filter_cols, gemm_input_rows); + TFLITE_DCHECK_EQ(bias_dims.sizes[0], output_rows); + TFLITE_DCHECK_EQ(bias_dims.sizes[1], 1); + TFLITE_DCHECK_EQ(bias_dims.sizes[2], 1); + TFLITE_DCHECK_EQ(bias_dims.sizes[3], 1); + gemmlowp::MatrixMap filter_matrix( + filter_data, filter_rows, filter_cols); + gemmlowp::MatrixMap input_matrix( + gemm_input_data, gemm_input_rows, gemm_input_cols); + gemmlowp::MatrixMap output_matrix( + output_data, output_rows, output_cols); + const auto& output_pipeline = GemmlowpOutputPipeline::Make( + bias_data, output_rows, output_offset, output_multiplier, output_shift, + output_activation_min, output_activation_max); + gemmlowp::GemmWithOutputPipeline( + gemm_context, filter_matrix, input_matrix, &output_matrix, filter_offset, + input_offset, output_pipeline); +} + +// legacy, for compatibility with old checked-in code +template +inline void Conv(const uint8* input_data, const Dims<4>& input_dims, + int32 input_offset, const uint8* filter_data, + const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims, uint8* im2col_data, + const Dims<4>& im2col_dims, + gemmlowp::GemmContext* gemm_context) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + Conv(input_data, input_dims, input_offset, filter_data, filter_dims, + filter_offset, bias_data, bias_dims, stride_width, stride_height, + pad_width, pad_height, output_offset, output_multiplier, output_shift, + output_activation_min, output_activation_max, output_data, output_dims, + im2col_data, im2col_dims, gemm_context); +} + +// legacy, for compatibility with old checked-in code +template +void Conv(const uint8* input_data, const Dims<4>& input_dims, + int32 input_offset, const uint8* filter_data, + const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, int stride, + int pad_width, int pad_height, int32 output_offset, + int32 output_multiplier, int output_shift, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims, uint8* im2col_data, + const Dims<4>& im2col_dims, gemmlowp::GemmContext* gemm_context) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + Conv(input_data, input_dims, input_offset, filter_data, filter_dims, + filter_offset, bias_data, bias_dims, stride, stride, pad_width, + pad_height, output_offset, output_multiplier, output_shift, + output_activation_min, output_activation_max, output_data, output_dims, + im2col_data, im2col_dims, gemm_context); +} + +template +inline void DepthToSpace(const T* input_data, const Dims<4>& input_dims, + int block_size, T* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("DepthToSpace"); + + const int input_depth = ArraySize(input_dims, 0); + const int input_width = ArraySize(input_dims, 1); + const int input_height = ArraySize(input_dims, 2); + + const int output_depth = ArraySize(output_dims, 0); + const int batch_size = ArraySize(output_dims, 3); + + // Number of continuous values that we can copy in one interation. + const int stride = block_size * output_depth; + + for (int batch = 0; batch < batch_size; ++batch) { + for (int in_h = 0; in_h < input_height; ++in_h) { + const T* input_ptr = input_data + Offset(input_dims, 0, 0, in_h, batch); + for (int offset_h = 0; offset_h < block_size; ++offset_h) { + const T* src = input_ptr; + for (int in_w = 0; in_w < input_width; ++in_w) { + memcpy(output_data, src, stride * sizeof(T)); + output_data += stride; + src += input_depth; + } + input_ptr += stride; + } + } + } +} + +// legacy, for compatibility with old checked-in code +template +void Im2col(const T* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int kheight, int kwidth, + uint8 byte_zero, T* output_data, const Dims<4>& output_dims) { + Im2col(input_data, input_dims, stride, stride, pad_width, pad_height, kheight, + kwidth, byte_zero, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void ConvAsGemm(const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + const float* bias_data, const Dims<4>& bias_dims, + float* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("ConvAsGemm"); + + const auto input_matrix_map = + MapAsMatrixWithFirstDimAsRows(input_data, input_dims); + const auto filter_matrix_map = + MapAsMatrixWithLastDimAsCols(filter_data, filter_dims); + auto output_matrix_map = + MapAsMatrixWithFirstDimAsRows(output_data, output_dims); + + Gemm(filter_matrix_map.transpose(), input_matrix_map, &output_matrix_map); + + AddBiasAndEvalActivationFunction(bias_data, bias_dims, output_data, + output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void ConvAsGemm(const uint8* input_data, const Dims<4>& input_dims, + int32 input_offset, const uint8* filter_data, + const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, + int32 output_offset, int32 output_multiplier, int output_shift, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims, + gemmlowp::GemmContext* gemm_context) { + gemmlowp::ScopedProfilingLabel label("ConvAsGemm/8bit"); + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + const int input_rows = input_dims.sizes[0]; + const int input_cols = + input_dims.sizes[1] * input_dims.sizes[2] * input_dims.sizes[3]; + const int filter_rows = filter_dims.sizes[3]; + const int filter_cols = + filter_dims.sizes[0] * filter_dims.sizes[1] * filter_dims.sizes[2]; + const int output_rows = output_dims.sizes[0]; + const int output_cols = + output_dims.sizes[1] * output_dims.sizes[2] * output_dims.sizes[3]; + TFLITE_DCHECK_EQ(output_rows, filter_rows); + TFLITE_DCHECK_EQ(output_cols, input_cols); + TFLITE_DCHECK_EQ(filter_cols, input_rows); + TFLITE_DCHECK_EQ(bias_dims.sizes[0], output_rows); + TFLITE_DCHECK_EQ(bias_dims.sizes[1], 1); + TFLITE_DCHECK_EQ(bias_dims.sizes[2], 1); + TFLITE_DCHECK_EQ(bias_dims.sizes[3], 1); + gemmlowp::MatrixMap filter_matrix( + filter_data, output_rows, filter_cols, filter_cols); + gemmlowp::MatrixMap input_matrix( + input_data, filter_cols, output_cols, filter_cols); + gemmlowp::MatrixMap output_matrix( + output_data, output_rows, output_cols, output_rows); + const auto& output_pipeline = GemmlowpOutputPipeline::Make( + bias_data, output_rows, output_offset, output_multiplier, output_shift, + output_activation_min, output_activation_max); + gemmlowp::GemmWithOutputPipeline( + gemm_context, filter_matrix, input_matrix, &output_matrix, filter_offset, + input_offset, output_pipeline); +} + +template +inline void SpaceToDepth(const T* input_data, const Dims<4>& input_dims, + int block_size, T* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("SpaceToDepth"); + + const int output_depth = ArraySize(output_dims, 0); + const int output_width = ArraySize(output_dims, 1); + const int output_height = ArraySize(output_dims, 2); + + const int input_depth = ArraySize(input_dims, 0); + const int batch_size = ArraySize(input_dims, 3); + + // Number of continuous values that we can copy in one interation. + const int stride = block_size * input_depth; + + for (int batch = 0; batch < batch_size; ++batch) { + for (int out_h = 0; out_h < output_height; ++out_h) { + T* output_ptr = output_data + Offset(output_dims, 0, 0, out_h, batch); + for (int offset_h = 0; offset_h < block_size; ++offset_h) { + T* dst = output_ptr; + for (int out_w = 0; out_w < output_width; ++out_w) { + memcpy(dst, input_data, stride * sizeof(T)); + input_data += stride; + dst += output_depth; + } + output_ptr += stride; + } + } + } +} + +template +void NonGlobalBatchNormalization( + const float* input_data, const Dims<4>& input_dims, const float* mean_data, + const Dims<4>& mean_dims, const float* multiplier_data, + const Dims<4>& multiplier_dims, const float* offset_data, + const Dims<4>& offset_dims, float* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("NonGlobalBatchNormalization"); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = + MatchingArraySize(input_dims, 2, mean_dims, 2, multiplier_dims, 2, + offset_dims, 2, output_dims, 2); + const int width = + MatchingArraySize(input_dims, 1, mean_dims, 1, multiplier_dims, 1, + offset_dims, 1, output_dims, 1); + const int depth = + MatchingArraySize(input_dims, 0, mean_dims, 0, multiplier_dims, 0, + offset_dims, 0, output_dims, 0); + + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + output_data[Offset(output_dims, c, x, y, b)] = ActivationFunction( + (input_data[Offset(input_dims, c, x, y, b)] - + mean_data[Offset(mean_dims, c, x, y, 0)]) * + multiplier_data[Offset(multiplier_dims, c, x, y, 0)] + + offset_data[Offset(offset_dims, c, x, y, 0)]); + } + } + } + } +} + +template +void GlobalBatchNormalization(const float* input_data, + const Dims<4>& input_dims, const float* mean_data, + const Dims<4>& mean_dims, + const float* multiplier_data, + const Dims<4>& multiplier_dims, + const float* offset_data, + const Dims<4>& offset_dims, float* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("GlobalBatchNormalization"); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = + MatchingArraySize(input_dims, 0, mean_dims, 0, multiplier_dims, 0, + offset_dims, 0, output_dims, 0); + + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + output_data[Offset(output_dims, c, x, y, b)] = ActivationFunction( + (input_data[Offset(input_dims, c, x, y, b)] - + mean_data[Offset(mean_dims, c, 0, 0, 0)]) * + multiplier_data[Offset(multiplier_dims, c, 0, 0, 0)] + + offset_data[Offset(offset_dims, c, 0, 0, 0)]); + } + } + } + } +} + +inline void Relu(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Relu (not fused)"); + + const auto input = MapAsVector(input_data, input_dims); + auto output = MapAsVector(output_data, output_dims); + output = input.cwiseMax(0.0f); +} + +inline void Relu1(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Relu1 (not fused)"); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + float val = input_data[Offset(input_dims, c, x, y, b)]; + const float upper = 1; + const float lower = -1; + float clamped = val > upper ? upper : val < lower ? lower : val; + output_data[Offset(output_dims, c, x, y, b)] = clamped; + } + } + } + } +} + +inline void Relu6(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Relu6 (not fused)"); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + float val = input_data[Offset(input_dims, c, x, y, b)]; + const float upper = 6; + const float lower = 0; + float clamped = val > upper ? upper : val < lower ? lower : val; + output_data[Offset(output_dims, c, x, y, b)] = clamped; + } + } + } + } +} + +template +void L2Normalization(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("L2Normalization"); + static_assert(Ac == FusedActivationFunctionType::kNone, ""); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + float squared_l2_norm = 0; + for (int c = 0; c < depth; ++c) { + float val = input_data[Offset(input_dims, c, x, y, b)]; + squared_l2_norm += val * val; + } + float inverse_l2_norm = 1.0f / std::sqrt(squared_l2_norm); + for (int c = 0; c < depth; ++c) { + output_data[Offset(output_dims, c, x, y, b)] = + input_data[Offset(input_dims, c, x, y, b)] * inverse_l2_norm; + } + } + } + } +} + +inline void GetInvSqrtQuantizedMultiplier(int32 input, int32* output_inv_sqrt, + int* output_shift) { + *output_shift = 11; + while (input >= (1 << 29)) { + input /= 4; + ++*output_shift; + } + TFLITE_DCHECK_GT(input, 0); + const unsigned max_left_shift_bits = __builtin_clz(input) - 1; + const unsigned max_left_shift_bit_pairs = max_left_shift_bits / 2; + const unsigned left_shift_bit_pairs = max_left_shift_bit_pairs - 1; + *output_shift -= left_shift_bit_pairs; + input <<= 2 * left_shift_bit_pairs; + TFLITE_DCHECK_GE(input, (1 << 27)); + TFLITE_DCHECK_LT(input, (1 << 29)); + using gemmlowp::FixedPoint; + using gemmlowp::Rescale; + using gemmlowp::SaturatingRoundingMultiplyByPOT; + // Using 3 integer bits gives us enough room for the internal arithmetic in + // this Newton-Raphson iteration. + using F3 = FixedPoint; + using F0 = FixedPoint; + const F3 fixedpoint_input = F3::FromRaw(input >> 1); + const F3 fixedpoint_half_input = + SaturatingRoundingMultiplyByPOT<-1>(fixedpoint_input); + const F3 fixedpoint_half_three = + GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F3, (1 << 28) + (1 << 27), 1.5); + // Newton-Raphson iteration + // Naive unoptimized starting guess: x = 1 + F3 x = F3::One(); + // Naive unoptimized number of iterations: 5 + for (int i = 0; i < 5; i++) { + const F3 x3 = Rescale<3>(x * x * x); + x = Rescale<3>(fixedpoint_half_three * x - fixedpoint_half_input * x3); + } + const F0 fixedpoint_half_sqrt_2 = + GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F0, 1518500250, std::sqrt(2.) / 2.); + x = x * fixedpoint_half_sqrt_2; + *output_inv_sqrt = x.raw(); + if (*output_shift < 0) { + *output_inv_sqrt <<= -*output_shift; + *output_shift = 0; + } +} + +inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims, + int32 input_zero_point, uint8* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("L2Normalization/8bit"); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + TFLITE_DCHECK(IsPackedWithoutStrides(input_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); + TFLITE_DCHECK_EQ(batches, 1); + TFLITE_DCHECK_EQ(height, 1); + TFLITE_DCHECK_EQ(width, 1); + int32 square_l2_norm = 0; + for (int i = 0; i < depth; i++) { + int32 diff = input_data[i] - input_zero_point; + square_l2_norm += diff * diff; + } + int32 inv_l2norm_multiplier; + int inv_l2norm_shift; + GetInvSqrtQuantizedMultiplier(square_l2_norm, &inv_l2norm_multiplier, + &inv_l2norm_shift); + + for (int i = 0; i < depth; i++) { + int32 diff = input_data[i] - input_zero_point; + int32 rescaled_diff = MultiplyByQuantizedMultiplierSmallerThanOne( + 128 * diff, inv_l2norm_multiplier, inv_l2norm_shift); + int32 unclamped_output_val = 128 + rescaled_diff; + int32 output_val = std::min(255, std::max(0, unclamped_output_val)); + output_data[i] = static_cast(output_val); + } +} + +inline void Add(const float* input1_data, const Dims<4>& input1_dims, + const float* input2_data, const Dims<4>& input2_dims, + float output_activation_min, float output_activation_max, + float* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Add"); + /* const int batches = */ MatchingArraySize(input1_dims, 3, input2_dims, 3, + output_dims, 3); + /* const int height = */ MatchingArraySize(input1_dims, 2, input2_dims, 2, + output_dims, 2); + /* const int width = */ MatchingArraySize(input1_dims, 1, input2_dims, 1, + output_dims, 1); + /* const int depth = */ MatchingArraySize(input1_dims, 0, input2_dims, 0, + output_dims, 0); + TFLITE_DCHECK(IsPackedWithoutStrides(input1_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(input2_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); + + int i = 0; + const int size = input1_dims.sizes[3] * input1_dims.strides[3]; +#ifdef USE_NEON + const auto activation_min = vdupq_n_f32(output_activation_min); + const auto activation_max = vdupq_n_f32(output_activation_max); + for (; i <= size - 16; i += 16) { + auto a10 = vld1q_f32(input1_data + i); + auto a11 = vld1q_f32(input1_data + i + 4); + auto a12 = vld1q_f32(input1_data + i + 8); + auto a13 = vld1q_f32(input1_data + i + 12); + auto a20 = vld1q_f32(input2_data + i); + auto a21 = vld1q_f32(input2_data + i + 4); + auto a22 = vld1q_f32(input2_data + i + 8); + auto a23 = vld1q_f32(input2_data + i + 12); + auto x0 = vaddq_f32(a10, a20); + auto x1 = vaddq_f32(a11, a21); + auto x2 = vaddq_f32(a12, a22); + auto x3 = vaddq_f32(a13, a23); + x0 = vmaxq_f32(activation_min, x0); + x1 = vmaxq_f32(activation_min, x1); + x2 = vmaxq_f32(activation_min, x2); + x3 = vmaxq_f32(activation_min, x3); + x0 = vminq_f32(activation_max, x0); + x1 = vminq_f32(activation_max, x1); + x2 = vminq_f32(activation_max, x2); + x3 = vminq_f32(activation_max, x3); + vst1q_f32(output_data + i, x0); + vst1q_f32(output_data + i + 4, x1); + vst1q_f32(output_data + i + 8, x2); + vst1q_f32(output_data + i + 12, x3); + } + for (; i <= size - 4; i += 4) { + auto a1 = vld1q_f32(input1_data + i); + auto a2 = vld1q_f32(input2_data + i); + auto x = vaddq_f32(a1, a2); + x = vmaxq_f32(activation_min, x); + x = vminq_f32(activation_max, x); + vst1q_f32(output_data + i, x); + } +#endif // NEON + + for (; i < size; i++) { + auto x = input1_data[i] + input2_data[i]; + output_data[i] = ActivationFunctionWithMinMax(x, output_activation_min, + output_activation_max); + } +} + +// legacy, for compatibility with old checked-in code +template +void Add(const float* input1_data, const Dims<4>& input1_dims, + const float* input2_data, const Dims<4>& input2_dims, + float* output_data, const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + + Add(input1_data, input1_dims, input2_data, input2_dims, output_activation_min, + output_activation_max, output_data, output_dims); +} + +template +inline void Add(int left_shift, const uint8* input1_data, + const Dims<4>& input1_dims, int32 input1_offset, + int32 input1_multiplier, int input1_shift, + const uint8* input2_data, const Dims<4>& input2_dims, + int32 input2_offset, int32 input2_multiplier, int input2_shift, + int32 output_offset, int32 output_multiplier, int output_shift, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + gemmlowp::ScopedProfilingLabel label("Add/8bit"); + /* const int batches = */ MatchingArraySize(input1_dims, 3, input2_dims, 3, + output_dims, 3); + /* const int height = */ MatchingArraySize(input1_dims, 2, input2_dims, 2, + output_dims, 2); + /* const int width = */ MatchingArraySize(input1_dims, 1, input2_dims, 1, + output_dims, 1); + /* const int depth = */ MatchingArraySize(input1_dims, 0, input2_dims, 0, + output_dims, 0); + TFLITE_DCHECK(IsPackedWithoutStrides(input1_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(input2_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); + + int i = 0; + const int size = input1_dims.sizes[3] * input1_dims.strides[3]; + TFLITE_DCHECK_GT(input1_offset, -256); + TFLITE_DCHECK_GT(input2_offset, -256); + TFLITE_DCHECK_LT(input1_offset, 256); + TFLITE_DCHECK_LT(input2_offset, 256); +#ifdef USE_NEON + for (; i <= size - 8; i += 8) { + const auto input1_val_original = vld1_u8(input1_data + i); + const auto input2_val_original = vld1_u8(input2_data + i); + const auto input1_val_s16 = + vreinterpretq_s16_u16(vmovl_u8(input1_val_original)); + const auto input2_val_s16 = + vreinterpretq_s16_u16(vmovl_u8(input2_val_original)); + const auto input1_val = + vaddq_s16(input1_val_s16, vdupq_n_s16(input1_offset)); + const auto input2_val = + vaddq_s16(input2_val_s16, vdupq_n_s16(input2_offset)); + const auto input1_val_high = vget_high_s16(input1_val); + const auto input1_val_low = vget_low_s16(input1_val); + const auto input2_val_high = vget_high_s16(input2_val); + const auto input2_val_low = vget_low_s16(input2_val); + auto x11 = vmovl_s16(input1_val_low); + auto x12 = vmovl_s16(input1_val_high); + auto x21 = vmovl_s16(input2_val_low); + auto x22 = vmovl_s16(input2_val_high); + const auto left_shift_dup = vdupq_n_s32(left_shift); + x11 = vshlq_s32(x11, left_shift_dup); + x12 = vshlq_s32(x12, left_shift_dup); + x21 = vshlq_s32(x21, left_shift_dup); + x22 = vshlq_s32(x22, left_shift_dup); + x11 = vqrdmulhq_n_s32(x11, input1_multiplier); + x12 = vqrdmulhq_n_s32(x12, input1_multiplier); + x21 = vqrdmulhq_n_s32(x21, input2_multiplier); + x22 = vqrdmulhq_n_s32(x22, input2_multiplier); + const auto input1_shift_dup = vdupq_n_s32(-input1_shift); + const auto input2_shift_dup = vdupq_n_s32(-input2_shift); + x11 = vshlq_s32(x11, input1_shift_dup); + x12 = vshlq_s32(x12, input1_shift_dup); + x21 = vshlq_s32(x21, input2_shift_dup); + x22 = vshlq_s32(x22, input2_shift_dup); + auto s1 = vaddq_s32(x11, x21); + auto s2 = vaddq_s32(x12, x22); + s1 = vqrdmulhq_n_s32(s1, output_multiplier); + s2 = vqrdmulhq_n_s32(s2, output_multiplier); + using gemmlowp::RoundingDivideByPOT; + s1 = RoundingDivideByPOT(s1, output_shift); + s2 = RoundingDivideByPOT(s2, output_shift); + const auto s1_narrowed = vmovn_s32(s1); + const auto s2_narrowed = vmovn_s32(s2); + const auto s = vaddq_s16(vcombine_s16(s1_narrowed, s2_narrowed), + vdupq_n_s16(output_offset)); + vst1_u8(output_data + i, vqmovun_s16(s)); + } +#endif // NEON + + for (; i < size; i++) { + const int32 input1_val = input1_offset + input1_data[i]; + const int32 input2_val = input2_offset + input2_data[i]; + const int32 shifted_input1_val = input1_val * (1 << left_shift); + const int32 shifted_input2_val = input2_val * (1 << left_shift); + const int32 scaled_input1_val = MultiplyByQuantizedMultiplierSmallerThanOne( + shifted_input1_val, input1_multiplier, input1_shift); + const int32 scaled_input2_val = MultiplyByQuantizedMultiplierSmallerThanOne( + shifted_input2_val, input2_multiplier, input2_shift); + const int32 raw_sum = scaled_input1_val + scaled_input2_val; + const int32 raw_output = MultiplyByQuantizedMultiplierSmallerThanOne( + raw_sum, output_multiplier, output_shift) + + output_offset; + const int32 clamped_output = std::min( + output_activation_max, std::max(output_activation_min, raw_output)); + output_data[i] = static_cast(clamped_output); + } +} + +template +void Add(const int32* input1_data, const Dims<4>& input1_dims, + const int32* input2_data, const Dims<4>& input2_dims, + int32* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Add/int32"); + TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone); + + auto input1_map = MapAsVector(input1_data, input1_dims); + auto input2_map = MapAsVector(input2_data, input2_dims); + auto output_map = MapAsVector(output_data, output_dims); + if (AreSameDims(input1_dims, input2_dims)) { + output_map.array() = input1_map.array() + input2_map.array(); + } else if (RequiredBufferSizeForDims(input2_dims) == 1) { + auto scalar = input2_data[0]; + output_map.array() = input1_map.array() + scalar; + } else if (RequiredBufferSizeForDims(input1_dims) == 1) { + auto scalar = input1_data[0]; + output_map.array() = scalar + input2_map.array(); + } else { + // Should not come here. + TFLITE_DCHECK(false); + } +} + +// TODO(jiawen): We can implement BroadcastAdd on buffers of arbitrary +// dimensionality if the runtime code does a single loop over one dimension +// that handles broadcasting as the base case. The code generator would then +// generate max(D1, D2) nested for loops. +// TODO(benoitjacob): BroadcastAdd is intentionally duplicated from +// reference_ops.h. Once an optimized version is implemented and NdArrayDesc +// is no longer referenced in this file, move NdArrayDesc from types.h to +// reference_ops.h. +template +void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims, + const T* input2_data, const Dims<4>& input2_dims, + T* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("BroadcastAdd"); + + NdArrayDesc<4> desc1; + NdArrayDesc<4> desc2; + NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); + + // In Tensorflow, the dimensions are canonically named (batch_number, row, + // col, channel), with extents (batches, height, width, depth), with the + // trailing dimension changing most rapidly (channels has the smallest stride, + // typically 1 element). + // + // In generated C code, we store arrays with the dimensions reversed. The + // first dimension has smallest stride. + // + // We name our variables by their Tensorflow convention, but generate C code + // nesting loops such that the innermost loop has the smallest stride for the + // best cache behavior. + for (int b = 0; b < ArraySize(output_dims, 3); ++b) { + for (int y = 0; y < ArraySize(output_dims, 2); ++y) { + for (int x = 0; x < ArraySize(output_dims, 1); ++x) { + for (int c = 0; c < ArraySize(output_dims, 0); ++c) { + output_data[Offset(output_dims, c, x, y, b)] = ActivationFunction( + input1_data[SubscriptToIndex(desc1, c, x, y, b)] + + input2_data[SubscriptToIndex(desc2, c, x, y, b)]); + } + } + } + } +} + +inline void BroadcastAdd(int left_shift, const uint8* input1_data, + const Dims<4>& input1_dims, int32 input1_offset, + int32 input1_multiplier, int input1_shift, + const uint8* input2_data, const Dims<4>& input2_dims, + int32 input2_offset, int32 input2_multiplier, + int input2_shift, int32 output_offset, + int32 output_multiplier, int output_shift, + int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("BroadcastAdd/8bit"); + + NdArrayDesc<4> desc1; + NdArrayDesc<4> desc2; + NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); + + // In Tensorflow, the dimensions are canonically named (batch_number, row, + // col, channel), with extents (batches, height, width, depth), with the + // trailing dimension changing most rapidly (channels has the smallest stride, + // typically 1 element). + // + // In generated C code, we store arrays with the dimensions reversed. The + // first dimension has smallest stride. + // + // We name our variables by their Tensorflow convention, but generate C code + // nesting loops such that the innermost loop has the smallest stride for the + // best cache behavior. + for (int b = 0; b < ArraySize(output_dims, 3); ++b) { + for (int y = 0; y < ArraySize(output_dims, 2); ++y) { + for (int x = 0; x < ArraySize(output_dims, 1); ++x) { + for (int c = 0; c < ArraySize(output_dims, 0); ++c) { + const int32 input1_val = + input1_offset + input1_data[SubscriptToIndex(desc1, c, x, y, b)]; + const int32 input2_val = + input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)]; + const int32 shifted_input1_val = input1_val * (1 << left_shift); + const int32 shifted_input2_val = input2_val * (1 << left_shift); + const int32 scaled_input1_val = + MultiplyByQuantizedMultiplierSmallerThanOne( + shifted_input1_val, input1_multiplier, input1_shift); + const int32 scaled_input2_val = + MultiplyByQuantizedMultiplierSmallerThanOne( + shifted_input2_val, input2_multiplier, input2_shift); + const int32 raw_sum = scaled_input1_val + scaled_input2_val; + const int32 raw_output = + MultiplyByQuantizedMultiplierSmallerThanOne( + raw_sum, output_multiplier, output_shift) + + output_offset; + const int32 clamped_output = + std::min(output_activation_max, + std::max(output_activation_min, raw_output)); + output_data[Offset(output_dims, c, x, y, b)] = + static_cast(clamped_output); + } + } + } + } +} + +template +inline void BroadcastAdd(int left_shift, const uint8* input1_data, + const Dims<4>& input1_dims, int32 input1_offset, + int32 input1_multiplier, int input1_shift, + const uint8* input2_data, const Dims<4>& input2_dims, + int32 input2_offset, int32 input2_multiplier, + int input2_shift, int32 output_offset, + int32 output_multiplier, int output_shift, + int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + BroadcastAdd(left_shift, input1_data, input1_dims, input1_offset, + input1_multiplier, input1_shift, input2_data, input2_dims, + input2_offset, input2_multiplier, input2_shift, output_offset, + output_multiplier, output_shift, output_activation_min, + output_activation_max, output_data, output_dims); +} + +inline void Mul(const float* input1_data, const Dims<4>& input1_dims, + const float* input2_data, const Dims<4>& input2_dims, + float output_activation_min, float output_activation_max, + float* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Mul"); + /* const int batches = */ MatchingArraySize(input1_dims, 3, input2_dims, 3, + output_dims, 3); + /* const int height = */ MatchingArraySize(input1_dims, 2, input2_dims, 2, + output_dims, 2); + /* const int width = */ MatchingArraySize(input1_dims, 1, input2_dims, 1, + output_dims, 1); + /* const int depth = */ MatchingArraySize(input1_dims, 0, input2_dims, 0, + output_dims, 0); + TFLITE_DCHECK(IsPackedWithoutStrides(input1_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(input2_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); + + int i = 0; + const int size = input1_dims.sizes[3] * input1_dims.strides[3]; +#ifdef USE_NEON + const auto activation_min = vdupq_n_f32(output_activation_min); + const auto activation_max = vdupq_n_f32(output_activation_max); + for (; i <= size - 16; i += 16) { + auto a10 = vld1q_f32(input1_data + i); + auto a11 = vld1q_f32(input1_data + i + 4); + auto a12 = vld1q_f32(input1_data + i + 8); + auto a13 = vld1q_f32(input1_data + i + 12); + auto a20 = vld1q_f32(input2_data + i); + auto a21 = vld1q_f32(input2_data + i + 4); + auto a22 = vld1q_f32(input2_data + i + 8); + auto a23 = vld1q_f32(input2_data + i + 12); + auto x0 = vmulq_f32(a10, a20); + auto x1 = vmulq_f32(a11, a21); + auto x2 = vmulq_f32(a12, a22); + auto x3 = vmulq_f32(a13, a23); + + x0 = vmaxq_f32(activation_min, x0); + x1 = vmaxq_f32(activation_min, x1); + x2 = vmaxq_f32(activation_min, x2); + x3 = vmaxq_f32(activation_min, x3); + x0 = vminq_f32(activation_max, x0); + x1 = vminq_f32(activation_max, x1); + x2 = vminq_f32(activation_max, x2); + x3 = vminq_f32(activation_max, x3); + + vst1q_f32(output_data + i, x0); + vst1q_f32(output_data + i + 4, x1); + vst1q_f32(output_data + i + 8, x2); + vst1q_f32(output_data + i + 12, x3); + } + for (; i <= size - 4; i += 4) { + auto a1 = vld1q_f32(input1_data + i); + auto a2 = vld1q_f32(input2_data + i); + auto x = vmulq_f32(a1, a2); + + x = vmaxq_f32(activation_min, x); + x = vminq_f32(activation_max, x); + + vst1q_f32(output_data + i, x); + } +#endif // NEON + + for (; i < size; i++) { + auto x = input1_data[i] * input2_data[i]; + output_data[i] = ActivationFunctionWithMinMax(x, output_activation_min, + output_activation_max); + } +} + +// legacy, for compatibility with old checked-in code +template +void Mul(const float* input1_data, const Dims<4>& input1_dims, + const float* input2_data, const Dims<4>& input2_dims, + float* output_data, const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + + Mul(input1_data, input1_dims, input2_data, input2_dims, output_activation_min, + output_activation_max, output_data, output_dims); +} + +template +void Mul(const int32* input1_data, const Dims<4>& input1_dims, + const int32* input2_data, const Dims<4>& input2_dims, + int32* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Mul/int32"); + TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone); + + auto input1_map = MapAsVector(input1_data, input1_dims); + auto input2_map = MapAsVector(input2_data, input2_dims); + auto output_map = MapAsVector(output_data, output_dims); + if (AreSameDims(input1_dims, input2_dims)) { + output_map.array() = input1_map.array() * input2_map.array(); + } else if (RequiredBufferSizeForDims(input2_dims) == 1) { + auto scalar = input2_data[0]; + output_map.array() = input1_map.array() * scalar; + } else if (RequiredBufferSizeForDims(input1_dims) == 1) { + auto scalar = input1_data[0]; + output_map.array() = scalar * input2_map.array(); + } else { + // Should not come here. + TFLITE_DCHECK(false); + } +} + +// TODO(jiawen): We can implement BroadcastMul on buffers of arbitrary +// dimensionality if the runtime code does a single loop over one dimension +// that handles broadcasting as the base case. The code generator would then +// generate max(D1, D2) nested for loops. +// TODO(benoitjacob): BroadcastMul is intentionally duplicated from +// reference_ops.h. Once an optimized version is implemented and NdArrayDesc +// is no longer referenced in this file, move NdArrayDesc from types.h to +// reference_ops.h. +template +void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims, + const T* input2_data, const Dims<4>& input2_dims, + T* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("BroadcastMul"); + + NdArrayDesc<4> desc1; + NdArrayDesc<4> desc2; + NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); + + // In Tensorflow, the dimensions are canonically named (batch_number, row, + // col, channel), with extents (batches, height, width, depth), with the + // trailing dimension changing most rapidly (channels has the smallest stride, + // typically 1 element). + // + // In generated C code, we store arrays with the dimensions reversed. The + // first dimension has smallest stride. + // + // We name our variables by their Tensorflow convention, but generate C code + // nesting loops such that the innermost loop has the smallest stride for the + // best cache behavior. + for (int b = 0; b < ArraySize(output_dims, 3); ++b) { + for (int y = 0; y < ArraySize(output_dims, 2); ++y) { + for (int x = 0; x < ArraySize(output_dims, 1); ++x) { + for (int c = 0; c < ArraySize(output_dims, 0); ++c) { + output_data[Offset(output_dims, c, x, y, b)] = ActivationFunction( + input1_data[SubscriptToIndex(desc1, c, x, y, b)] * + input2_data[SubscriptToIndex(desc2, c, x, y, b)]); + } + } + } + } +} + +inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims, + int32 input1_offset, const uint8* input2_data, + const Dims<4>& input2_dims, int32 input2_offset, + int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("BroadcastMul/8bit"); + + NdArrayDesc<4> desc1; + NdArrayDesc<4> desc2; + NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); + + // In Tensorflow, the dimensions are canonically named (batch_number, row, + // col, channel), with extents (batches, height, width, depth), with the + // trailing dimension changing most rapidly (channels has the smallest stride, + // typically 1 element). + // + // In generated C code, we store arrays with the dimensions reversed. The + // first dimension has smallest stride. + // + // We name our variables by their Tensorflow convention, but generate C code + // nesting loops such that the innermost loop has the smallest stride for the + // best cache behavior. + for (int b = 0; b < ArraySize(output_dims, 3); ++b) { + for (int y = 0; y < ArraySize(output_dims, 2); ++y) { + for (int x = 0; x < ArraySize(output_dims, 1); ++x) { + for (int c = 0; c < ArraySize(output_dims, 0); ++c) { + const int32 input1_val = + input1_offset + input1_data[SubscriptToIndex(desc1, c, x, y, b)]; + const int32 input2_val = + input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)]; + const int32 unclamped_result = + output_offset + + MultiplyByQuantizedMultiplierSmallerThanOne( + input1_val * input2_val, output_multiplier, output_shift); + const int32 clamped_output = + std::min(output_activation_max, + std::max(output_activation_min, unclamped_result)); + output_data[Offset(output_dims, c, x, y, b)] = + static_cast(clamped_output); + } + } + } + } +} + +// legacy, for compatibility with old checked-in code +template +inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims, + int32 input1_offset, const uint8* input2_data, + const Dims<4>& input2_dims, int32 input2_offset, + int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + BroadcastMul(input1_data, input1_dims, input1_offset, input2_data, + input2_dims, input2_offset, output_offset, output_multiplier, + output_shift, output_activation_min, output_activation_max, + output_data, output_dims); +} + +template +void Concatenation(int concat_dim, const Scalar* const* input_data, + const Dims<4>* const* input_dims, int inputs_count, + Scalar* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Concatenation"); + int concat_size = 0; + for (int i = 0; i < inputs_count; i++) { + for (int j = 0; j < 4; j++) { + if (j != concat_dim) { + MatchingArraySize(*input_dims[i], j, output_dims, j); + } + } + concat_size += ArraySize(*input_dims[i], concat_dim); + } + TFLITE_DCHECK_EQ(concat_size, ArraySize(output_dims, concat_dim)); + TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); + // for now we dont have a model with a Concatenation + // with fused activation function. + TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone); + int outer_size = 1; + for (int i = concat_dim + 1; i < 4; i++) { + outer_size *= output_dims.sizes[i]; + } + Scalar* output_ptr = output_data; + for (int k = 0; k < outer_size; k++) { + for (int i = 0; i < inputs_count; ++i) { + const int copy_size = + input_dims[i]->sizes[concat_dim] * input_dims[i]->strides[concat_dim]; + memcpy(output_ptr, input_data[i] + k * copy_size, + copy_size * sizeof(Scalar)); + output_ptr += copy_size; + } + } +} + +template +void DepthConcatenation(const Scalar* const* input_data, + const Dims<4>* const* input_dims, int inputs_count, + Scalar* output_data, const Dims<4>& output_dims) { + Concatenation(0, input_data, input_dims, inputs_count, + output_data, output_dims); +} + +inline void LstmCell(const float* input_data, const Dims<4>& input_dims, + const float* prev_activ_data, + const Dims<4>& prev_activ_dims, const float* weights_data, + const Dims<4>& weights_dims, const float* bias_data, + const Dims<4>& bias_dims, const float* prev_state_data, + const Dims<4>& prev_state_dims, float* output_state_data, + const Dims<4>& output_state_dims, float* output_activ_data, + const Dims<4>& output_activ_dims, float* concat_temp_data, + const Dims<4>& concat_temp_dims, float* activ_temp_data, + const Dims<4>& activ_temp_dims) { + gemmlowp::ScopedProfilingLabel label("LstmCell"); + MatchingArraySize( // batches + input_dims, 3, prev_activ_dims, 3, prev_state_dims, 3, output_state_dims, + 3, output_activ_dims, 3); + MatchingArraySize( // height + input_dims, 2, prev_activ_dims, 2, prev_state_dims, 2, output_state_dims, + 2, output_activ_dims, 2); + MatchingArraySize( // width + input_dims, 1, prev_activ_dims, 1, prev_state_dims, 1, output_state_dims, + 1, output_activ_dims, 1); + TFLITE_CHECK_EQ(ArraySize(weights_dims, 2), 1); + TFLITE_CHECK_EQ(ArraySize(weights_dims, 3), 1); + const int input_depth = ArraySize(input_dims, 0); + const int prev_activ_depth = ArraySize(prev_activ_dims, 0); + const int total_input_depth = prev_activ_depth + input_depth; + TFLITE_CHECK_EQ(ArraySize(weights_dims, 0), total_input_depth); + TFLITE_CHECK_EQ(MatchingArraySize(bias_dims, 1, bias_dims, 2, bias_dims, 3), + 1); + const int intern_activ_depth = + MatchingArraySize(weights_dims, 1, bias_dims, 0); + TFLITE_CHECK_EQ(intern_activ_depth % 4, 0); + const int output_depth = + MatchingArraySize(prev_state_dims, 0, prev_activ_dims, 0, + output_state_dims, 0, output_activ_dims, 0); + TFLITE_CHECK_EQ(output_depth, intern_activ_depth / 4); + + // Concatenate prev_activ and input data together + std::vector concat_input_arrays_data; + std::vector const*> concat_input_arrays_dims; + concat_input_arrays_data.push_back(input_data); + concat_input_arrays_data.push_back(prev_activ_data); + concat_input_arrays_dims.push_back(&input_dims); + concat_input_arrays_dims.push_back(&prev_activ_dims); + Concatenation( + 0, &(concat_input_arrays_data[0]), &(concat_input_arrays_dims[0]), + concat_input_arrays_data.size(), concat_temp_data, concat_temp_dims); + + // Fully connected + FullyConnected( + concat_temp_data, concat_temp_dims, weights_data, weights_dims, bias_data, + bias_dims, activ_temp_data, activ_temp_dims); + + // Map raw arrays to Eigen arrays so we can use Eigen's optimized array + // operations. + ArrayMap activ_temp_map = + MapAsArrayWithFirstDimAsRows(activ_temp_data, activ_temp_dims); + auto input_gate_sm = activ_temp_map.block(0 * output_depth, 0, output_depth, + activ_temp_map.cols()); + auto new_input_sm = activ_temp_map.block(1 * output_depth, 0, output_depth, + activ_temp_map.cols()); + auto forget_gate_sm = activ_temp_map.block(2 * output_depth, 0, output_depth, + activ_temp_map.cols()); + auto output_gate_sm = activ_temp_map.block(3 * output_depth, 0, output_depth, + activ_temp_map.cols()); + ArrayMap prev_state_map = + MapAsArrayWithFirstDimAsRows(prev_state_data, prev_state_dims); + ArrayMap output_state_map = + MapAsArrayWithFirstDimAsRows(output_state_data, output_state_dims); + ArrayMap output_activ_map = + MapAsArrayWithFirstDimAsRows(output_activ_data, output_activ_dims); + + // Combined memory state and final output calculation + gemmlowp::ScopedProfilingLabel label2("MemoryStateAndFinalOutput"); + output_state_map = + input_gate_sm.unaryExpr(Eigen::internal::scalar_sigmoid_op()) * + new_input_sm.tanh() + + forget_gate_sm.unaryExpr(Eigen::internal::scalar_sigmoid_op()) * + prev_state_map; + output_activ_map = + output_gate_sm.unaryExpr(Eigen::internal::scalar_sigmoid_op()) * + output_state_map.tanh(); +} + +template +void TensorFlowSplit(const Scalar* input_data, const Dims<4>& input_dims, + int outputs_count, Scalar* const* output_data, + const Dims<4>* const* output_dims) { + gemmlowp::ScopedProfilingLabel label("TensorFlowSplit"); + TFLITE_DCHECK_GE(outputs_count, 1); + for (int i = 0; i < outputs_count; i++) { + /* batches = */ MatchingArraySize(*output_dims[i], 3, input_dims, 3); + /* height = */ MatchingArraySize(*output_dims[i], 2, input_dims, 2); + /* width = */ MatchingArraySize(*output_dims[i], 1, input_dims, 1); + } + const int batches = MatchingArraySize(*output_dims[0], 3, input_dims, 3); + const int height = MatchingArraySize(*output_dims[0], 2, input_dims, 2); + const int width = MatchingArraySize(*output_dims[0], 1, input_dims, 1); + TFLITE_DCHECK(IsPackedWithoutStrides(input_dims)); + // for now we dont have a model with a TensorFlowSplit + // with fused activation function. + TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone); + const int whb = width * height * batches; + const Scalar* input_ptr = input_data; + for (int k = 0; k < whb; k++) { + for (int i = 0; i < outputs_count; ++i) { + memcpy(output_data[i] + k * output_dims[i]->sizes[0], input_ptr, + output_dims[i]->sizes[0] * sizeof(Scalar)); + input_ptr += output_dims[i]->sizes[0]; + } + } +} + +inline int NodeOffset(int b, int h, int w, int height, int width) { + return (b * height + h) * width + w; +} + +inline void AveragePool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int kwidth, int kheight, + float output_activation_min, + float output_activation_max, float* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("AveragePool"); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + + // TODO(benoitjacob) make this a proper reference impl without Eigen! + const auto in_mat = MapAsMatrixWithFirstDimAsRows(input_data, input_dims); + auto out_mat = MapAsMatrixWithFirstDimAsRows(output_data, output_dims); + // TODO(benoitjacob) get rid of the dynamic memory allocation here! + Eigen::VectorXf out_count(out_mat.cols()); + out_count.setZero(); + // Prefill the output to 0. + out_mat.setZero(); + for (int b = 0; b < batches; ++b) { + for (int h = 0; h < input_height; ++h) { + for (int w = 0; w < input_width; ++w) { + // (h_start, h_end) * (w_start, w_end) is the range that the input + // vector projects to. + int hpad = h + pad_height; + int wpad = w + pad_width; + int h_start = + (hpad < kheight) ? 0 : (hpad - kheight) / stride_height + 1; + int h_end = std::min(hpad / stride_height + 1, output_height); + int w_start = (wpad < kwidth) ? 0 : (wpad - kwidth) / stride_width + 1; + int w_end = std::min(wpad / stride_width + 1, output_width); + // compute elementwise sum + for (int ph = h_start; ph < h_end; ++ph) { + for (int pw = w_start; pw < w_end; ++pw) { + int out_offset = NodeOffset(b, ph, pw, output_height, output_width); + out_mat.col(out_offset) += + in_mat.col(NodeOffset(b, h, w, input_height, input_width)); + out_count(out_offset)++; + } + } + } + } + } + // Divide the output by the actual number of elements being averaged over + TFLITE_DCHECK_GT(out_count.minCoeff(), 0); + out_mat.array().rowwise() /= out_count.transpose().array(); + + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < output_height; ++y) { + for (int x = 0; x < output_width; ++x) { + for (int c = 0; c < depth; ++c) { + output_data[Offset(output_dims, c, x, y, b)] = + ActivationFunctionWithMinMax( + output_data[Offset(output_dims, c, x, y, b)], + output_activation_min, output_activation_max); + } + } + } + } +} + +// legacy, for compatibility with old checked-in code +template +void AveragePool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int kwidth, int kheight, float* output_data, + const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + + AveragePool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, kwidth, kheight, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void AveragePool(const float* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, + int filter_height, float* output_data, + const Dims<4>& output_dims) { + AveragePool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_data, output_dims); +} + +inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int filter_width, int filter_height, + int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("AveragePool/8bit"); + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + for (int batch = 0; batch < batches; ++batch) { + for (int out_y = 0; out_y < output_height; ++out_y) { + for (int out_x = 0; out_x < output_width; ++out_x) { + const int in_x_origin = (out_x * stride_width) - pad_width; + const int in_y_origin = (out_y * stride_height) - pad_height; + const int filter_x_start = std::max(0, -in_x_origin); + const int filter_x_end = + std::min(filter_width, input_width - in_x_origin); + const int filter_y_start = std::max(0, -in_y_origin); + const int filter_y_end = + std::min(filter_height, input_height - in_y_origin); + const int filter_count = + (filter_x_end - filter_x_start) * (filter_y_end - filter_y_start); + // 1280 required by Inception v3 + static constexpr int kAccBufferMaxSize = 2048; + TFLITE_DCHECK_LE(depth, kAccBufferMaxSize); + uint16 acc[kAccBufferMaxSize]; + memset(acc, 0, depth * sizeof(acc[0])); + const uint8* input_ptr = + input_data + input_dims.strides[1] * in_x_origin + + input_dims.strides[2] * in_y_origin + input_dims.strides[3] * batch; + for (int fy = filter_y_start; fy < filter_y_end; fy++) { + const uint8* input_row_ptr = input_ptr + fy * input_dims.strides[2] + + filter_x_start * input_dims.strides[1]; + for (int fx = filter_x_start; fx < filter_x_end; fx++) { + int channel = 0; +#ifdef USE_NEON + for (; channel <= depth - 16; channel += 16) { + uint16x8_t acc_reg[2]; + for (int i = 0; i < 2; i++) { + acc_reg[i] = vld1q_u16(acc + channel + 8 * i); + } + uint8x16_t input_reg = vld1q_u8(input_row_ptr); + input_row_ptr += 16; + acc_reg[0] = vaddw_u8(acc_reg[0], vget_low_u8(input_reg)); + acc_reg[1] = vaddw_u8(acc_reg[1], vget_high_u8(input_reg)); + for (int i = 0; i < 2; i++) { + vst1q_u16(acc + channel + 8 * i, acc_reg[i]); + } + } + for (; channel <= depth - 8; channel += 8) { + uint16x8_t acc_reg = vld1q_u16(acc + channel); + uint8x8_t input_reg = vld1_u8(input_row_ptr); + input_row_ptr += 8; + acc_reg = vaddw_u8(acc_reg, input_reg); + vst1q_u16(acc + channel, acc_reg); + } +#endif + for (; channel < depth; ++channel) { + acc[channel] += *input_row_ptr++; + } + } + } + uint8* output_ptr = + output_data + Offset(output_dims, 0, out_x, out_y, batch); + int channel = 0; +#ifdef USE_NEON +#define AVGPOOL_DIVIDING_BY(FILTER_COUNT) \ + if (filter_count == FILTER_COUNT) { \ + for (; channel <= depth - 8; channel += 8) { \ + uint16 buf[8]; \ + for (int i = 0; i < 8; i++) { \ + buf[i] = (acc[channel + i] + FILTER_COUNT / 2) / FILTER_COUNT; \ + } \ + uint8x8_t buf8 = vqmovn_u16(vld1q_u16(buf)); \ + buf8 = vmin_u8(buf8, vdup_n_u8(output_activation_max)); \ + buf8 = vmax_u8(buf8, vdup_n_u8(output_activation_min)); \ + vst1_u8(output_ptr + channel, buf8); \ + } \ + } + AVGPOOL_DIVIDING_BY(9) + AVGPOOL_DIVIDING_BY(15) +#undef AVGPOOL_DIVIDING_BY + for (; channel <= depth - 8; channel += 8) { + uint16 buf[8]; + for (int i = 0; i < 8; i++) { + buf[i] = (acc[channel + i] + filter_count / 2) / filter_count; + } + uint8x8_t buf8 = vqmovn_u16(vld1q_u16(buf)); + buf8 = vmin_u8(buf8, vdup_n_u8(output_activation_max)); + buf8 = vmax_u8(buf8, vdup_n_u8(output_activation_min)); + vst1_u8(output_ptr + channel, buf8); + } +#endif + for (; channel < depth; ++channel) { + uint16 a = (acc[channel] + filter_count / 2) / filter_count; + a = std::max(a, output_activation_min); + a = std::min(a, output_activation_max); + output_ptr[channel] = static_cast(a); + } + } + } + } +} + +// legacy, for compatibility with old checked-in code +template +void AveragePool(const uint8* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int filter_width, int filter_height, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + AveragePool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void AveragePool(const uint8* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, + int filter_height, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + AveragePool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +inline void MaxPool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int kwidth, int kheight, + float output_activation_min, float output_activation_max, + float* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("MaxPool"); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + + const auto in_mat = MapAsMatrixWithFirstDimAsRows(input_data, input_dims); + auto out_mat = MapAsMatrixWithFirstDimAsRows(output_data, output_dims); + // Prefill the output to minimum representable float value + out_mat.setConstant(std::numeric_limits::lowest()); + for (int b = 0; b < batches; ++b) { + for (int h = 0; h < input_height; ++h) { + for (int w = 0; w < input_width; ++w) { + // (h_start, h_end) * (w_start, w_end) is the range that the input + // vector projects to. + int hpad = h + pad_height; + int wpad = w + pad_width; + int h_start = + (hpad < kheight) ? 0 : (hpad - kheight) / stride_height + 1; + int h_end = std::min(hpad / stride_height + 1, output_height); + int w_start = (wpad < kwidth) ? 0 : (wpad - kwidth) / stride_width + 1; + int w_end = std::min(wpad / stride_width + 1, output_width); + // compute elementwise sum + for (int ph = h_start; ph < h_end; ++ph) { + for (int pw = w_start; pw < w_end; ++pw) { + int out_offset = NodeOffset(b, ph, pw, output_height, output_width); + out_mat.col(out_offset) = + out_mat.col(out_offset) + .cwiseMax(in_mat.col( + NodeOffset(b, h, w, input_height, input_width))); + } + } + } + } + } + + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < output_height; ++y) { + for (int x = 0; x < output_width; ++x) { + for (int c = 0; c < depth; ++c) { + output_data[Offset(output_dims, c, x, y, b)] = + ActivationFunctionWithMinMax( + output_data[Offset(output_dims, c, x, y, b)], + output_activation_min, output_activation_max); + } + } + } + } +} + +// legacy, for compatibility with old checked-in code +template +void MaxPool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, int pad_height, + int kwidth, int kheight, float* output_data, + const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + MaxPool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, kwidth, kheight, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void MaxPool(const float* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, int filter_height, + float* output_data, const Dims<4>& output_dims) { + MaxPool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_data, output_dims); +} + +inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int filter_width, int filter_height, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("MaxPool/8bit"); + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + for (int batch = 0; batch < batches; ++batch) { + for (int out_y = 0; out_y < output_height; ++out_y) { + for (int out_x = 0; out_x < output_width; ++out_x) { + const int in_x_origin = (out_x * stride_width) - pad_width; + const int in_y_origin = (out_y * stride_height) - pad_height; + const int filter_x_start = std::max(0, -in_x_origin); + const int filter_x_end = + std::min(filter_width, input_width - in_x_origin); + const int filter_y_start = std::max(0, -in_y_origin); + const int filter_y_end = + std::min(filter_height, input_height - in_y_origin); + // 2048 required by Inception v3 + static constexpr int kAccBufferMaxSize = 2048; + TFLITE_DCHECK_LE(depth, kAccBufferMaxSize); + uint8 acc[kAccBufferMaxSize]; + memset(acc, 0, depth * sizeof(acc[0])); + const uint8* input_ptr = + input_data + input_dims.strides[1] * in_x_origin + + input_dims.strides[2] * in_y_origin + input_dims.strides[3] * batch; + for (int fy = filter_y_start; fy < filter_y_end; fy++) { + const uint8* input_row_ptr = input_ptr + fy * input_dims.strides[2] + + filter_x_start * input_dims.strides[1]; + for (int fx = filter_x_start; fx < filter_x_end; fx++) { + int channel = 0; +#ifdef USE_NEON + for (; channel <= depth - 16; channel += 16) { + uint8x16_t acc_reg = vld1q_u8(acc + channel); + uint8x16_t input_reg = vld1q_u8(input_row_ptr); + input_row_ptr += 16; + acc_reg = vmaxq_u8(acc_reg, input_reg); + vst1q_u8(acc + channel, acc_reg); + } + + for (; channel <= depth - 8; channel += 8) { + uint8x8_t acc_reg = vld1_u8(acc + channel); + uint8x8_t input_reg = vld1_u8(input_row_ptr); + input_row_ptr += 8; + acc_reg = vmax_u8(acc_reg, input_reg); + vst1_u8(acc + channel, acc_reg); + } +#endif + for (; channel < depth; ++channel) { + acc[channel] = std::max(acc[channel], *input_row_ptr++); + } + } + } + uint8* output_ptr = + output_data + Offset(output_dims, 0, out_x, out_y, batch); + int channel = 0; +#ifdef USE_NEON + for (; channel <= depth - 16; channel += 16) { + uint8x16_t a = vld1q_u8(acc + channel); + a = vminq_u8(a, vdupq_n_u8(output_activation_max)); + a = vmaxq_u8(a, vdupq_n_u8(output_activation_min)); + vst1q_u8(output_ptr + channel, a); + } + for (; channel <= depth - 8; channel += 8) { + uint8x8_t a = vld1_u8(acc + channel); + a = vmin_u8(a, vdup_n_u8(output_activation_max)); + a = vmax_u8(a, vdup_n_u8(output_activation_min)); + vst1_u8(output_ptr + channel, a); + } +#endif + for (; channel < depth; ++channel) { + uint8 a = acc[channel]; + a = std::max(a, output_activation_min); + a = std::min(a, output_activation_max); + output_ptr[channel] = static_cast(a); + } + } + } + } +} + +// legacy, for compatibility with old checked-in code +template +void MaxPool(const uint8* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, int pad_height, + int filter_width, int filter_height, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + MaxPool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void MaxPool(const uint8* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, int filter_height, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + MaxPool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +inline void L2Pool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int filter_width, int filter_height, + float output_activation_min, float output_activation_max, + float* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("L2Pool"); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + // Actually carry out L2 Pool. Code is written in forward mode: we go through + // the input values once, and write to all the pooled regions that it maps to. + const auto in_mat = MapAsMatrixWithFirstDimAsRows(input_data, input_dims); + auto out_mat = MapAsMatrixWithFirstDimAsRows(output_data, output_dims); + Eigen::VectorXf in_square(in_mat.rows()); + Eigen::VectorXf out_count(out_mat.cols()); + out_count.setZero(); + // Prefill the output to 0. + out_mat.setZero(); + for (int b = 0; b < batches; ++b) { + for (int h = 0; h < input_height; ++h) { + for (int w = 0; w < input_width; ++w) { + // (h_start, h_end) * (w_start, w_end) is the range that the input + // vector projects to. + const int hpad = h + pad_height; + const int wpad = w + pad_width; + const int h_start = (hpad < filter_height) + ? 0 + : (hpad - filter_height) / stride_height + 1; + const int h_end = std::min(hpad / stride_height + 1, output_height); + const int w_start = (wpad < filter_width) + ? 0 + : (wpad - filter_width) / stride_width + 1; + const int w_end = std::min(wpad / stride_width + 1, output_width); + // pre-compute square + const int in_offset = w + input_width * (h + input_height * b); + in_square = + in_mat.col(in_offset).array() * in_mat.col(in_offset).array(); + // compute elementwise sum of squares + for (int ph = h_start; ph < h_end; ++ph) { + for (int pw = w_start; pw < w_end; ++pw) { + const int out_offset = pw + output_width * (ph + output_height * b); + out_mat.col(out_offset) += in_square; + out_count(out_offset)++; + } + } + } + } + } + + out_count = out_count.array().inverse(); + out_mat = + (out_mat.array().rowwise() * out_count.transpose().array()).cwiseSqrt(); +} + +// legacy, for compatibility with old checked-in code +template +void L2Pool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, int pad_height, + int filter_width, int filter_height, float* output_data, + const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + L2Pool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void L2Pool(const float* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, int filter_height, + float* output_data, const Dims<4>& output_dims) { + L2Pool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_data, output_dims); +} + +inline void LocalResponseNormalization(const float* input_data, + const Dims<4>& input_dims, int range, + float bias, float alpha, float beta, + float* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("LocalResponseNormalization"); + /* const int batches = */ MatchingArraySize(input_dims, 3, output_dims, 3); + /* const int height = */ MatchingArraySize(input_dims, 2, output_dims, 2); + /* const int width = */ MatchingArraySize(input_dims, 1, output_dims, 1); + /* const int depth = */ MatchingArraySize(input_dims, 0, output_dims, 0); + + const auto data_in = MapAsMatrixWithFirstDimAsRows(input_data, input_dims); + auto data_out = MapAsMatrixWithFirstDimAsRows(output_data, output_dims); + + // Carry out local response normalization, vector by vector. + // Since the data are stored column major, making row-wise operation + // probably not memory efficient anyway, we do an explicit for loop over + // the columns. + const int double_range = range * 2; + Eigen::VectorXf padded_square(data_in.rows() + double_range); + padded_square.setZero(); + for (int r = 0; r < data_in.cols(); ++r) { + // Do local response normalization for data_in(:, r) + // first, compute the square and store them in buffer for repeated use + padded_square.block(range, 0, data_in.rows(), 1) = + data_in.col(r).cwiseProduct(data_in.col(r)) * alpha; + // Then, compute the scale and writes them to data_out + float accumulated_scale = 0; + for (int i = 0; i < double_range; ++i) { + accumulated_scale += padded_square(i); + } + for (int i = 0; i < data_in.rows(); ++i) { + accumulated_scale += padded_square(i + double_range); + data_out(i, r) = bias + accumulated_scale; + accumulated_scale -= padded_square(i); + } + } + + // In a few cases, the pow computation could benefit from speedups. + if (beta == 1) { + data_out.array() = data_in.array() * data_out.array().inverse(); + } else if (beta == 0.5) { + data_out.array() = data_in.array() * data_out.array().sqrt().inverse(); + } else { + data_out.array() = data_in.array() * data_out.array().pow(-beta); + } +} + +inline void Softmax(const float* input_data, const Dims<4>& input_dims, + float beta, float* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Softmax"); + /* const int batches = */ MatchingArraySize(input_dims, 3, output_dims, 3); + /* const int height = */ MatchingArraySize(input_dims, 2, output_dims, 2); + /* const int width = */ MatchingArraySize(input_dims, 1, output_dims, 1); + /* const int depth = */ MatchingArraySize(input_dims, 0, output_dims, 0); + + const auto in_mat = MapAsMatrixWithFirstDimAsRows(input_data, input_dims); + auto out_mat = MapAsMatrixWithFirstDimAsRows(output_data, output_dims); + // Compute the exponential first, removing the max coefficient for numerical + // stability. + out_mat = (in_mat.rowwise() - in_mat.colwise().maxCoeff()).array() * beta; + // We are separating out the exp function so that exp can be vectorized. + out_mat = out_mat.array().exp(); + // Normalize to get the activations. + Eigen::Array scale = + out_mat.array().colwise().sum().inverse(); + out_mat.array().rowwise() *= scale; +} + +inline void Softmax(const uint8* input_data, const Dims<4>& input_dims, + int32 input_beta_multiplier, int32 input_beta_left_shift, + int diff_min, uint8* output_data, + const Dims<4>& output_dims) { + // The representation chosen for the input to the exp() function is Q5.26. + // We need to leave extra space since values that we skip might be as large as + // -32 before multiplying by input_beta_multiplier, and therefore as large as + // -16 afterwards. Note that exp(-8) is definitely not insignificant to + // accumulation, but exp(-16) definitely is. + static const int kScaledDiffIntegerBits = 5; + static const int kAccumulationIntegerBits = 12; + using FixedPointScaledDiff = + gemmlowp::FixedPoint; + using FixedPointAccum = gemmlowp::FixedPoint; + using FixedPoint0 = gemmlowp::FixedPoint; + + gemmlowp::ScopedProfilingLabel label("Softmax"); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + + for (int b = 0; b < batches; ++b) { + for (int x = 0; x < width; ++x) { + for (int y = 0; y < height; ++y) { + uint8 max_in_row = 0; + for (int c = 0; c < depth; ++c) { + max_in_row = + std::max(max_in_row, input_data[Offset(input_dims, c, x, y, b)]); + } + + FixedPointAccum sum_of_exps = FixedPointAccum::Zero(); + for (int c = 0; c < depth; ++c) { + int32 input_diff = + static_cast(input_data[Offset(input_dims, c, x, y, b)]) - + max_in_row; + if (input_diff >= diff_min) { + const int32 input_diff_rescaled = + MultiplyByQuantizedMultiplierGreaterThanOne( + input_diff, input_beta_multiplier, input_beta_left_shift); + const FixedPointScaledDiff scaled_diff_f8 = + FixedPointScaledDiff::FromRaw(input_diff_rescaled); + sum_of_exps = + sum_of_exps + gemmlowp::Rescale( + exp_on_negative_values(scaled_diff_f8)); + } + } + + int32 fixed_sum_of_exps = sum_of_exps.raw(); + // TODO(starka): Use a NEON intrinsic like vclzq_u32 instead. + int headroom_plus_one = + __builtin_clz(static_cast(fixed_sum_of_exps)); + // This is the number of bits to the left of the binary point above 1.0. + // Consider fixed_sum_of_exps=1.25. In that case shifted_scale=0.8 and + // no later adjustment will be needed. + int num_bits_over_unit = kAccumulationIntegerBits - headroom_plus_one; + int32 shifted_sum_minus_one = static_cast( + (static_cast(fixed_sum_of_exps) << headroom_plus_one) - + (static_cast(1) << 31)); + + FixedPoint0 shifted_scale = gemmlowp::one_over_one_plus_x_for_x_in_0_1( + FixedPoint0::FromRaw(shifted_sum_minus_one)); + + for (int c = 0; c < depth; ++c) { + int32 input_diff = + static_cast(input_data[Offset(input_dims, c, x, y, b)]) - + max_in_row; + if (input_diff >= diff_min) { + const int32 input_diff_rescaled = + MultiplyByQuantizedMultiplierGreaterThanOne( + input_diff, input_beta_multiplier, input_beta_left_shift); + const FixedPointScaledDiff scaled_diff_f8 = + FixedPointScaledDiff::FromRaw(input_diff_rescaled); + + FixedPoint0 exp_in_0 = exp_on_negative_values(scaled_diff_f8); + int32 unsat_output = gemmlowp::RoundingDivideByPOT( + (shifted_scale * exp_in_0).raw(), num_bits_over_unit + 31 - 8); + + output_data[Offset(output_dims, c, x, y, b)] = + std::max(std::min(unsat_output, 255), 0); + + } else { + output_data[Offset(output_dims, c, x, y, b)] = 0; + } + } + } + } + } +} + +inline void Logistic(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Logistic"); + auto input_map = MapAsVector(input_data, input_dims); + auto output_map = MapAsVector(output_data, output_dims); + output_map.array() = + input_map.array().unaryExpr(Eigen::internal::scalar_sigmoid_op()); +} + +inline void Logistic(const uint8* input_data, const Dims<4>& input_dims, + int32 input_zero_point, int32 input_range_radius, + int32 input_multiplier, int input_left_shift, + uint8* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Logistic"); + /* batches */ MatchingArraySize(input_dims, 3, output_dims, 3); + /* height */ MatchingArraySize(input_dims, 2, output_dims, 2); + /* width */ MatchingArraySize(input_dims, 1, output_dims, 1); + /* depth */ MatchingArraySize(input_dims, 0, output_dims, 0); + const int size = RequiredBufferSizeForDims(input_dims); + + int c = 0; +#ifdef USE_NEON + // Handle 16 values at a time + for (; c <= size - 16; c += 16) { + // Read input uint8 values, cast to int16 and subtract input_zero_point + uint8x16_t input_val_u8 = vld1q_u8(input_data + c); + int16x8_t input_val_centered_0 = + vsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(input_val_u8))), + vdupq_n_s16(input_zero_point)); + int16x8_t input_val_centered_1 = + vsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(input_val_u8))), + vdupq_n_s16(input_zero_point)); + + // Prepare the bit masks that we will use at the end to implement the logic + // that was expressed in the scalar code with branching: + // if (input_val_centered < -input_range_radius) { + // output_val = 0; + // } else if (input_val_centered > input_range_radius) { + // output_val = 255; + // } else { + // ... + uint16x8_t mask_rightclamp_0 = + vcgtq_s16(input_val_centered_0, vdupq_n_s16(input_range_radius)); + uint16x8_t mask_rightclamp_1 = + vcgtq_s16(input_val_centered_1, vdupq_n_s16(input_range_radius)); + uint16x8_t mask_leftclamp_0 = + vcgeq_s16(input_val_centered_0, vdupq_n_s16(-input_range_radius)); + uint16x8_t mask_leftclamp_1 = + vcgeq_s16(input_val_centered_1, vdupq_n_s16(-input_range_radius)); + uint8x16_t mask_rightclamp = vcombine_u8(vshrn_n_u16(mask_rightclamp_0, 8), + vshrn_n_u16(mask_rightclamp_1, 8)); + uint8x16_t mask_leftclamp = vcombine_u8(vshrn_n_u16(mask_leftclamp_0, 8), + vshrn_n_u16(mask_leftclamp_1, 8)); + + // This performs what is expressed in the scalar code as + // const int32 input_val_rescaled = + // MultiplyByQuantizedMultiplierGreaterThanOne( + // input_val_centered, input_multiplier, input_left_shift); + int32x4_t input_val_rescaled_0 = + vshlq_s32(vmovl_s16(vget_low_s16(input_val_centered_0)), + vdupq_n_s32(input_left_shift)); + int32x4_t input_val_rescaled_1 = + vshlq_s32(vmovl_s16(vget_high_s16(input_val_centered_0)), + vdupq_n_s32(input_left_shift)); + int32x4_t input_val_rescaled_2 = + vshlq_s32(vmovl_s16(vget_low_s16(input_val_centered_1)), + vdupq_n_s32(input_left_shift)); + int32x4_t input_val_rescaled_3 = + vshlq_s32(vmovl_s16(vget_high_s16(input_val_centered_1)), + vdupq_n_s32(input_left_shift)); + input_val_rescaled_0 = + vqrdmulhq_n_s32(input_val_rescaled_0, input_multiplier); + input_val_rescaled_1 = + vqrdmulhq_n_s32(input_val_rescaled_1, input_multiplier); + input_val_rescaled_2 = + vqrdmulhq_n_s32(input_val_rescaled_2, input_multiplier); + input_val_rescaled_3 = + vqrdmulhq_n_s32(input_val_rescaled_3, input_multiplier); + + // Invoke gemmlowp::logistic on FixedPoint wrapping int32x4_t + using FixedPoint4 = gemmlowp::FixedPoint; + using FixedPoint0 = gemmlowp::FixedPoint; + const FixedPoint4 input_val_f4_0 = + FixedPoint4::FromRaw(input_val_rescaled_0); + const FixedPoint4 input_val_f4_1 = + FixedPoint4::FromRaw(input_val_rescaled_1); + const FixedPoint4 input_val_f4_2 = + FixedPoint4::FromRaw(input_val_rescaled_2); + const FixedPoint4 input_val_f4_3 = + FixedPoint4::FromRaw(input_val_rescaled_3); + const FixedPoint0 output_val_f0_0 = gemmlowp::logistic(input_val_f4_0); + const FixedPoint0 output_val_f0_1 = gemmlowp::logistic(input_val_f4_1); + const FixedPoint0 output_val_f0_2 = gemmlowp::logistic(input_val_f4_2); + const FixedPoint0 output_val_f0_3 = gemmlowp::logistic(input_val_f4_3); + + // Divide by 2^23 as in the scalar code + using gemmlowp::RoundingDivideByPOT; + int32x4_t output_val_s32_0 = RoundingDivideByPOT(output_val_f0_0.raw(), 23); + int32x4_t output_val_s32_1 = RoundingDivideByPOT(output_val_f0_1.raw(), 23); + int32x4_t output_val_s32_2 = RoundingDivideByPOT(output_val_f0_2.raw(), 23); + int32x4_t output_val_s32_3 = RoundingDivideByPOT(output_val_f0_3.raw(), 23); + + // Cast output values to uint8, saturating + int16x8_t output_val_s16_0 = vcombine_s16(vqmovn_s32(output_val_s32_0), + vqmovn_s32(output_val_s32_1)); + int16x8_t output_val_s16_1 = vcombine_s16(vqmovn_s32(output_val_s32_2), + vqmovn_s32(output_val_s32_3)); + uint8x16_t output_val_u8 = vcombine_u8(vqmovun_s16(output_val_s16_0), + vqmovun_s16(output_val_s16_1)); + + // Perform the bit-masking with the bit masks computed at the beginning, + // see the comment there. + output_val_u8 = vorrq_u8(output_val_u8, mask_rightclamp); + output_val_u8 = vandq_u8(output_val_u8, mask_leftclamp); + + // Store back to memory + vst1q_u8(output_data + c, output_val_u8); + } +#endif + // Leftover loop: handle one value at a time with scalar code. + for (; c < size; ++c) { + const uint8 input_val_u8 = input_data[c]; + const int32 input_val_centered = + static_cast(input_val_u8) - input_zero_point; + uint8 output_val; + if (input_val_centered < -input_range_radius) { + output_val = 0; + } else if (input_val_centered > input_range_radius) { + output_val = 255; + } else { + const int32 input_val_rescaled = + MultiplyByQuantizedMultiplierGreaterThanOne( + input_val_centered, input_multiplier, input_left_shift); + using FixedPoint4 = gemmlowp::FixedPoint; + using FixedPoint0 = gemmlowp::FixedPoint; + const FixedPoint4 input_val_f4 = FixedPoint4::FromRaw(input_val_rescaled); + const FixedPoint0 output_val_f0 = gemmlowp::logistic(input_val_f4); + using gemmlowp::RoundingDivideByPOT; + int32 output_val_s32 = RoundingDivideByPOT(output_val_f0.raw(), 23); + if (output_val_s32 == 256) { + output_val_s32 = 255; + } + TFLITE_DCHECK_GE(output_val_s32, 0); + TFLITE_DCHECK_LE(output_val_s32, 255); + output_val = static_cast(output_val_s32); + } + output_data[c] = output_val; + } +} + +inline void Tanh(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Tanh"); + auto input_map = MapAsVector(input_data, input_dims); + auto output_map = MapAsVector(output_data, output_dims); + output_map.array() = input_map.array().tanh(); +} + +inline void Dequantize(const uint8* input_data, const Dims<4>& input_dims, + int32 zero_point, double scale, float* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Dequantize"); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + int32 val = input_data[Offset(input_dims, c, x, y, b)]; + float result = static_cast(scale * (val - zero_point)); + output_data[Offset(output_dims, c, x, y, b)] = result; + } + } + } + } +} + +inline void FakeQuant(const float* input_data, const Dims<4>& input_dims, + float rmin, float rmax, float* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("FakeQuant"); + + // 0 should always be a representable value. Let's assume that the initial + // min,max range contains 0. + TFLITE_DCHECK_LE(rmin, 0.); + TFLITE_DCHECK_GE(rmax, 0.); + + // Determine quantization parameters: zero_point, scale. + using Integer = uint8; + const Integer qmin = std::numeric_limits::min(); + const Integer qmax = std::numeric_limits::max(); + const float qmin_float = qmin; + const float qmax_float = qmax; + int32 zero_point = 0; + float scale = 0.f; + // If rmin==rmax, both must be zero per the above assertion, + // so we are done. + if (rmin != rmax) { + // First determine the scale. + scale = (rmax - rmin) / (qmax_float - qmin_float); + + // Zero-point computation. + // First the initial floating-point computation. The zero-point can be + // determined from solving an affine equation for any known pair + // (real value, corresponding quantized value). + // We know two such pairs: (rmin, qmin) and (rmax, qmax). + // The arithmetic error on the zero point computed from either pair + // will be roughly machine_epsilon * (sum of absolute values of terms) + // so we want to use the variant that adds the smaller terms. + const float zero_point_from_min = qmin_float - rmin / scale; + const float zero_point_from_max = qmax_float - rmax / scale; + const float zero_point_from_min_error = + std::abs(qmin_float) + std::abs(rmin / scale); + const float zero_point_from_max_error = + std::abs(qmax_float) + std::abs(rmax / scale); + + const float zero_point_float = + zero_point_from_min_error < zero_point_from_max_error + ? zero_point_from_min + : zero_point_from_max; + + // Now we need to nudge the zero point to be an integer + // (our zero points are integer, and this is motivated by the requirement + // to be able to represent the real value "0" exactly as a quantized value, + // which is required in multiple places, for example in Im2col with SAME + // padding). + if (zero_point_float < qmin_float) { + zero_point = qmin; + } else if (zero_point_float > qmax_float) { + zero_point = qmax; + } else { + zero_point = static_cast(TfLiteRound(zero_point_float)); + } + // The zero point should always be in the range of quantized value, + // [qmin, qmax]. + TFLITE_DCHECK_GE(zero_point, qmin); + TFLITE_DCHECK_LE(zero_point, qmax); + } + + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + const float src_val = input_data[Offset(input_dims, c, x, y, b)]; + const float unclamped_quantized_val = + TfLiteRound(zero_point + src_val / scale); + const float quantized_val = std::min( + qmax_float, std::max(qmin_float, unclamped_quantized_val)); + const float dst_val = scale * (quantized_val - zero_point); + output_data[Offset(output_dims, c, x, y, b)] = dst_val; + } + } + } + } +} + +template +inline void Cast(const SrcT* input_data, const Dims<4>& input_dims, + DstT* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Cast"); + auto input_map = MapAsVector(input_data, input_dims); + auto output_map = MapAsVector(output_data, output_dims); + output_map.array() = input_map.array().template cast(); +} + +inline void Floor(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Floor"); + auto input_map = MapAsVector(input_data, input_dims); + auto output_map = MapAsVector(output_data, output_dims); + output_map.array() = Eigen::floor(input_map.array()); +} + +template +inline void Gather(const T* input_data, const Dims<4>& input_dims, + int input_rank, const int32* coords_data, + const Dims<4>& coords_dims, T* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Gather"); + + TFLITE_DCHECK(coords_dims.sizes[0] == output_dims.sizes[input_rank - 1]); + int stride = input_dims.strides[input_rank - 1]; + T* out = output_data; + + for (int i = 0; i < coords_dims.sizes[0]; i++) { + TFLITE_DCHECK_GE(coords_data[i], 0); + TFLITE_DCHECK_LT(coords_data[i], input_dims.sizes[input_rank - 1]); + const T* in = input_data + coords_data[i] * stride; + memcpy(out, in, sizeof(T) * stride); + out += stride; + } +} + +#ifdef USE_NEON +inline void ResizeBilinearKernel(const float* input_ptr, int32 depth, + float scale, float* output_ptr) { + int ic = 0; + // Handle 32 input channels at a time. + for (; ic <= depth - 32; ic += 32) { + float32x4x2_t input[4]; + for (int i = 0; i < 4; i++) { + input[i].val[0] = vld1q_f32(input_ptr + 8 * i); + input[i].val[1] = vld1q_f32(input_ptr + 8 * i + 4); + } + float32x4x2_t acc[4]; + for (int i = 0; i < 4; i++) { + acc[i].val[0] = vld1q_f32(output_ptr + 8 * i); + acc[i].val[1] = vld1q_f32(output_ptr + 8 * i + 4); + } + for (int i = 0; i < 4; i++) { + acc[i].val[0] = vmlaq_n_f32(acc[i].val[0], input[i].val[0], scale); + acc[i].val[1] = vmlaq_n_f32(acc[i].val[1], input[i].val[1], scale); + } + for (int i = 0; i < 4; i++) { + vst1q_f32(output_ptr, acc[i].val[0]); + vst1q_f32(output_ptr + 4, acc[i].val[1]); + output_ptr += 8; + } + input_ptr += 32; + } + // Handle 16 input channels at a time. + for (; ic <= depth - 16; ic += 16) { + float32x4x2_t input[2]; + for (int i = 0; i < 2; i++) { + input[i].val[0] = vld1q_f32(input_ptr + 8 * i); + input[i].val[1] = vld1q_f32(input_ptr + 8 * i + 4); + } + float32x4x2_t acc[2]; + for (int i = 0; i < 2; i++) { + acc[i].val[0] = vld1q_f32(output_ptr + 8 * i); + acc[i].val[1] = vld1q_f32(output_ptr + 8 * i + 4); + } + for (int i = 0; i < 2; i++) { + acc[i].val[0] = vmlaq_n_f32(acc[i].val[0], input[i].val[0], scale); + acc[i].val[1] = vmlaq_n_f32(acc[i].val[1], input[i].val[1], scale); + } + for (int i = 0; i < 2; i++) { + vst1q_f32(output_ptr, acc[i].val[0]); + vst1q_f32(output_ptr + 4, acc[i].val[1]); + output_ptr += 8; + } + input_ptr += 16; + } + // Handle 8 input channels at a time. + for (; ic <= depth - 8; ic += 8) { + float32x4x2_t input; + input.val[0] = vld1q_f32(input_ptr); + input.val[1] = vld1q_f32(input_ptr + 4); + + float32x4x2_t acc; + acc.val[0] = vld1q_f32(output_ptr); + acc.val[1] = vld1q_f32(output_ptr + 4); + acc.val[0] = vmlaq_n_f32(acc.val[0], input.val[0], scale); + acc.val[1] = vmlaq_n_f32(acc.val[1], input.val[1], scale); + + vst1q_f32(output_ptr, acc.val[0]); + vst1q_f32(output_ptr + 4, acc.val[1]); + + input_ptr += 8; + output_ptr += 8; + } + // Handle 4 input channels at a time. + for (; ic <= depth - 4; ic += 4) { + float32x4_t input = vld1q_f32(input_ptr); + float32x4_t acc = vld1q_f32(output_ptr); + + acc = vmlaq_n_f32(acc, input, scale); + vst1q_f32(output_ptr, acc); + + input_ptr += 4; + output_ptr += 4; + } + // Handle 1 input channel at a time. + for (; ic < depth; ic++) { + *output_ptr += *input_ptr * scale; + output_ptr++; + input_ptr++; + } +} +#else +inline void ResizeBilinearKernel(const float* input_ptr, int32 depth, + float scale, float* output_ptr) { + for (int32 i = 0; i < depth; i++) { + *output_ptr += *input_ptr * scale; + output_ptr++; + input_ptr++; + } +} +#endif + +inline void ResizeBilinearKernel2x2(int32 x0, int32 x1, int32 y0, int32 y1, + int32 x, int32 y, int32 depth, int32 batch, + const float* input_data, + const Dims<4>& input_dims, + float* output_data, + const Dims<4>& output_dims) { + const int32 input_width = ArraySize(input_dims, 1); + const int32 output_width = ArraySize(output_dims, 1); + + const int32 input_x_offset = (x1 - x0) * depth; + const int32 input_y_offset = (y1 - y0) * depth * input_width; + const int32 output_x_offset = depth; + const int32 output_y_offset = depth * output_width; + +#ifdef USE_NEON + TFLITE_DCHECK(IsPackedWithoutStrides(input_dims)); + TFLITE_DCHECK(x1 >= x0); + TFLITE_DCHECK(y1 >= y0); + + int ic = 0; + // Handle 8 input channels at a time. + for (; ic <= depth - 8; ic += 8) { + const float* input_ptr = nullptr; + + float32x4x2_t x0y0; + input_ptr = &input_data[Offset(input_dims, ic, x0, y0, batch)]; + x0y0.val[0] = vld1q_f32(input_ptr); + x0y0.val[1] = vld1q_f32(input_ptr + 4); + + float32x4x2_t x1y0; + input_ptr += input_x_offset; + x1y0.val[0] = vld1q_f32(input_ptr); + x1y0.val[1] = vld1q_f32(input_ptr + 4); + + float32x4x2_t x0y1; + input_ptr += -input_x_offset + input_y_offset; + x0y1.val[0] = vld1q_f32(input_ptr); + x0y1.val[1] = vld1q_f32(input_ptr + 4); + + float32x4x2_t x1y1; + input_ptr += input_x_offset; + x1y1.val[0] = vld1q_f32(input_ptr); + x1y1.val[1] = vld1q_f32(input_ptr + 4); + + // Top left corner. + float* output_ptr = &output_data[Offset(output_dims, ic, x, y, batch)]; + vst1q_f32(output_ptr, x0y0.val[0]); + vst1q_f32(output_ptr + 4, x0y0.val[1]); + + // Top right corner. + output_ptr += output_x_offset; + float32x4x2_t tr; + tr.val[0] = vaddq_f32(x0y0.val[0], x1y0.val[0]); + tr.val[1] = vaddq_f32(x0y0.val[1], x1y0.val[1]); + tr.val[0] = vmulq_n_f32(tr.val[0], 0.5f); + tr.val[1] = vmulq_n_f32(tr.val[1], 0.5f); + + vst1q_f32(output_ptr, tr.val[0]); + vst1q_f32(output_ptr + 4, tr.val[1]); + + // Bottom left corner. + output_ptr += -output_x_offset + output_y_offset; + float32x4x2_t bl; + bl.val[0] = vaddq_f32(x0y0.val[0], x0y1.val[0]); + bl.val[1] = vaddq_f32(x0y0.val[1], x0y1.val[1]); + bl.val[0] = vmulq_n_f32(bl.val[0], 0.5f); + bl.val[1] = vmulq_n_f32(bl.val[1], 0.5f); + vst1q_f32(output_ptr, bl.val[0]); + vst1q_f32(output_ptr + 4, bl.val[1]); + + // Bottom right corner. + output_ptr += output_x_offset; + float32x4x2_t br; + br.val[0] = vaddq_f32(x1y0.val[0], x1y1.val[0]); + br.val[1] = vaddq_f32(x1y0.val[1], x1y1.val[1]); + br.val[0] = vmlaq_n_f32(bl.val[0], br.val[0], 0.5f); + br.val[1] = vmlaq_n_f32(bl.val[1], br.val[1], 0.5f); + br.val[0] = vmulq_n_f32(br.val[0], 0.5f); + br.val[1] = vmulq_n_f32(br.val[1], 0.5f); + vst1q_f32(output_ptr, br.val[0]); + vst1q_f32(output_ptr + 4, br.val[1]); + } + // Handle 4 input channels at a time. + for (; ic <= depth - 4; ic += 4) { + const float* input_ptr = &input_data[Offset(input_dims, ic, x0, y0, batch)]; + float32x4_t x0y0 = vld1q_f32(input_ptr); + float32x4_t x1y0 = vld1q_f32(input_ptr + input_x_offset); + float32x4_t x0y1 = vld1q_f32(input_ptr + input_y_offset); + float32x4_t x1y1 = vld1q_f32(input_ptr + input_x_offset + input_y_offset); + + // Top left corner. + float* output_ptr = &output_data[Offset(output_dims, ic, x, y, batch)]; + vst1q_f32(output_ptr, x0y0); + + // Top right corner. + output_ptr += output_x_offset; + float32x4_t tr = vaddq_f32(x0y0, x1y0); + tr = vmulq_n_f32(tr, 0.5f); + vst1q_f32(output_ptr, tr); + + // Bottom left corner. + output_ptr += -output_x_offset + output_y_offset; + float32x4_t bl = vaddq_f32(x0y0, x0y1); + bl = vmulq_n_f32(bl, 0.5f); + vst1q_f32(output_ptr, bl); + + // Bottom right corner. + output_ptr += output_x_offset; + float32x4_t br = vaddq_f32(x1y0, x1y1); + br = vmlaq_n_f32(bl, br, 0.5f); + br = vmulq_n_f32(br, 0.5f); + vst1q_f32(output_ptr, br); + } + // Handle one input channel at a time. + for (; ic < depth; ic++) { + const int32 input_offset = Offset(input_dims, ic, x0, y0, batch); + + float x0y0 = input_data[input_offset]; + float x1y0 = input_data[input_offset + input_x_offset]; + float x0y1 = input_data[input_offset + input_y_offset]; + float x1y1 = input_data[input_offset + input_x_offset + input_y_offset]; + + // Top left corner. + const int32 output_offset = Offset(output_dims, ic, x, y, batch); + output_data[output_offset] = x0y0; + + // Top right corner. + output_data[output_offset + output_x_offset] = (x0y0 + x1y0) / 2; + + // Bottom left corner. + float output = (x0y0 + x0y1) / 2; + output_data[output_offset + output_y_offset] = output; + + // Bottom right corner. + output_data[output_offset + output_x_offset + output_y_offset] = + (output + ((x1y0 + x1y1) / 2)) / 2; + } +#else + for (int ch = 0; ch < depth; ch++) { + const int32 input_offset = Offset(input_dims, ch, x0, y0, batch); + + float x0y0 = input_data[input_offset]; + float x1y0 = input_data[input_offset + input_x_offset]; + float x0y1 = input_data[input_offset + input_y_offset]; + float x1y1 = input_data[input_offset + input_x_offset + input_y_offset]; + + // Top left corner. + const int32 output_offset = Offset(output_dims, ch, x, y, batch); + output_data[output_offset] = x0y0; + + // Top right corner. + output_data[output_offset + output_x_offset] = (x0y0 + x1y0) / 2; + + // Bottom left corner. + float output = (x0y0 + x0y1) / 2; + output_data[output_offset + output_y_offset] = output; + + // Bottom right corner. + output_data[output_offset + output_x_offset + output_y_offset] = + (output + ((x1y0 + x1y1) / 2)) / 2; + } +#endif +} + +inline void ResizeBilinear2x2(const float* input_data, + const Dims<4>& input_dims, float* output_data, + const Dims<4>& output_dims, int32 batches, + int32 input_height, int32 input_width, + int32 depth, int32 output_height, + int32 output_width) { + for (int b = 0; b < batches; b++) { + for (int y0 = 0, y = 0; y <= output_height - 2; y += 2, y0++) { + for (int x0 = 0, x = 0; x <= output_width - 2; x += 2, x0++) { + int32 x1 = std::min(x0 + 1, input_width - 1); + int32 y1 = std::min(y0 + 1, input_height - 1); + ResizeBilinearKernel2x2(x0, x1, y0, y1, x, y, depth, b, input_data, + input_dims, output_data, output_dims); + } + } + } +} + +inline void ResizeBilinearGeneric(const float* input_data, + const Dims<4>& input_dims, float* output_data, + const Dims<4>& output_dims, int32 batches, + int32 input_height, int32 input_width, + int32 depth, int32 output_height, + int32 output_width, float height_scale, + float width_scale) { + memset(output_data, 0, + batches * output_height * output_width * depth * sizeof(float)); + + int32 output_offset = 0; + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < output_height; ++y) { + float input_y = y * height_scale; + int32 y0 = static_cast(std::floor(input_y)); + int32 y1 = std::min(y0 + 1, input_height - 1); + for (int x = 0; x < output_width; ++x) { + float input_x = x * width_scale; + int32 x0 = static_cast(input_x); + int32 x1 = std::min(x0 + 1, input_width - 1); + float* output_ptr = &output_data[output_offset]; + + // Run kernel on the 4 corners of the bilinear resize algorithm. + int32 input_offset = Offset(input_dims, 0, x0, y0, b); + float scale = (1 - (input_y - y0)) * (1 - (input_x - x0)); + const float* input_ptr = &input_data[input_offset]; + ResizeBilinearKernel(input_ptr, depth, scale, output_ptr); + + input_offset = Offset(input_dims, 0, x1, y0, b); + scale = (1 - (input_y - y0)) * (input_x - x0); + input_ptr = &input_data[input_offset]; + ResizeBilinearKernel(input_ptr, depth, scale, output_ptr); + + input_offset = Offset(input_dims, 0, x0, y1, b); + scale = (input_y - y0) * (1 - (input_x - x0)); + input_ptr = &input_data[input_offset]; + ResizeBilinearKernel(input_ptr, depth, scale, output_ptr); + + input_offset = Offset(input_dims, 0, x1, y1, b); + scale = (input_y - y0) * (input_x - x0); + input_ptr = &input_data[input_offset]; + ResizeBilinearKernel(input_ptr, depth, scale, output_ptr); + + output_offset += depth; + } + } + } +} + +inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims, + const int32* output_size_data, + const Dims<4>& output_size_dims, float* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("ResizeBilinear"); + int32 batches = MatchingArraySize(input_dims, 3, output_dims, 3); + int32 input_height = ArraySize(input_dims, 2); + int32 input_width = ArraySize(input_dims, 1); + int32 depth = MatchingArraySize(input_dims, 0, output_dims, 0); + + TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 3), 1); + TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 2), 1); + TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 1), 1); + TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 0), 2); + int32 output_height = output_size_data[Offset(output_size_dims, 0, 0, 0, 0)]; + int32 output_width = output_size_data[Offset(output_size_dims, 1, 0, 0, 0)]; + + // Specialize for 2x2 upsample. + if (output_height == 2 * input_height && output_width == 2 * input_width) { + ResizeBilinear2x2(input_data, input_dims, output_data, output_dims, batches, + input_height, input_width, depth, output_height, + output_width); + } else { + float height_scale = static_cast(input_height) / output_height; + float width_scale = static_cast(input_width) / output_width; + + ResizeBilinearGeneric(input_data, input_dims, output_data, output_dims, + batches, input_height, input_width, depth, + output_height, output_width, height_scale, + width_scale); + } +} + +template +inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims, + const int32* block_shape_data, + const Dims<4>& block_shape_dims, + const int32* paddings_data, + const Dims<4>& paddings_dims, T* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("SpaceToBatchND"); + + const int output_batch_size = ArraySize(output_dims, 3); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + const int input_batch_size = ArraySize(input_dims, 3); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int depth = ArraySize(input_dims, 0); + const int block_shape_height = block_shape_data[0]; + const int block_shape_width = block_shape_data[1]; + const int padding_top = paddings_data[0]; + const int padding_left = paddings_data[2]; + + for (int out_b = 0; out_b < output_batch_size; ++out_b) { + int input_batch = out_b % input_batch_size; + int shift_w = (out_b / input_batch_size) % block_shape_width; + int shift_h = (out_b / input_batch_size) / block_shape_width; + for (int out_h = 0; out_h < output_height; ++out_h) { + for (int out_w = 0; out_w < output_width; ++out_w) { + T* out = output_data + Offset(output_dims, 0, out_w, out_h, out_b); + if (out_h * block_shape_height < padding_top || + out_h * block_shape_height >= padding_top + input_height || + out_w * block_shape_width < padding_left || + out_w * block_shape_width >= padding_left + input_width) { + memset(out, 0, depth * sizeof(T)); + } else { + const T* in = + input_data + + Offset(input_dims, 0, + (out_w * block_shape_width + shift_w) - padding_left, + (out_h * block_shape_height + shift_h) - padding_top, + input_batch); + memcpy(out, in, depth * sizeof(T)); + } + } + } + } +} + +template +inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims, + const int32* block_shape_data, + const Dims<4>& block_shape_dims, T* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("BatchToSpaceND"); + + const int output_batch_size = ArraySize(output_dims, 3); + const int input_batch_size = ArraySize(input_dims, 3); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int depth = ArraySize(input_dims, 0); + const int block_shape_width = block_shape_data[1]; + const int block_shape_height = block_shape_data[0]; + + for (int in_batch = 0; in_batch < input_batch_size; ++in_batch) { + for (int in_h = 0; in_h < input_height; ++in_h) { + for (int in_w = 0; in_w < input_width; ++in_w) { + int out_batch = in_batch % output_batch_size; + int out_w = in_w * block_shape_width + + (in_batch / output_batch_size) % block_shape_width; + int out_h = in_h * block_shape_height + + (in_batch / output_batch_size) / block_shape_width; + T* out = output_data + Offset(output_dims, 0, out_w, out_h, out_batch); + const T* in = input_data + Offset(input_dims, 0, in_w, in_h, in_batch); + memcpy(out, in, depth * sizeof(T)); + } + } + } +} + +template +inline void Pad(const T* input_data, const Dims<4>& input_dims, + const std::vector& left_paddings, + const std::vector& right_paddings, T* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Pad"); + const int output_batch = ArraySize(output_dims, 3); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + const int output_depth = ArraySize(output_dims, 0); + + const int left_b_padding = left_paddings[3]; + const int left_h_padding = left_paddings[2]; + const int left_w_padding = left_paddings[1]; + const int left_d_padding = left_paddings[0]; + + const int right_b_padding = right_paddings[3]; + const int right_h_padding = right_paddings[2]; + const int right_w_padding = right_paddings[1]; + const int right_d_padding = right_paddings[0]; + + const int input_depth = ArraySize(input_dims, 0); + + if (left_b_padding != 0) { + memset(output_data, 0, + left_b_padding * output_height * output_width * output_depth * + sizeof(T)); + } + for (int out_b = left_b_padding; out_b < output_batch - right_b_padding; + ++out_b) { + if (left_h_padding != 0) { + memset(output_data + Offset(output_dims, 0, 0, 0, out_b), 0, + left_h_padding * output_width * output_depth * sizeof(T)); + } + for (int out_h = left_h_padding; out_h < output_height - right_h_padding; + ++out_h) { + if (left_w_padding != 0) { + memset(output_data + Offset(output_dims, 0, 0, out_h, out_b), 0, + left_w_padding * output_depth * sizeof(T)); + } + for (int out_w = left_w_padding; out_w < output_width - right_w_padding; + ++out_w) { + if (left_d_padding != 0) { + memset(output_data + Offset(output_dims, 0, out_w, out_h, out_b), 0, + left_d_padding * sizeof(T)); + } + + T* out = output_data + + Offset(output_dims, left_d_padding, out_w, out_h, out_b); + const T* in = + input_data + Offset(input_dims, 0, out_w - left_w_padding, + out_h - left_h_padding, out_b - left_b_padding); + memcpy(out, in, input_depth * sizeof(T)); + + if (right_d_padding != 0) { + memset( + output_data + Offset(output_dims, output_depth - right_d_padding, + out_w, out_h, out_b), + 0, right_d_padding * sizeof(T)); + } + } + if (right_w_padding != 0) { + memset( + output_data + Offset(output_dims, 0, output_width - right_w_padding, + out_h, out_b), + 0, right_w_padding * output_depth * sizeof(T)); + } + } + if (right_h_padding != 0) { + memset(output_data + Offset(output_dims, 0, 0, + output_height - right_h_padding, out_b), + 0, right_h_padding * output_width * output_depth * sizeof(T)); + } + } + if (right_b_padding != 0) { + memset(output_data + + Offset(output_dims, 0, 0, 0, output_batch - right_b_padding), + 0, + right_b_padding * output_height * output_width * output_depth * + sizeof(T)); + } +} + +template +inline void StridedSlice(const T* input_data, const Dims<4>& input_dims, + int begin_mask, int end_mask, + const std::vector& starts, + const std::vector& stops, + const std::vector& strides, T* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("StridedSlice"); + const int start_b = (begin_mask & 8) ? 0 : starts[3]; + const int stop_b = (end_mask & 8) ? input_dims.sizes[3] : stops[3]; + const int start_h = (begin_mask & 4) ? 0 : starts[2]; + const int stop_h = (end_mask & 4) ? input_dims.sizes[2] : stops[2]; + const int start_w = (begin_mask & 2) ? 0 : starts[1]; + const int stop_w = (end_mask & 2) ? input_dims.sizes[1] : stops[1]; + const int start_d = (begin_mask & 1) ? 0 : starts[0]; + const int stop_d = (end_mask & 1) ? input_dims.sizes[0] : stops[0]; + + T* out_ptr = output_data; + if (strides[0] == 0) { + for (int in_b = start_b; in_b < stop_b; in_b += strides[3]) { + for (int in_h = start_h; in_h < stop_h; in_h += strides[2]) { + for (int in_w = start_w; in_w < stop_w; in_w += strides[1]) { + const int len = stop_d - start_d; + memcpy(out_ptr, + input_data + Offset(input_dims, start_d, in_w, in_h, in_b), + len * sizeof(T)); + out_ptr += len; + } + } + } + } else { + for (int in_b = start_b; in_b < stop_b; in_b += strides[3]) { + for (int in_h = start_h; in_h < stop_h; in_h += strides[2]) { + for (int in_w = start_w; in_w < stop_w; in_w += strides[1]) { + for (int in_d = start_d; in_d < stop_d; in_d += strides[0]) { + *out_ptr++ = input_data[Offset(input_dims, in_d, in_w, in_h, in_b)]; + } + } + } + } + } +} + +template +inline void Slice(const T* input_data, const Dims<4>& input_dims, + const std::vector& begin, const std::vector& size, + T* output_data, const Dims<4>& output_dims) { + // TODO(dkalenichenko): This op only supports 4D tensors. + TFLITE_DCHECK_EQ(begin.size(), 4); + TFLITE_DCHECK_EQ(size.size(), 4); + const int start_b = begin[3]; + const int stop_b = + size[3] == -1 ? input_dims.sizes[3] - start_b : start_b + size[3]; + const int start_h = begin[2]; + const int stop_h = + size[2] == -1 ? input_dims.sizes[2] - start_b : start_b + size[2]; + const int start_w = begin[1]; + const int stop_w = + size[1] == -1 ? input_dims.sizes[1] - start_b : start_b + size[1]; + const int start_d = begin[0]; + const int stop_d = + size[0] == -1 ? input_dims.sizes[0] - start_d : start_d + size[0]; + + T* out_ptr = output_data; + for (int in_b = start_b; in_b < stop_b; ++in_b) { + for (int in_h = start_h; in_h < stop_h; ++in_h) { + for (int in_w = start_w; in_w < stop_w; ++in_w) { + const int len = stop_d - start_d; + memcpy(out_ptr, + input_data + Offset(input_dims, start_d, in_w, in_h, in_b), + len * sizeof(T)); + out_ptr += len; + } + } + } +} + +template +inline void Mean(const T* input_data, const Dims<4>& input_dims, + const std::vector& reduction_indices, T* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Mean"); + const int output_batch = ArraySize(output_dims, 3); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + const int output_depth = ArraySize(output_dims, 0); + + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + + // The current implementation only supports simultaneous reduction over + // width and height. + TFLITE_DCHECK_EQ(reduction_indices.size(), 2); + TFLITE_DCHECK((reduction_indices[0] == 1 && reduction_indices[1] == 2) || + (reduction_indices[0] == 2 && reduction_indices[1] == 1)); + TFLITE_DCHECK_EQ(output_height, 1); + TFLITE_DCHECK_EQ(output_width, 1); + + for (int out_b = 0; out_b < output_batch; ++out_b) { + for (int out_d = 0; out_d < output_depth; ++out_d) { + float value = 0; + for (int in_h = 0; in_h < input_height; ++in_h) { + for (int in_w = 0; in_w < input_width; ++in_w) { + value += input_data[Offset(input_dims, out_d, in_w, in_h, out_b)]; + } + } + output_data[Offset(output_dims, out_d, 0, 0, out_b)] = + value / (input_width * input_height); + } + } +} + +template +void GenericBroadcastSub(const T* input1_data, const Dims<4>& input1_dims, + const T* input2_data, const Dims<4>& input2_dims, + T* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("GenericBroadcastSub"); + + NdArrayDesc<4> desc1; + NdArrayDesc<4> desc2; + NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); + + // In Tensorflow, the dimensions are canonically named (batch_number, row, + // col, channel), with extents (batches, height, width, depth), with the + // trailing dimension changing most rapidly (channels has the smallest stride, + // typically 1 element). + // + // In generated C code, we store arrays with the dimensions reversed. The + // first dimension has smallest stride. + // + // We name our variables by their Tensorflow convention, but generate C code + // nesting loops such that the innermost loop has the smallest stride for the + // best cache behavior. + for (int b = 0; b < ArraySize(output_dims, 3); ++b) { + for (int y = 0; y < ArraySize(output_dims, 2); ++y) { + for (int x = 0; x < ArraySize(output_dims, 1); ++x) { + for (int c = 0; c < ArraySize(output_dims, 0); ++c) { + output_data[Offset(output_dims, c, x, y, b)] = + input1_data[SubscriptToIndex(desc1, c, x, y, b)] - + input2_data[SubscriptToIndex(desc2, c, x, y, b)]; + } + } + } + } +} + +template +void Sub(const T* input1_data, const Dims<4>& input1_dims, const T* input2_data, + const Dims<4>& input2_dims, T* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("Sub"); + + auto input1_map = MapAsVector(input1_data, input1_dims); + auto input2_map = MapAsVector(input2_data, input2_dims); + auto output_map = MapAsVector(output_data, output_dims); + if (AreSameDims(input1_dims, input2_dims)) { + output_map.array() = input1_map.array() - input2_map.array(); + } else if (RequiredBufferSizeForDims(input1_dims) == 1) { + auto scalar = input1_data[0]; + output_map.array() = scalar - input2_map.array(); + } else if (RequiredBufferSizeForDims(input2_dims) == 1) { + auto scalar = input2_data[0]; + output_map.array() = input1_map.array() - scalar; + } else { + GenericBroadcastSub(input1_data, input1_dims, input2_data, input2_dims, + output_data, output_dims); + } +} + +template +void TensorFlowMinimum(const T* input1_data, const Dims<4>& input1_dims, + const T* input2_data, T* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("TensorFlowMinimum"); + auto input1_map = MapAsVector(input1_data, input1_dims); + auto output_map = MapAsVector(output_data, output_dims); + auto min_value = input2_data[0]; + output_map.array() = input1_map.array().min(min_value); +} + +template +void TensorFlowMaximum(const T* input1_data, const Dims<4>& input1_dims, + const T* input2_data, T* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("TensorFlowMaximum"); + auto input1_map = MapAsVector(input1_data, input1_dims); + auto output_map = MapAsVector(output_data, output_dims); + auto max_value = input2_data[0]; + output_map.array() = input1_map.array().max(max_value); +} +} // namespace optimized_ops +} // namespace tflite + +#if defined OPTIMIZED_OPS_H__IGNORE_DEPRECATED_DECLARATIONS +#undef OPTIMIZED_OPS_H__IGNORE_DEPRECATED_DECLARATIONS +#pragma GCC diagnostic pop +#endif + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_OPS_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h b/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..f8be99e82fb8721ced7a3e5da686b20ce241ea2d --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h @@ -0,0 +1,138 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TF_LITE_KERNELS_INTERNAL_OPTIMIZED_TENSOR_UTILS_IMPL_H_ +#define TF_LITE_KERNELS_INTERNAL_OPTIMIZED_TENSOR_UTILS_IMPL_H_ + +// TDOD(ghodrat): Remove this header file and the dependency to internal data +// structure. +#include "tensorflow/contrib/lite/builtin_op_data.h" + +#ifndef USE_NEON +#if defined(__ARM_NEON__) || defined(__ARM_NEON) +#define USE_NEON +#endif // defined(__ARM_NEON__) || defined(__ARM_NEON) +#endif // USE_NEON + +namespace tflite { +namespace tensor_utils { + +// Multiply a matrix by a batch vector, and store results in a batch-size +// vector. +void PortableMatrixBatchVectorMultiplyAccumulate(const float* matrix, + int m_rows, int m_cols, + const float* vector, + int n_batch, float* result, + int result_stride); +void NeonMatrixBatchVectorMultiplyAccumulate(const float* matrix, int m_rows, + int m_cols, const float* vector, + int n_batch, float* result, + int result_stride); + +// Cwise product of two vectors. +void PortableVectorVectorCwiseProduct(const float* vector1, + const float* vector2, int v_size, + float* result); +void NeonVectorVectorCwiseProduct(const float* vector1, const float* vector2, + int v_size, float* result); + +// Cwise product and accumulate of two vectors. Since it's a MAC operation, the +// assumption here is that result array is initialized to valid values. +void PortableVectorVectorCwiseProductAccumulate(const float* vector1, + const float* vector2, + int v_size, float* result); +void NeonVectorVectorCwiseProductAccumulate(const float* vector1, + const float* vector2, int v_size, + float* result); + +// Dot product of two vectors. +float PortableVectorVectorDotProduct(const float* vector1, const float* vector2, + int v_size); +float NeonVectorVectorDotProduct(const float* vector1, const float* vector2, + int v_size); + +// Dot product of two batch vectors. +void PortableBatchVectorBatchVectorDotProduct(const float* vector1, + const float* vector2, int v_size, + int n_batch, float* result, + int result_stride); +void NeonBatchVectorBatchVectorDotProduct(const float* vector1, + const float* vector2, int v_size, + int n_batch, float* result, + int result_stride); + +// Cwise product and accumulate of a vector and a batch-vector. Since it's a MAC +// operation, the assumption here is that result array is initialized to valid +// values. +void PortableVectorBatchVectorCwiseProductAccumulate(const float* vector, + int v_size, + const float* batch_vector, + int n_batch, + float* result); +void NeonVectorBatchVectorCwiseProductAccumulate(const float* vector, + int v_size, + const float* batch_vector, + int n_batch, float* result); + +// Compute "1.0f - elements of vector" (used in CIFG). +void PortableSub1Vector(const float* vector, int v_size, float* result); +void NeonSub1Vector(const float* vector, int v_size, float* result); + +// Clip elements of a vector using a abs_limit value. +void PortableClipVector(const float* vector, int v_size, float abs_limit, + float* result); +void NeonClipVector(const float* vector, int v_size, float abs_limit, + float* result); + +// Batch vector initialization with another vector. +void PortableVectorBatchVectorAssign(const float* vector, int v_size, + int n_batch, float* batch_vector); + +// Apply sigmoid to elements of a vector. +void PortableApplySigmoidToVector(const float* vector, int v_size, + float* result); + +// Apply activation function to elements of a vector. +void PortableApplyActivationToVector(const float* vector, int v_size, + TfLiteFusedActivation activation, + float* result); + +// Copy vector to another vector. +void PortableCopyVector(const float* vector, int v_size, float* result); + +// Fill vector with 0.f. +void PortableZeroVector(float* vector, int v_size); + +// Limit a float input f between +abs_limit and -abs_limit. +float PortableClip(float f, float abs_limit); + +// Shift left a vector in place with v_size size. +void PortableVectorShiftLeft(float* vector, int v_size, float shift_value); +void NeonVectorShiftLeft(float* vector, int v_size, float shift_value); + +// Reduce-sum on a float input vector: +// input_vector: float pointer to input vector. +// output_vector: float pointer to vector. +// output_size: output vector size. +// reduction_size: number of consecutive elements from input vector which are +// added to get one element of output. +void PortableReductionSumVector(const float* input_vector, float* output_vector, + int output_size, int reduction_size); +void NeonReductionSumVector(const float* input_vector, float* output_vector, + int output_size, int reduction_size); + +} // namespace tensor_utils +} // namespace tflite + +#endif // TF_LITE_KERNELS_INTERNAL_OPTIMIZED_TENSOR_UTILS_IMPL_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util.cc b/tensorflow/contrib/lite/kernels/internal/quantization_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..98f2e365c5249a6c28673fc185ebec34cc2105b2 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/quantization_util.cc @@ -0,0 +1,95 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include + +#include "tensorflow/contrib/lite/kernels/internal/compatibility.h" +#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" +#include "tensorflow/contrib/lite/kernels/internal/round.h" + +namespace tflite { + +void QuantizeMultiplierSmallerThanOne(double double_multiplier, + int32_t* quantized_multiplier, + int* right_shift) { + TFLITE_CHECK(double_multiplier >= 0.); + TFLITE_CHECK(double_multiplier < 1.); + if (double_multiplier == 0.) { + *quantized_multiplier = 0; + *right_shift = 0; + return; + } + TFLITE_CHECK(double_multiplier > 0.); + const double q = std::frexp(double_multiplier, right_shift); + *right_shift *= -1; + + auto q_fixed = static_cast(TfLiteRound(q * (1ll << 31))); + TFLITE_CHECK(q_fixed <= (1ll << 31)); + if (q_fixed == (1ll << 31)) { + q_fixed /= 2; + --*right_shift; + } + TFLITE_CHECK_GE(*right_shift, 0); + TFLITE_CHECK_LE(q_fixed, std::numeric_limits::max()); + *quantized_multiplier = static_cast(q_fixed); +} + +void QuantizeMultiplierGreaterThanOne(double double_multiplier, + int32_t* quantized_multiplier, + int* left_shift) { + TFLITE_CHECK(double_multiplier > 1.); + const double q = std::frexp(double_multiplier, left_shift); + auto q_fixed = static_cast(TfLiteRound(q * (1ll << 31))); + TFLITE_CHECK(q_fixed <= (1ll << 31)); + if (q_fixed == (1ll << 31)) { + q_fixed /= 2; + ++*left_shift; + } + TFLITE_CHECK_GE(*left_shift, 0); + TFLITE_CHECK_LE(q_fixed, std::numeric_limits::max()); + *quantized_multiplier = static_cast(q_fixed); +} + +void PreprocessSoftmaxScaling(double beta, double input_scale, + int input_integer_bits, + int32_t* quantized_multiplier, int* left_shift) { + // If the overall multiplier (input and beta) is large, then exp() of an + // input difference of 1 scaled by this will be large. In other words, we + // can cap the multiplier and know that, when it is used, the output will be + // (round to) zero wherever the input is not at the maximum value. + + // If the overall scale is less than one, and input_integer_bits=0, then the + // result is double equivalent of Q0.31 (actually with more precision). Thus + // this generates a Q(input_integer_bits).(31-input_integer_bits) + // representation. + const double input_beta_real_multiplier = std::min( + beta * input_scale * (1 << (31 - input_integer_bits)), (1ll << 31) - 1.0); + + QuantizeMultiplierGreaterThanOne(input_beta_real_multiplier, + quantized_multiplier, left_shift); +} + +int CalculateInputRadius(int input_integer_bits, int input_left_shift) { + const double max_input_rescaled = 1.0 * ((1 << input_integer_bits) - 1) * + (1ll << (31 - input_integer_bits)) / + (1ll << input_left_shift); + // Tighten bound using floor. Suppose that we could use the exact value. + // After scaling the difference, the result would be at the maximum. Thus we + // must ensure that our value has lower magnitude. + return static_cast(std::floor(max_input_rescaled)); +} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util.h b/tensorflow/contrib/lite/kernels/internal/quantization_util.h new file mode 100644 index 0000000000000000000000000000000000000000..efb7191c8deb2a23ea5473ab131d2b6537202765 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/quantization_util.h @@ -0,0 +1,55 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef PHOTOS_VISION_LEARNING_TENSORFLOW_MINI_QUANTIZATION_UTIL_H_ +#define PHOTOS_VISION_LEARNING_TENSORFLOW_MINI_QUANTIZATION_UTIL_H_ + +#include + +namespace tflite { + +// Decompose a double multiplier into a Q0.31 int32 representation of its +// significand, and shift representation of its exponent. +// +// Restricted to the case where the multiplier < 1 (and non-negative). +void QuantizeMultiplierSmallerThanOne(double double_multiplier, + int32_t* quantized_multiplier, + int* right_shift); + +// Decompose a double multiplier into a Q0.31 int32 representation of its +// significand, and shift representation of its exponent. +// +// Restricted to the case where the multiplier > 1. +void QuantizeMultiplierGreaterThanOne(double double_multiplier, + int32_t* quantized_multiplier, + int* left_shift); + +// This first creates a multiplier in a double equivalent of +// Q(input_integer_bits).(31-input_integer_bits) representation, with extra +// precision in the double's fractional bits. It then splits the result into +// significand and exponent. +void PreprocessSoftmaxScaling(double beta, double input_scale, + int input_integer_bits, + int32_t* quantized_multiplier, int* left_shift); + +// Calculate the largest input that will result in a within-bounds intermediate +// result within MultiplyByQuantizedMultiplierGreaterThanOne. In other words, +// it must not overflow before we reduce the value by multiplication by the +// input multiplier. The negative radius is used as the minimum difference +// in Softmax. +int CalculateInputRadius(int input_integer_bits, int input_left_shift); + +} // namespace tflite + +#endif // PHOTOS_VISION_LEARNING_TENSORFLOW_MINI_QUANTIZATION_UTIL_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc b/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..d6f306e2cbae3c780b3d773638ba46cd2abf02f5 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc @@ -0,0 +1,108 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" + +#include +#include + +namespace tflite { +namespace { + +using ::testing::Pair; + +TEST(QuantizationUtilTest, QuantizeMultiplierSmallerThanOne) { + auto quantize = [](double d) { + int32_t q; + int s; + QuantizeMultiplierSmallerThanOne(d, &q, &s); + return std::pair{q, s}; + }; + + EXPECT_DEATH(quantize(-0.1), ""); + EXPECT_THAT(quantize(0.0), Pair(0, 0)); + EXPECT_THAT(quantize(0.25), Pair(1073741824, 1)); + + // Around 0.5 we can see the change in exponent and how we try hard to + // void hitting max int32. + EXPECT_THAT(quantize(0.50 - 5e-9), Pair(2147483627, 1)); + EXPECT_THAT(quantize(0.50 - 1e-10), Pair(1073741824, 0)); + EXPECT_THAT(quantize(0.50), Pair(1073741824, 0)); + + EXPECT_THAT(quantize(0.75), Pair(1610612736, 0)); + EXPECT_THAT(quantize(1 - 1e-9), Pair(2147483646, 0)); + + // If we get close enough to 1.0 it crashes and dies in one of two ways: + // Either the shift becomes negative or we trigger the 'less-than-one' CHECK. + EXPECT_DEATH(quantize(1 - 1e-15), ""); + EXPECT_DEATH(quantize(1 - 1e-17), ""); + EXPECT_DEATH(quantize(1.0), ""); +} + +TEST(QuantizationUtilTest, QuantizeMultiplierGreaterThanOne) { + auto quantize = [](double d) { + int32_t q; + int s; + QuantizeMultiplierGreaterThanOne(d, &q, &s); + return std::pair{q, s}; + }; + + // If we are close enough to 1.0 it crashes. + EXPECT_DEATH(quantize(1 + 1e-16), ""); + + EXPECT_THAT(quantize(1 + 1e-11), Pair(1073741824, 1)); + EXPECT_THAT(quantize(1.25), Pair(1342177280, 1)); + EXPECT_THAT(quantize(1.50), Pair(1610612736, 1)); + EXPECT_THAT(quantize(1.75), Pair(1879048192, 1)); + + // Around the powers of two we see the change in exponent. Also, + // we try hard to avoid hitting max int32. + EXPECT_THAT(quantize(2 - 1e-9), Pair(2147483647, 1)); + EXPECT_THAT(quantize(2 - 1e-11), Pair(1073741824, 2)); + EXPECT_THAT(quantize(2), Pair(1073741824, 2)); +} + +TEST(QuantizationUtilTest, PreprocessSoftmaxScaling) { + auto quantize = [](double beta, double scale, int integer_bits) { + int32_t q; + int s; + PreprocessSoftmaxScaling(beta, scale, integer_bits, &q, &s); + return std::pair{q, s}; + }; + + // If beta * scale is greater than fits in the number of integer bits, the + // result is move near the maximum. Otherwise they quantize as expected. + // With 4 integer bits we can represent up to 16.0. + EXPECT_THAT(quantize(1.0, 16.0, 4), Pair(2147483647, 31)); + EXPECT_THAT(quantize(1.0, 8.0, 4), Pair(1073741824, 31)); + // But with 5 bits we can go further. + EXPECT_THAT(quantize(2.0, 16.0, 5), Pair(2147483647, 31)); + EXPECT_THAT(quantize(2.0, 8.0, 5), Pair(1073741824, 31)); +} + +TEST(QuantizationUtilTest, CalculateInputRadius) { + EXPECT_EQ(CalculateInputRadius(4, 27), 15); + EXPECT_EQ(CalculateInputRadius(3, 27), 14); + EXPECT_EQ(CalculateInputRadius(3, 28), 7); + EXPECT_EQ(CalculateInputRadius(4, 2), 503316480); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + // On Linux, add: tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h new file mode 100644 index 0000000000000000000000000000000000000000..8e0f234545e43dd8b2412e065aaecad8325a1182 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h @@ -0,0 +1,115 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_FLOAT_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_FLOAT_H_ + +#include "tensorflow/contrib/lite/kernels/internal/common.h" +#include "tensorflow/contrib/lite/kernels/internal/compatibility.h" +#include "tensorflow/contrib/lite/kernels/internal/types.h" + +namespace tflite { +namespace reference_ops { + +inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + const float* bias_data, const Dims<4>& bias_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int depth_multiplier, + float output_activation_min, + float output_activation_max, float* output_data, + const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int input_depth = ArraySize(input_dims, 0); + const int filter_height = ArraySize(filter_dims, 2); + const int filter_width = ArraySize(filter_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + TFLITE_DCHECK(output_depth == input_depth * depth_multiplier); + + for (int b = 0; b < batches; ++b) { + for (int out_y = 0; out_y < output_height; ++out_y) { + for (int out_x = 0; out_x < output_width; ++out_x) { + for (int ic = 0; ic < input_depth; ++ic) { + for (int m = 0; m < depth_multiplier; m++) { + const int oc = m + ic * depth_multiplier; + const int in_x_origin = (out_x * stride_width) - pad_width; + const int in_y_origin = (out_y * stride_height) - pad_height; + float total = 0.f; + for (int filter_y = 0; filter_y < filter_height; ++filter_y) { + for (int filter_x = 0; filter_x < filter_width; ++filter_x) { + const int in_x = in_x_origin + filter_x; + const int in_y = in_y_origin + filter_y; + // If the location is outside the bounds of the input image, + // use zero as a default value. + if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) && + (in_y < input_height)) { + float input_value = + input_data[Offset(input_dims, ic, in_x, in_y, b)]; + float filter_value = filter_data[Offset( + filter_dims, oc, filter_x, filter_y, 0)]; + total += (input_value * filter_value); + } + } + } + float bias_value = 0.0f; + if (bias_data) { + bias_value = bias_data[Offset(bias_dims, oc, 0, 0, 0)]; + } + output_data[Offset(output_dims, oc, out_x, out_y, b)] = + ActivationFunctionWithMinMax(total + bias_value, + output_activation_min, + output_activation_max); + } + } + } + } + } +} + +// Legacy, for compatibility with old checked-in code. +template +void DepthwiseConv(const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + const float* bias_data, const Dims<4>& bias_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int depth_multiplier, float* output_data, + const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + DepthwiseConv(input_data, input_dims, filter_data, filter_dims, bias_data, + bias_dims, stride_width, stride_height, pad_width, pad_height, + depth_multiplier, output_activation_min, output_activation_max, + output_data, output_dims); +} + +// Legacy, for compatibility with old checked-in code. +template +void DepthwiseConv(const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + const float* bias_data, const Dims<4>& bias_dims, int stride, + int pad_width, int pad_height, int depth_multiplier, + float* output_data, const Dims<4>& output_dims) { + DepthwiseConv(input_data, input_dims, filter_data, filter_dims, bias_data, + bias_dims, stride, stride, pad_width, pad_height, + depth_multiplier, output_data, output_dims); +} + +} // end namespace reference_ops +} // end namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_FLOAT_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h new file mode 100644 index 0000000000000000000000000000000000000000..8a80558b32f2858778460956cd9f57617674e21e --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h @@ -0,0 +1,138 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_UINT8_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_UINT8_H_ + +#include + +#include "fixedpoint/fixedpoint.h" +#include "public/gemmlowp.h" +#include "tensorflow/contrib/lite/kernels/internal/common.h" +#include "tensorflow/contrib/lite/kernels/internal/compatibility.h" +#include "tensorflow/contrib/lite/kernels/internal/types.h" + +namespace tflite { +namespace reference_ops { + +inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims, + int32 input_offset, const uint8* filter_data, + const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int depth_multiplier, + int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int output_depth = MatchingArraySize(filter_dims, 0, output_dims, 0); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int input_depth = ArraySize(input_dims, 0); + const int filter_height = ArraySize(filter_dims, 2); + const int filter_width = ArraySize(filter_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + TFLITE_DCHECK(output_depth == input_depth * depth_multiplier); + + for (int b = 0; b < batches; ++b) { + for (int out_y = 0; out_y < output_height; ++out_y) { + for (int out_x = 0; out_x < output_width; ++out_x) { + for (int ic = 0; ic < input_depth; ++ic) { + for (int m = 0; m < depth_multiplier; m++) { + const int oc = m + ic * depth_multiplier; + const int in_x_origin = (out_x * stride_width) - pad_width; + const int in_y_origin = (out_y * stride_height) - pad_height; + int32 acc = 0; + for (int filter_y = 0; filter_y < filter_height; ++filter_y) { + for (int filter_x = 0; filter_x < filter_width; ++filter_x) { + const int in_x = in_x_origin + filter_x; + const int in_y = in_y_origin + filter_y; + // If the location is outside the bounds of the input image, + // use zero as a default value. + if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) && + (in_y < input_height)) { + int32 input_val = + input_data[Offset(input_dims, ic, in_x, in_y, b)]; + int32 filter_val = filter_data[Offset(filter_dims, oc, + filter_x, filter_y, 0)]; + acc += + (filter_val + filter_offset) * (input_val + input_offset); + } + } + } + if (bias_data) { + acc += bias_data[Offset(bias_dims, oc, 0, 0, 0)]; + } + acc = MultiplyByQuantizedMultiplierSmallerThanOne( + acc, output_multiplier, output_shift); + acc += output_offset; + acc = std::max(acc, output_activation_min); + acc = std::min(acc, output_activation_max); + output_data[Offset(output_dims, oc, out_x, out_y, b)] = + static_cast(acc); + } + } + } + } + } +} + +// Legacy, for compatibility with old checked-in code. +template +void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims, + int32 input_offset, const uint8* filter_data, + const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int depth_multiplier, int32 output_offset, + int32 output_multiplier, int output_shift, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + DepthwiseConv(input_data, input_dims, input_offset, filter_data, filter_dims, + filter_offset, bias_data, bias_dims, stride_width, + stride_height, pad_width, pad_height, depth_multiplier, + output_offset, output_multiplier, output_shift, + output_activation_min, output_activation_max, output_data, + output_dims); +} + +// Legacy, for compatibility with old checked-in code. +template +void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims, + int32 input_offset, const uint8* filter_data, + const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, int stride, + int pad_width, int pad_height, int depth_multiplier, + int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + DepthwiseConv(input_data, input_dims, input_offset, filter_data, + filter_dims, filter_offset, bias_data, bias_dims, stride, + stride, pad_width, pad_height, depth_multiplier, + output_offset, output_multiplier, output_shift, + output_activation_min, output_activation_max, output_data, + output_dims); +} + +} // end namespace reference_ops +} // end namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_UINT8_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc new file mode 100644 index 0000000000000000000000000000000000000000..c5b0bccc9da5fa2ff9c3a9d430725b613435abf1 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc @@ -0,0 +1,165 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/kernels/activation_functor.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace tensor_utils { + +float PortableClip(float f, float abs_limit) { + float result = (abs_limit < f) ? abs_limit : f; + result = (-abs_limit > result) ? -abs_limit : result; + return result; +} + +void PortableMatrixBatchVectorMultiplyAccumulate(const float* matrix, + int m_rows, int m_cols, + const float* vector, + int n_batch, float* result, + int result_stride) { + float* result_in_batch = result; + for (int b = 0; b < n_batch; b++) { + const float* matrix_ptr = matrix; + for (int r = 0; r < m_rows; r++) { + const float* vector_in_batch = vector + b * m_cols; + for (int c = 0; c < m_cols; c++) { + *result_in_batch += *matrix_ptr++ * *vector_in_batch++; + } + result_in_batch += result_stride; + } + } +} + +void PortableVectorVectorCwiseProduct(const float* vector1, + const float* vector2, int v_size, + float* result) { + for (int v = 0; v < v_size; v++) { + *result++ = *vector1++ * *vector2++; + } +} + +float PortableVectorVectorDotProduct(const float* vector1, const float* vector2, + int v_size) { + float result = 0.0; + for (int v = 0; v < v_size; v++) { + result += *vector1++ * *vector2++; + } + return result; +} + +void PortableBatchVectorBatchVectorDotProduct(const float* vector1, + const float* vector2, int v_size, + int n_batch, float* result, + int result_stride) { + float* result_ptr = result; + const float* vector1_ptr = vector1; + const float* vector2_ptr = vector2; + for (int b = 0; b < n_batch; b++) { + *result_ptr = + PortableVectorVectorDotProduct(vector1_ptr, vector2_ptr, v_size); + vector1_ptr += v_size; + vector2_ptr += v_size; + result_ptr += result_stride; + } +} + +void PortableVectorVectorCwiseProductAccumulate(const float* vector1, + const float* vector2, + int v_size, float* result) { + for (int v = 0; v < v_size; v++) { + *result++ += *vector1++ * *vector2++; + } +} + +void PortableVectorBatchVectorCwiseProductAccumulate(const float* vector, + int v_size, + const float* batch_vector, + int n_batch, + float* result) { + for (int b = 0; b < n_batch; b++) { + for (int v = 0; v < v_size; v++) { + *result++ += vector[v] * *batch_vector++; + } + } +} + +void PortableVectorBatchVectorAssign(const float* vector, int v_size, + int n_batch, float* batch_vector) { + for (int b = 0; b < n_batch; b++) { + memcpy(batch_vector + b * v_size, vector, v_size * sizeof(float)); + } +} + +void PortableApplySigmoidToVector(const float* vector, int v_size, + float* result) { + auto sigmoid_func = ActivationFunctor(kTfLiteActSigmoid); + for (int v = 0; v < v_size; v++) { + *result++ = (sigmoid_func)(*vector++); + } +} + +void PortableApplyActivationToVector(const float* vector, int v_size, + TfLiteFusedActivation activation, + float* result) { + auto activation_func = ActivationFunctor(activation); + for (int v = 0; v < v_size; v++) { + *result++ = (activation_func)(*vector++); + } +} + +void PortableCopyVector(const float* vector, int v_size, float* result) { + memcpy(result, vector, v_size * sizeof(float)); +} + +void PortableSub1Vector(const float* vector, int v_size, float* result) { + for (int v = 0; v < v_size; v++) { + *result++ = 1.0f - *vector++; + } +} + +void PortableZeroVector(float* vector, int v_size) { + memset(vector, 0, v_size * sizeof(float)); +} + +void PortableClipVector(const float* vector, int v_size, float abs_limit, + float* result) { + for (int v = 0; v < v_size; v++) { + *result++ = PortableClip(*vector++, abs_limit); + } +} + +void PortableVectorShiftLeft(float* vector, int v_size, float shift_value) { + TF_LITE_ASSERT(v_size > 0); + for (int i = 0; i < v_size - 1; i++) { + vector[i] = vector[i + 1]; + } + vector[v_size - 1] = shift_value; +} + +void PortableReductionSumVector(const float* input_vector, float* output_vector, + int output_size, int reduction_size) { + const float* input_vector_ptr = input_vector; + for (int o = 0; o < output_size; o++) { + for (int r = 0; r < reduction_size; r++) { + output_vector[o] += *input_vector_ptr++; + } + } +} + +} // namespace tensor_utils +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..c2ab78000b81485f037c507933cd024e70f39850 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h @@ -0,0 +1,189 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_PORTABLE_TENSOR_UTILS_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_PORTABLE_TENSOR_UTILS_H_ + +// TDOD(ghodrat): Remove this header file and the dependency to internal data +// structure. +#include "tensorflow/contrib/lite/builtin_op_data.h" + +namespace tflite { +namespace tensor_utils { + +// Limit a float input f betweeen +abs_limit and -abs_limit. +float PortableClip(float f, float abs_limit); + +// Multiply a matrix by a batch vector, and store results in a batch-size +// vector. +void PortableMatrixBatchVectorMultiplyAccumulate(const float* matrix, + int m_rows, int m_cols, + const float* vector, + int n_batch, float* result, + int result_stride); + +// Cwise product of two vectors. +void PortableVectorVectorCwiseProduct(const float* vector1, + const float* vector2, int v_size, + float* result); + +// Cwise product and accumulate of two vectors. Since it's a MAC opertation, the +// assumption here is that result array is initialized to valid values. +void PortableVectorVectorCwiseProductAccumulate(const float* vector1, + const float* vector2, + int v_size, float* result); + +// Dot product of two vectors. +float PortableVectorVectorDotProduct(const float* vector1, const float* vector2, + int v_size); + +// Dot product of two batch vectors. +void PortableBatchVectorBatchVectorDotProduct(const float* vector1, + const float* vector2, int v_size, + int n_batch, float* result, + int result_stride); + +// Cwise product and accumulate of a vector and a batch-vector. Since it's a MAC +// operation, the assumption here is that result array is initialized to valid +// values. +void PortableVectorBatchVectorCwiseProductAccumulate(const float* vector, + int v_size, + const float* batch_vector, + int n_batch, + float* result); + +// Batch vector initialization with another vector. +void PortableVectorBatchVectorAssign(const float* vector, int v_size, + int n_batch, float* batch_vector); + +// Apply sigmoid to elements of a vector. +void PortableApplySigmoidToVector(const float* vector, int v_size, + float* result); + +// Apply activation function to elements of a vector. +void PortableApplyActivationToVector(const float* vector, int v_size, + TfLiteFusedActivation activation, + float* result); + +// Copy vector to another vector. +void PortableCopyVector(const float* vector, int v_size, float* result); + +// Compute "1.0f - elements of vector" (used in CIFG). +void PortableSub1Vector(const float* vector, int v_size, float* result); + +// Fill vector with 0.f. +void PortableZeroVector(float* vector, int v_size); + +// Clip elements of a vector using a abs_limit value. +void PortableClipVector(const float* vector, int v_size, float abs_limit, + float* result); + +// Shift left a vector in place with v_size size. +void PortableVectorShiftLeft(float* vector, int v_size, float shift_value); + +// Reduce-sum on a float input vector: +// input_vector: float pointer to input vector. +// output_vector: float pointer to vector. +// output_size: output vector size. +// reduction_size: number of consecutive elements from input vector which are +// added to get one element of output. +void PortableReductionSumVector(const float* input_vector, float* output_vector, + int output_size, int reduction_size); + +float Clip(float f, float abs_limit) { return PortableClip(f, abs_limit); } + +void MatrixBatchVectorMultiplyAccumulate(const float* matrix, int m_rows, + int m_cols, const float* vector, + int n_batch, float* result, + int result_stride) { + PortableMatrixBatchVectorMultiplyAccumulate(matrix, m_rows, m_cols, vector, + n_batch, result, result_stride); +} + +void VectorVectorCwiseProduct(const float* vector1, const float* vector2, + int v_size, float* result) { + PortableVectorVectorCwiseProduct(vector1, vector2, v_size, result); +} + +void VectorVectorCwiseProductAccumulate(const float* vector1, + const float* vector2, int v_size, + float* result) { + PortableVectorVectorCwiseProductAccumulate(vector1, vector2, v_size, result); +} + +void VectorBatchVectorCwiseProductAccumulate(const float* vector, int v_size, + const float* batch_vector, + int n_batch, float* result) { + PortableVectorBatchVectorCwiseProductAccumulate(vector, v_size, batch_vector, + n_batch, result); +} + +float VectorVectorDotProduct(const float* vector1, const float* vector2, + int v_size) { + return PortableVectorVectorDotProduct(vector1, vector2, v_size); +} + +void BatchVectorBatchVectorDotProduct(const float* vector1, + const float* vector2, int v_size, + int n_batch, float* result, + int result_stride) { + PortableBatchVectorBatchVectorDotProduct(vector1, vector2, v_size, n_batch, + result, result_stride); +} + +void VectorBatchVectorAssign(const float* vector, int v_size, int n_batch, + float* batch_vector) { + PortableVectorBatchVectorAssign(vector, v_size, n_batch, batch_vector); +} + +void ApplySigmoidToVector(const float* vector, int v_size, float* result) { + PortableApplySigmoidToVector(vector, v_size, result); +} + +void ApplyActivationToVector(const float* vector, int v_size, + TfLiteFusedActivation activation, float* result) { + PortableApplyActivationToVector(vector, v_size, activation, result); +} + +void CopyVector(const float* vector, int v_size, float* result) { + PortableCopyVector(vector, v_size, result); +} + +void Sub1Vector(const float* vector, int v_size, float* result) { + PortableSub1Vector(vector, v_size, result); +} + +void ZeroVector(float* vector, int v_size) { + PortableZeroVector(vector, v_size); +} + +void ClipVector(const float* vector, int v_size, float abs_limit, + float* result) { + PortableClipVector(vector, v_size, abs_limit, result); +} + +void VectorShiftLeft(float* vector, int v_size, float shift_value) { + PortableVectorShiftLeft(vector, v_size, shift_value); +} + +void ReductionSumVector(const float* input_vector, float* output_vector, + int output_size, int reduction_size) { + PortableReductionSumVector(input_vector, output_vector, output_size, + reduction_size); +} + +} // namespace tensor_utils +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_PORTABLE_TENSOR_UTILS_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..b9ca3d5c626dff4ea8ba52949e8fea8e9b43689f --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h @@ -0,0 +1,2455 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_REFERENCE_OPS_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_REFERENCE_OPS_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "third_party/eigen3/Eigen/Core" +#include "fixedpoint/fixedpoint.h" +#include "public/gemmlowp.h" +#include "tensorflow/contrib/lite/kernels/internal/common.h" +#include "tensorflow/contrib/lite/kernels/internal/round.h" +#include "tensorflow/contrib/lite/kernels/internal/types.h" + +namespace tflite { +namespace reference_ops { + +inline int32 MultiplyByQuantizedMultiplierSmallerThanOne( + int32 x, int32 quantized_multiplier, int right_shift) { + using gemmlowp::RoundingDivideByPOT; + using gemmlowp::SaturatingRoundingDoublingHighMul; + return RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(x, quantized_multiplier), right_shift); +} + +inline int32 MultiplyByQuantizedMultiplierGreaterThanOne( + int32 x, int32 quantized_multiplier, int left_shift) { + using gemmlowp::SaturatingRoundingDoublingHighMul; + return SaturatingRoundingDoublingHighMul(x * (1 << left_shift), + quantized_multiplier); +} + +template +int CountLeadingZeros(T integer_input) { + static_assert(std::is_unsigned::value, + "Only unsigned integer types handled."); + const T one_in_leading_positive = static_cast(1) + << (std::numeric_limits::digits - 1); + int leading_zeros = 0; + while (integer_input < one_in_leading_positive) { + integer_input <<= 1; + ++leading_zeros; + } + return leading_zeros; +} + +// DO NOT USE THIS STRUCT FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING ELEMENT-WISE +// BROADCASTING. +// +// NdArrayDesc describes the shape and memory layout of an N-dimensional +// rectangular array of numbers. +// +// NdArrayDesc is basically identical to Dims defined in types.h. +// However, as Dims is to be deprecated, this class exists as an adaptor +// to enable simple unoptimized implementations of element-wise broadcasting +// operations. +template +struct NdArrayDesc { + // The "extent" of each dimension. Indices along dimension d must be in the + // half-open interval [0, extents[d]). + int extents[N]; + + // The number of *elements* (not bytes) between consecutive indices of each + // dimension. + int strides[N]; +}; + +// DO NOT USE THIS FUNCTION FOR NEW FUNCTIONALITY BEYOND IMPLEMENTING +// ELEMENT-WISE BROADCASTING. +// +// Same as Offset(), except takes as NdArrayDesc instead of Dims. +inline int SubscriptToIndex(const NdArrayDesc<4>& desc, int i0, int i1, int i2, + int i3) { + TFLITE_DCHECK(i0 >= 0 && i0 < desc.extents[0]); + TFLITE_DCHECK(i1 >= 0 && i1 < desc.extents[1]); + TFLITE_DCHECK(i2 >= 0 && i2 < desc.extents[2]); + TFLITE_DCHECK(i3 >= 0 && i3 < desc.extents[3]); + return i0 * desc.strides[0] + i1 * desc.strides[1] + i2 * desc.strides[2] + + i3 * desc.strides[3]; +} + +// Given the dimensions of the operands for an element-wise binary broadcast, +// adjusts them so that they can be directly iterated over with simple loops. +// Returns the adjusted dims as instances of NdArrayDesc in 'desc0_out' and +// 'desc1_out'. 'desc0_out' and 'desc1_out' cannot be nullptr. +// +// This function assumes that the two input shapes are compatible up to +// broadcasting and the shorter one has already been prepended with 1s to be the +// same length. E.g., if shape0 is (1, 16, 16, 64) and shape1 is (1, 64), +// shape1 must already have been prepended to be (1, 1, 1, 64). Recall that +// Dims refer to shapes in reverse order. In this case, input0_dims will be +// (64, 16, 16, 1) and input1_dims will be (64, 1, 1, 1). +// +// When two shapes are compatible up to broadcasting, for each dimension d, +// the input extents are either equal, or one of them is 1. +// +// This function performs the following for each dimension d: +// - If the extents are equal, then do nothing since the loop that walks over +// both of the input arrays is correct. +// - Otherwise, one (and only one) of the extents must be 1. Say extent0 is 1 +// and extent1 is e1. Then set extent0 to e1 and stride0 *to 0*. This allows +// array0 to be referenced *at any index* in dimension d and still access the +// same slice. +template +inline void NdArrayDescsForElementwiseBroadcast(const Dims& input0_dims, + const Dims& input1_dims, + NdArrayDesc* desc0_out, + NdArrayDesc* desc1_out) { + TFLITE_DCHECK(desc0_out != nullptr); + TFLITE_DCHECK(desc1_out != nullptr); + + // Copy dims to desc. + for (int i = 0; i < N; ++i) { + desc0_out->extents[i] = input0_dims.sizes[i]; + desc0_out->strides[i] = input0_dims.strides[i]; + desc1_out->extents[i] = input1_dims.sizes[i]; + desc1_out->strides[i] = input1_dims.strides[i]; + } + + // Walk over each dimension. If the extents are equal do nothing. + // Otherwise, set the desc with extent 1 to have extent equal to the other and + // stride 0. + for (int i = 0; i < N; ++i) { + const int extent0 = ArraySize(input0_dims, i); + const int extent1 = ArraySize(input1_dims, i); + if (extent0 != extent1) { + if (extent0 == 1) { + desc0_out->strides[i] = 0; + desc0_out->extents[i] = extent1; + } else { + TFLITE_DCHECK_EQ(extent1, 1); + desc1_out->strides[i] = 0; + desc1_out->extents[i] = extent0; + } + } + } +} + +inline void Conv(const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + const float* bias_data, const Dims<4>& bias_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, float output_activation_min, + float output_activation_max, float* output_data, + const Dims<4>& output_dims, float* im2col_data, + const Dims<4>& im2col_dims) { + (void)im2col_data; // only used in optimized code. + (void)im2col_dims; // only used in optimized code. + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 0); + const int output_depth = MatchingArraySize(filter_dims, 3, output_dims, 0); + if (bias_data) { + TFLITE_DCHECK_EQ(ArraySize(filter_dims, 3), ArraySize(bias_dims, 0)); + } + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int filter_height = ArraySize(filter_dims, 2); + const int filter_width = ArraySize(filter_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + for (int batch = 0; batch < batches; ++batch) { + for (int out_y = 0; out_y < output_height; ++out_y) { + for (int out_x = 0; out_x < output_width; ++out_x) { + for (int out_channel = 0; out_channel < output_depth; ++out_channel) { + const int in_x_origin = (out_x * stride_width) - pad_width; + const int in_y_origin = (out_y * stride_height) - pad_height; + float total = 0.f; + for (int filter_y = 0; filter_y < filter_height; ++filter_y) { + for (int filter_x = 0; filter_x < filter_width; ++filter_x) { + for (int in_channel = 0; in_channel < input_depth; ++in_channel) { + const int in_x = in_x_origin + filter_x; + const int in_y = in_y_origin + filter_y; + // If the location is outside the bounds of the input image, + // use zero as a default value. + if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) && + (in_y < input_height)) { + float input_value = input_data[Offset(input_dims, in_channel, + in_x, in_y, batch)]; + float filter_value = + filter_data[Offset(filter_dims, in_channel, filter_x, + filter_y, out_channel)]; + total += (input_value * filter_value); + } + } + } + } + float bias_value = 0.0f; + if (bias_data) { + bias_value = bias_data[Offset(bias_dims, out_channel, 0, 0, 0)]; + } + output_data[Offset(output_dims, out_channel, out_x, out_y, batch)] = + ActivationFunctionWithMinMax(total + bias_value, + output_activation_min, + output_activation_max); + } + } + } + } +} + +// legacy, for compatibility with old checked-in code +template +void Conv(const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + const float* bias_data, const Dims<4>& bias_dims, int stride_width, + int stride_height, int pad_width, int pad_height, float* output_data, + const Dims<4>& output_dims, float* im2col_data, + const Dims<4>& im2col_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + Conv(input_data, input_dims, filter_data, filter_dims, bias_data, bias_dims, + stride_width, stride_height, pad_width, pad_height, + output_activation_min, output_activation_max, output_data, output_dims, + im2col_data, im2col_dims); +} + +// legacy, for compatibility with old checked-in code +template +void Conv(const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + const float* bias_data, const Dims<4>& bias_dims, int stride, + int pad_width, int pad_height, float* output_data, + const Dims<4>& output_dims, float* im2col_data, + const Dims<4>& im2col_dims) { + Conv(input_data, input_dims, filter_data, filter_dims, bias_data, + bias_dims, stride, stride, pad_width, pad_height, output_data, + output_dims, im2col_data, im2col_dims); +} + +inline void Conv(const uint8* input_data, const Dims<4>& input_dims, + int32 input_offset, const uint8* filter_data, + const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims, uint8* im2col_data, + const Dims<4>& im2col_dims, + gemmlowp::GemmContext* gemm_context) { + (void)im2col_data; // only used in optimized code. + (void)im2col_dims; // only used in optimized code. + (void)gemm_context; // only used in optimized code. + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 0); + const int output_depth = + MatchingArraySize(filter_dims, 3, bias_dims, 0, output_dims, 0); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int filter_height = ArraySize(filter_dims, 2); + const int filter_width = ArraySize(filter_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + for (int batch = 0; batch < batches; ++batch) { + for (int out_y = 0; out_y < output_height; ++out_y) { + for (int out_x = 0; out_x < output_width; ++out_x) { + for (int out_channel = 0; out_channel < output_depth; ++out_channel) { + const int in_x_origin = (out_x * stride_width) - pad_width; + const int in_y_origin = (out_y * stride_height) - pad_height; + int32 acc = 0; + for (int filter_y = 0; filter_y < filter_height; ++filter_y) { + for (int filter_x = 0; filter_x < filter_width; ++filter_x) { + for (int in_channel = 0; in_channel < input_depth; ++in_channel) { + const int in_x = in_x_origin + filter_x; + const int in_y = in_y_origin + filter_y; + // If the location is outside the bounds of the input image, + // use zero as a default value. + if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) && + (in_y < input_height)) { + int32 input_val = input_data[Offset(input_dims, in_channel, + in_x, in_y, batch)]; + int32 filter_val = + filter_data[Offset(filter_dims, in_channel, filter_x, + filter_y, out_channel)]; + acc += + (filter_val + filter_offset) * (input_val + input_offset); + } + } + } + } + if (bias_data) { + acc += bias_data[Offset(bias_dims, out_channel, 0, 0, 0)]; + } + acc = MultiplyByQuantizedMultiplierSmallerThanOne( + acc, output_multiplier, output_shift); + acc += output_offset; + acc = std::max(acc, output_activation_min); + acc = std::min(acc, output_activation_max); + output_data[Offset(output_dims, out_channel, out_x, out_y, batch)] = + static_cast(acc); + } + } + } + } +} + +// legacy, for compatibility with old checked-in code +template +inline void Conv(const uint8* input_data, const Dims<4>& input_dims, + int32 input_offset, const uint8* filter_data, + const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims, uint8* im2col_data, + const Dims<4>& im2col_dims, + gemmlowp::GemmContext* gemm_context) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + Conv(input_data, input_dims, input_offset, filter_data, filter_dims, + filter_offset, bias_data, bias_dims, stride_width, stride_height, + pad_width, pad_height, output_offset, output_multiplier, output_shift, + output_activation_min, output_activation_max, output_data, output_dims, + im2col_data, im2col_dims, gemm_context); +} + +// legacy, for compatibility with old checked-in code +template +void Conv(const uint8* input_data, const Dims<4>& input_dims, + int32 input_offset, const uint8* filter_data, + const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, int stride, + int pad_width, int pad_height, int32 output_offset, + int32 output_multiplier, int output_shift, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims, uint8* im2col_data, + const Dims<4>& im2col_dims, gemmlowp::GemmContext* gemm_context) { + Conv(input_data, input_dims, input_offset, filter_data, filter_dims, + filter_offset, bias_data, bias_dims, stride, stride, pad_width, + pad_height, output_offset, output_multiplier, output_shift, + output_activation_min, output_activation_max, output_data, + output_dims, im2col_data, im2col_dims, gemm_context); +} + +template +inline void DepthToSpace(const T* input_data, const Dims<4>& input_dims, + int block_size, T* output_data, + const Dims<4>& output_dims) { + const int input_depth = ArraySize(input_dims, 0); + const int input_width = ArraySize(input_dims, 1); + const int input_height = ArraySize(input_dims, 2); + const int input_batch = ArraySize(input_dims, 3); + + const int output_depth = ArraySize(output_dims, 0); + const int output_width = ArraySize(output_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_batch = ArraySize(output_dims, 3); + + TFLITE_DCHECK_EQ(input_width * block_size, output_width); + TFLITE_DCHECK_EQ(input_height * block_size, output_height); + TFLITE_DCHECK_EQ(input_depth, output_depth * block_size * block_size); + TFLITE_DCHECK_EQ(input_batch, output_batch); + + for (int out_b = 0; out_b < output_batch; ++out_b) { + for (int out_h = 0; out_h < output_height; ++out_h) { + for (int out_w = 0; out_w < output_width; ++out_w) { + for (int out_d = 0; out_d < output_depth; ++out_d) { + const int in_d = + out_d + ((out_h % block_size) * block_size + out_w % block_size) * + output_depth; + const int in_w = out_w / block_size; + const int in_h = out_h / block_size; + const int in_b = out_b; + + const int output_index = + Offset(output_dims, out_d, out_w, out_h, out_b); + const int input_index = Offset(input_dims, in_d, in_w, in_h, in_b); + + output_data[output_index] = input_data[input_index]; + } + } + } + } +} + +template +inline void SpaceToDepth(const T* input_data, const Dims<4>& input_dims, + int block_size, T* output_data, + const Dims<4>& output_dims) { + const int input_depth = ArraySize(input_dims, 0); + const int input_width = ArraySize(input_dims, 1); + const int input_height = ArraySize(input_dims, 2); + const int input_batch = ArraySize(input_dims, 3); + + const int output_depth = ArraySize(output_dims, 0); + const int output_width = ArraySize(output_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_batch = ArraySize(output_dims, 3); + + TFLITE_DCHECK_EQ(input_width, output_width * block_size); + TFLITE_DCHECK_EQ(input_height, output_height * block_size); + TFLITE_DCHECK_EQ(input_depth * block_size * block_size, output_depth); + TFLITE_DCHECK_EQ(input_batch, output_batch); + + for (int in_b = 0; in_b < input_batch; ++in_b) { + for (int in_h = 0; in_h < input_height; ++in_h) { + for (int in_w = 0; in_w < input_width; ++in_w) { + for (int in_d = 0; in_d < input_depth; ++in_d) { + const int out_d = + in_d + ((in_h % block_size) * block_size + in_w % block_size) * + input_depth; + const int out_w = in_w / block_size; + const int out_h = in_h / block_size; + const int out_b = in_b; + + const int output_index = + Offset(output_dims, out_d, out_w, out_h, out_b); + const int input_index = Offset(input_dims, in_d, in_w, in_h, in_b); + + output_data[output_index] = input_data[input_index]; + } + } + } + } +} + +inline void FullyConnected(const float* input_data, const Dims<4>& input_dims, + const float* weights_data, + const Dims<4>& weights_dims, const float* bias_data, + const Dims<4>& bias_dims, + float output_activation_min, + float output_activation_max, float* output_data, + const Dims<4>& output_dims) { + // TODO(benoitjacob): This really should be: + // const int batches = ArraySize(output_dims, 1); + // but the current --variable_batch hack consists in overwriting the 3rd + // dimension with the runtime batch size, as we don't keep track for each + // array of which dimension is the batch dimension in it. + const int batches = ArraySize(output_dims, 1) * ArraySize(output_dims, 2) * + ArraySize(output_dims, 3); + const int output_depth = MatchingArraySize(weights_dims, 1, output_dims, 0); + const int accum_depth = ArraySize(weights_dims, 0); + TFLITE_DCHECK(IsPackedWithoutStrides(input_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(weights_dims)); + for (int b = 0; b < batches; ++b) { + for (int out_c = 0; out_c < output_depth; ++out_c) { + float total = 0.f; + for (int d = 0; d < accum_depth; ++d) { + total += input_data[b * accum_depth + d] * + weights_data[out_c * accum_depth + d]; + } + float bias_value = 0.0f; + if (bias_data) { + bias_value = bias_data[Offset(bias_dims, out_c, 0, 0, 0)]; + } + output_data[out_c + output_depth * b] = ActivationFunctionWithMinMax( + total + bias_value, output_activation_min, output_activation_max); + } + } +} + +// legacy, for compatibility with old checked-in code +template +void FullyConnected(const float* input_data, const Dims<4>& input_dims, + const float* weights_data, const Dims<4>& weights_dims, + const float* bias_data, const Dims<4>& bias_dims, + float* output_data, const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + FullyConnected(input_data, input_dims, weights_data, weights_dims, bias_data, + bias_dims, output_activation_min, output_activation_max, + output_data, output_dims); +} + +inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims, + int32 input_offset, const uint8* filter_data, + const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, + int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims, + gemmlowp::GemmContext* gemm_context) { + (void)gemm_context; // only used in optimized code. + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + // TODO(benoitjacob): This really should be: + // const int batches = ArraySize(output_dims, 1); + // but the current --variable_batch hack consists in overwriting the 3rd + // dimension with the runtime batch size, as we don't keep track for each + // array of which dimension is the batch dimension in it. + const int batches = ArraySize(output_dims, 1) * ArraySize(output_dims, 2) * + ArraySize(output_dims, 3); + const int output_depth = MatchingArraySize(filter_dims, 1, output_dims, 0); + const int accum_depth = ArraySize(filter_dims, 0); + TFLITE_DCHECK(IsPackedWithoutStrides(input_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(filter_dims)); + for (int b = 0; b < batches; ++b) { + for (int out_c = 0; out_c < output_depth; ++out_c) { + int32 acc = 0; + for (int d = 0; d < accum_depth; ++d) { + int32 input_val = input_data[b * accum_depth + d]; + int32 filter_val = filter_data[out_c * accum_depth + d]; + acc += (filter_val + filter_offset) * (input_val + input_offset); + } + if (bias_data) { + acc += bias_data[Offset(bias_dims, out_c, 0, 0, 0)]; + } + acc = MultiplyByQuantizedMultiplierSmallerThanOne(acc, output_multiplier, + output_shift); + acc += output_offset; + acc = std::max(acc, output_activation_min); + acc = std::min(acc, output_activation_max); + output_data[out_c + output_depth * b] = static_cast(acc); + } + } +} + +// legacy, for compatibility with old checked-in code +template +void FullyConnected(const uint8* input_data, const Dims<4>& input_dims, + int32 input_offset, const uint8* filter_data, + const Dims<4>& filter_dims, int32 filter_offset, + const int32* bias_data, const Dims<4>& bias_dims, + int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims, + gemmlowp::GemmContext* gemm_context) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + FullyConnected(input_data, input_dims, input_offset, filter_data, filter_dims, + filter_offset, bias_data, bias_dims, output_offset, + output_multiplier, output_shift, output_activation_min, + output_activation_max, output_data, output_dims, gemm_context); +} + +template +void NonGlobalBatchNormalization( + const float* input_data, const Dims<4>& input_dims, const float* mean_data, + const Dims<4>& mean_dims, const float* multiplier_data, + const Dims<4>& multiplier_dims, const float* offset_data, + const Dims<4>& offset_dims, float* output_data, + const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = + MatchingArraySize(input_dims, 2, mean_dims, 2, multiplier_dims, 2, + offset_dims, 2, output_dims, 2); + const int width = + MatchingArraySize(input_dims, 1, mean_dims, 1, multiplier_dims, 1, + offset_dims, 1, output_dims, 1); + const int depth = + MatchingArraySize(input_dims, 0, mean_dims, 0, multiplier_dims, 0, + offset_dims, 0, output_dims, 0); + + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + output_data[Offset(output_dims, c, x, y, b)] = ActivationFunction( + (input_data[Offset(input_dims, c, x, y, b)] - + mean_data[Offset(mean_dims, c, x, y, 0)]) * + multiplier_data[Offset(multiplier_dims, c, x, y, 0)] + + offset_data[Offset(offset_dims, c, x, y, 0)]); + } + } + } + } +} + +template +void GlobalBatchNormalization(const float* input_data, + const Dims<4>& input_dims, const float* mean_data, + const Dims<4>& mean_dims, + const float* multiplier_data, + const Dims<4>& multiplier_dims, + const float* offset_data, + const Dims<4>& offset_dims, float* output_data, + const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = + MatchingArraySize(input_dims, 0, mean_dims, 0, multiplier_dims, 0, + offset_dims, 0, output_dims, 0); + + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + output_data[Offset(output_dims, c, x, y, b)] = ActivationFunction( + (input_data[Offset(input_dims, c, x, y, b)] - + mean_data[Offset(mean_dims, c, 0, 0, 0)]) * + multiplier_data[Offset(multiplier_dims, c, 0, 0, 0)] + + offset_data[Offset(offset_dims, c, 0, 0, 0)]); + } + } + } + } +} + +inline void Relu(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + float val = input_data[Offset(input_dims, c, x, y, b)]; + const float lower = 0; + float clamped = val < lower ? lower : val; + output_data[Offset(output_dims, c, x, y, b)] = clamped; + } + } + } + } +} + +inline void Relu1(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + float val = input_data[Offset(input_dims, c, x, y, b)]; + const float upper = 1; + const float lower = -1; + float clamped = val > upper ? upper : val < lower ? lower : val; + output_data[Offset(output_dims, c, x, y, b)] = clamped; + } + } + } + } +} + +inline void Relu6(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + float val = input_data[Offset(input_dims, c, x, y, b)]; + const float upper = 6; + const float lower = 0; + float clamped = val > upper ? upper : val < lower ? lower : val; + output_data[Offset(output_dims, c, x, y, b)] = clamped; + } + } + } + } +} + +template +void L2Normalization(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + static_assert(Ac == FusedActivationFunctionType::kNone, ""); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + float squared_l2_norm = 0; + for (int c = 0; c < depth; ++c) { + float val = input_data[Offset(input_dims, c, x, y, b)]; + squared_l2_norm += val * val; + } + float l2_norm = std::sqrt(squared_l2_norm); + for (int c = 0; c < depth; ++c) { + output_data[Offset(output_dims, c, x, y, b)] = + input_data[Offset(input_dims, c, x, y, b)] / l2_norm; + } + } + } + } +} + +inline void GetInvSqrtQuantizedMultiplier(int32 input, int32* output_inv_sqrt, + int* output_shift) { + *output_shift = 11; + while (input >= (1 << 29)) { + input /= 4; + ++*output_shift; + } + TFLITE_DCHECK_GT(input, 0); + const unsigned max_left_shift_bits = __builtin_clz(input) - 1; + const unsigned max_left_shift_bit_pairs = max_left_shift_bits / 2; + const unsigned left_shift_bit_pairs = max_left_shift_bit_pairs - 1; + *output_shift -= left_shift_bit_pairs; + input <<= 2 * left_shift_bit_pairs; + TFLITE_DCHECK_GE(input, (1 << 27)); + TFLITE_DCHECK_LT(input, (1 << 29)); + using gemmlowp::FixedPoint; + using gemmlowp::Rescale; + using gemmlowp::SaturatingRoundingMultiplyByPOT; + // Using 3 integer bits gives us enough room for the internal arithmetic in + // this Newton-Raphson iteration. + using F3 = FixedPoint; + using F0 = FixedPoint; + const F3 fixedpoint_input = F3::FromRaw(input >> 1); + const F3 fixedpoint_half_input = + SaturatingRoundingMultiplyByPOT<-1>(fixedpoint_input); + const F3 fixedpoint_half_three = + GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F3, (1 << 28) + (1 << 27), 1.5); + // Newton-Raphson iteration + // Naive unoptimized starting guess: x = 1 + F3 x = F3::One(); + // Naive unoptimized number of iterations: 5 + for (int i = 0; i < 5; i++) { + const F3 x3 = Rescale<3>(x * x * x); + x = Rescale<3>(fixedpoint_half_three * x - fixedpoint_half_input * x3); + } + const F0 fixedpoint_half_sqrt_2 = + GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F0, 1518500250, std::sqrt(2.) / 2.); + x = x * fixedpoint_half_sqrt_2; + *output_inv_sqrt = x.raw(); + if (*output_shift < 0) { + *output_inv_sqrt <<= -*output_shift; + *output_shift = 0; + } +} + +inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims, + int32 input_zero_point, uint8* output_data, + const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + TFLITE_DCHECK_EQ(batches, 1); + TFLITE_DCHECK_EQ(height, 1); + TFLITE_DCHECK_EQ(width, 1); + int32 square_l2_norm = 0; + for (int i = 0; i < depth; i++) { + int32 diff = input_data[Offset(input_dims, i, 0, 0, 0)] - input_zero_point; + square_l2_norm += diff * diff; + } + int32 inv_l2norm_multiplier; + int inv_l2norm_shift; + GetInvSqrtQuantizedMultiplier(square_l2_norm, &inv_l2norm_multiplier, + &inv_l2norm_shift); + + for (int i = 0; i < depth; i++) { + int32 diff = input_data[Offset(input_dims, i, 0, 0, 0)] - input_zero_point; + int32 rescaled_diff = MultiplyByQuantizedMultiplierSmallerThanOne( + 128 * diff, inv_l2norm_multiplier, inv_l2norm_shift); + int32 unclamped_output_val = 128 + rescaled_diff; + int32 output_val = std::min(255, std::max(0, unclamped_output_val)); + output_data[Offset(output_dims, i, 0, 0, 0)] = + static_cast(output_val); + } +} + +inline void Add(const float* input1_data, const Dims<4>& input1_dims, + const float* input2_data, const Dims<4>& input2_dims, + float output_activation_min, float output_activation_max, + float* output_data, const Dims<4>& output_dims) { + const int batches = + MatchingArraySize(input1_dims, 3, input2_dims, 3, output_dims, 3); + const int height = + MatchingArraySize(input1_dims, 2, input2_dims, 2, output_dims, 2); + const int width = + MatchingArraySize(input1_dims, 1, input2_dims, 1, output_dims, 1); + const int depth = + MatchingArraySize(input1_dims, 0, input2_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + output_data[Offset(output_dims, c, x, y, b)] = + ActivationFunctionWithMinMax( + input1_data[Offset(input1_dims, c, x, y, b)] + + input2_data[Offset(input2_dims, c, x, y, b)], + output_activation_min, output_activation_max); + } + } + } + } +} + +// legacy, for compatibility with old checked-in code +template +void Add(const float* input1_data, const Dims<4>& input1_dims, + const float* input2_data, const Dims<4>& input2_dims, + float* output_data, const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + + Add(input1_data, input1_dims, input2_data, input2_dims, output_activation_min, + output_activation_max, output_data, output_dims); +} + +template +inline void Add(int left_shift, const uint8* input1_data, + const Dims<4>& input1_dims, int32 input1_offset, + int32 input1_multiplier, int input1_shift, + const uint8* input2_data, const Dims<4>& input2_dims, + int32 input2_offset, int32 input2_multiplier, int input2_shift, + int32 output_offset, int32 output_multiplier, int output_shift, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + const int batches = + MatchingArraySize(input1_dims, 3, input2_dims, 3, output_dims, 3); + const int height = + MatchingArraySize(input1_dims, 2, input2_dims, 2, output_dims, 2); + const int width = + MatchingArraySize(input1_dims, 1, input2_dims, 1, output_dims, 1); + const int depth = + MatchingArraySize(input1_dims, 0, input2_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + const int32 input1_val = + input1_offset + input1_data[Offset(input1_dims, c, x, y, b)]; + const int32 input2_val = + input2_offset + input2_data[Offset(input2_dims, c, x, y, b)]; + const int32 shifted_input1_val = input1_val * (1 << left_shift); + const int32 shifted_input2_val = input2_val * (1 << left_shift); + const int32 scaled_input1_val = + MultiplyByQuantizedMultiplierSmallerThanOne( + shifted_input1_val, input1_multiplier, input1_shift); + const int32 scaled_input2_val = + MultiplyByQuantizedMultiplierSmallerThanOne( + shifted_input2_val, input2_multiplier, input2_shift); + const int32 raw_sum = scaled_input1_val + scaled_input2_val; + const int32 raw_output = + MultiplyByQuantizedMultiplierSmallerThanOne( + raw_sum, output_multiplier, output_shift) + + output_offset; + const int32 clamped_output = + std::min(output_activation_max, + std::max(output_activation_min, raw_output)); + output_data[Offset(output_dims, c, x, y, b)] = + static_cast(clamped_output); + } + } + } + } +} + +// TODO(jiawen): We can implement BroadcastAdd on buffers of arbitrary +// dimensionality if the runtime code does a single loop over one dimension +// that handles broadcasting as the base case. The code generator would then +// generate max(D1, D2) nested for loops. +template +void BroadcastAdd(const float* input1_data, const Dims<4>& input1_dims, + const float* input2_data, const Dims<4>& input2_dims, + float* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("BroadcastAdd"); + + NdArrayDesc<4> desc1; + NdArrayDesc<4> desc2; + NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); + + // In Tensorflow, the dimensions are canonically named (batch_number, row, + // col, channel), with extents (batches, height, width, depth), with the + // trailing dimension changing most rapidly (channels has the smallest stride, + // typically 1 element). + // + // In generated C code, we store arrays with the dimensions reversed. The + // first dimension has smallest stride. + // + // We name our variables by their Tensorflow convention, but generate C code + // nesting loops such that the innermost loop has the smallest stride for the + // best cache behavior. + for (int b = 0; b < ArraySize(output_dims, 3); ++b) { + for (int y = 0; y < ArraySize(output_dims, 2); ++y) { + for (int x = 0; x < ArraySize(output_dims, 1); ++x) { + for (int c = 0; c < ArraySize(output_dims, 0); ++c) { + output_data[Offset(output_dims, c, x, y, b)] = ActivationFunction( + input1_data[SubscriptToIndex(desc1, c, x, y, b)] + + input2_data[SubscriptToIndex(desc2, c, x, y, b)]); + } + } + } + } +} + +inline void BroadcastAdd(int left_shift, const uint8* input1_data, + const Dims<4>& input1_dims, int32 input1_offset, + int32 input1_multiplier, int input1_shift, + const uint8* input2_data, const Dims<4>& input2_dims, + int32 input2_offset, int32 input2_multiplier, + int input2_shift, int32 output_offset, + int32 output_multiplier, int output_shift, + int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("BroadcastAdd/8bit"); + + NdArrayDesc<4> desc1; + NdArrayDesc<4> desc2; + NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); + + // In Tensorflow, the dimensions are canonically named (batch_number, row, + // col, channel), with extents (batches, height, width, depth), with the + // trailing dimension changing most rapidly (channels has the smallest stride, + // typically 1 element). + // + // In generated C code, we store arrays with the dimensions reversed. The + // first dimension has smallest stride. + // + // We name our variables by their Tensorflow convention, but generate C code + // nesting loops such that the innermost loop has the smallest stride for the + // best cache behavior. + for (int b = 0; b < ArraySize(output_dims, 3); ++b) { + for (int y = 0; y < ArraySize(output_dims, 2); ++y) { + for (int x = 0; x < ArraySize(output_dims, 1); ++x) { + for (int c = 0; c < ArraySize(output_dims, 0); ++c) { + const int32 input1_val = + input1_offset + input1_data[SubscriptToIndex(desc1, c, x, y, b)]; + const int32 input2_val = + input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)]; + const int32 shifted_input1_val = input1_val * (1 << left_shift); + const int32 shifted_input2_val = input2_val * (1 << left_shift); + const int32 scaled_input1_val = + MultiplyByQuantizedMultiplierSmallerThanOne( + shifted_input1_val, input1_multiplier, input1_shift); + const int32 scaled_input2_val = + MultiplyByQuantizedMultiplierSmallerThanOne( + shifted_input2_val, input2_multiplier, input2_shift); + const int32 raw_sum = scaled_input1_val + scaled_input2_val; + const int32 raw_output = + MultiplyByQuantizedMultiplierSmallerThanOne( + raw_sum, output_multiplier, output_shift) + + output_offset; + const int32 clamped_output = + std::min(output_activation_max, + std::max(output_activation_min, raw_output)); + output_data[Offset(output_dims, c, x, y, b)] = + static_cast(clamped_output); + } + } + } + } +} + +template +inline void BroadcastAdd(int left_shift, const uint8* input1_data, + const Dims<4>& input1_dims, int32 input1_offset, + int32 input1_multiplier, int input1_shift, + const uint8* input2_data, const Dims<4>& input2_dims, + int32 input2_offset, int32 input2_multiplier, + int input2_shift, int32 output_offset, + int32 output_multiplier, int output_shift, + int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + BroadcastAdd(left_shift, input1_data, input1_dims, input1_offset, + input1_multiplier, input1_shift, input2_data, input2_dims, + input2_offset, input2_multiplier, input2_shift, output_offset, + output_multiplier, output_shift, output_activation_min, + output_activation_max, output_data, output_dims); +} + +inline void Mul(const float* input1_data, const Dims<4>& input1_dims, + const float* input2_data, const Dims<4>& input2_dims, + float output_activation_min, float output_activation_max, + float* output_data, const Dims<4>& output_dims) { + const int batches = + MatchingArraySize(input1_dims, 3, input2_dims, 3, output_dims, 3); + const int height = + MatchingArraySize(input1_dims, 2, input2_dims, 2, output_dims, 2); + const int width = + MatchingArraySize(input1_dims, 1, input2_dims, 1, output_dims, 1); + const int depth = + MatchingArraySize(input1_dims, 0, input2_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + output_data[Offset(output_dims, c, x, y, b)] = + ActivationFunctionWithMinMax( + input1_data[Offset(input1_dims, c, x, y, b)] * + input2_data[Offset(input2_dims, c, x, y, b)], + output_activation_min, output_activation_max); + } + } + } + } +} + +// legacy, for compatibility with old checked-in code +template +void Mul(const float* input1_data, const Dims<4>& input1_dims, + const float* input2_data, const Dims<4>& input2_dims, + float* output_data, const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + + Mul(input1_data, input1_dims, input2_data, input2_dims, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// TODO(jiawen): We can implement BroadcastMul on buffers of arbitrary +// dimensionality if the runtime code does a single loop over one dimension +// that handles broadcasting as the base case. The code generator would then +// generate max(D1, D2) nested for loops. +template +void BroadcastMul(const float* input1_data, const Dims<4>& input1_dims, + const float* input2_data, const Dims<4>& input2_dims, + float* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("BroadcastMul"); + + NdArrayDesc<4> desc1; + NdArrayDesc<4> desc2; + NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); + + // In Tensorflow, the dimensions are canonically named (batch_number, row, + // col, channel), with extents (batches, height, width, depth), with the + // trailing dimension changing most rapidly (channels has the smallest + // stride, typically 1 element). + // + // In generated C code, we store arrays with the dimensions reversed. The + // first dimension has smallest stride. + // + // We name our variables by their Tensorflow convention, but generate C code + // nesting loops such that the innermost loop has the smallest stride for + // the best cache behavior. + for (int b = 0; b < ArraySize(output_dims, 3); ++b) { + for (int y = 0; y < ArraySize(output_dims, 2); ++y) { + for (int x = 0; x < ArraySize(output_dims, 1); ++x) { + for (int c = 0; c < ArraySize(output_dims, 0); ++c) { + output_data[Offset(output_dims, c, x, y, b)] = ActivationFunction( + input1_data[SubscriptToIndex(desc1, c, x, y, b)] * + input2_data[SubscriptToIndex(desc2, c, x, y, b)]); + } + } + } + } +} + +inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims, + int32 input1_offset, const uint8* input2_data, + const Dims<4>& input2_dims, int32 input2_offset, + int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("BroadcastMul/8bit"); + + NdArrayDesc<4> desc1; + NdArrayDesc<4> desc2; + NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); + + // In Tensorflow, the dimensions are canonically named (batch_number, row, + // col, channel), with extents (batches, height, width, depth), with the + // trailing dimension changing most rapidly (channels has the smallest + // stride, typically 1 element). + // + // In generated C code, we store arrays with the dimensions reversed. The + // first dimension has smallest stride. + // + // We name our variables by their Tensorflow convention, but generate C code + // nesting loops such that the innermost loop has the smallest stride for + // the best cache behavior. + for (int b = 0; b < ArraySize(output_dims, 3); ++b) { + for (int y = 0; y < ArraySize(output_dims, 2); ++y) { + for (int x = 0; x < ArraySize(output_dims, 1); ++x) { + for (int c = 0; c < ArraySize(output_dims, 0); ++c) { + const int32 input1_val = + input1_offset + input1_data[SubscriptToIndex(desc1, c, x, y, b)]; + const int32 input2_val = + input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)]; + const int32 unclamped_result = + output_offset + + MultiplyByQuantizedMultiplierSmallerThanOne( + input1_val * input2_val, output_multiplier, output_shift); + const int32 clamped_output = + std::min(output_activation_max, + std::max(output_activation_min, unclamped_result)); + output_data[Offset(output_dims, c, x, y, b)] = + static_cast(clamped_output); + } + } + } + } +} + +// legacy, for compatibility with old checked-in code +template +inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims, + int32 input1_offset, const uint8* input2_data, + const Dims<4>& input2_dims, int32 input2_offset, + int32 output_offset, int32 output_multiplier, + int output_shift, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + BroadcastMul(input1_data, input1_dims, input1_offset, input2_data, + input2_dims, input2_offset, output_offset, output_multiplier, + output_shift, output_activation_min, output_activation_max, + output_data, output_dims); +} + +template +void Concatenation(int concat_dim, const Scalar* const* input_data, + const Dims<4>* const* input_dims, int inputs_count, + Scalar* output_data, const Dims<4>& output_dims) { + TFLITE_DCHECK_GT(inputs_count, 1); + int concat_size = 0; + for (int i = 0; i < inputs_count; i++) { + for (int j = 0; j < 4; j++) { + if (j != concat_dim) { + MatchingArraySize(*input_dims[i], j, output_dims, j); + } + } + concat_size += ArraySize(*input_dims[i], concat_dim); + } + TFLITE_DCHECK_EQ(concat_size, ArraySize(output_dims, concat_dim)); + TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone); + int outer_size = 1; + for (int i = concat_dim + 1; i < 4; i++) { + outer_size *= output_dims.sizes[i]; + } + Scalar* output_ptr = output_data; + for (int k = 0; k < outer_size; k++) { + for (int i = 0; i < inputs_count; ++i) { + const int copy_size = + input_dims[i]->sizes[concat_dim] * input_dims[i]->strides[concat_dim]; + memcpy(output_ptr, input_data[i] + k * copy_size, + copy_size * sizeof(Scalar)); + output_ptr += copy_size; + } + } +} + +template +void DepthConcatenation(const Scalar* const* input_data, + const Dims<4>* const* input_dims, int inputs_count, + Scalar* output_data, const Dims<4>& output_dims) { + Concatenation(0, input_data, input_dims, inputs_count, + output_data, output_dims); +} + +inline void LstmCell(const float* input_data, const Dims<4>& input_dims, + const float* prev_activ_data, + const Dims<4>& prev_activ_dims, const float* weights_data, + const Dims<4>& weights_dims, const float* bias_data, + const Dims<4>& bias_dims, const float* prev_state_data, + const Dims<4>& prev_state_dims, float* output_state_data, + const Dims<4>& output_state_dims, float* output_activ_data, + const Dims<4>& output_activ_dims, float* concat_temp_data, + const Dims<4>& concat_temp_dims, float* activ_temp_data, + const Dims<4>& activ_temp_dims) { + const int batches = + MatchingArraySize(input_dims, 3, prev_activ_dims, 3, prev_state_dims, 3, + output_state_dims, 3, output_activ_dims, 3); + const int height = + MatchingArraySize(input_dims, 2, prev_activ_dims, 2, prev_state_dims, 2, + output_state_dims, 2, output_activ_dims, 2); + const int width = + MatchingArraySize(input_dims, 1, prev_activ_dims, 1, prev_state_dims, 1, + output_state_dims, 1, output_activ_dims, 1); + TFLITE_CHECK_EQ(ArraySize(weights_dims, 2), 1); + TFLITE_CHECK_EQ(ArraySize(weights_dims, 3), 1); + const int input_depth = ArraySize(input_dims, 0); + const int prev_activ_depth = ArraySize(prev_activ_dims, 0); + const int total_input_depth = prev_activ_depth + input_depth; + TFLITE_CHECK_EQ(ArraySize(weights_dims, 0), total_input_depth); + TFLITE_CHECK_EQ(MatchingArraySize(bias_dims, 1, bias_dims, 2, bias_dims, 3), + 1); + const int intern_activ_depth = + MatchingArraySize(weights_dims, 1, bias_dims, 0); + TFLITE_CHECK_EQ(intern_activ_depth % 4, 0); + const int output_depth = + MatchingArraySize(prev_state_dims, 0, prev_activ_dims, 0, + output_state_dims, 0, output_activ_dims, 0); + TFLITE_CHECK_EQ(output_depth, intern_activ_depth / 4); + + // Concatenate prev_activ and input data together + std::vector concat_input_arrays_data; + std::vector const*> concat_input_arrays_dims; + concat_input_arrays_data.push_back(input_data); + concat_input_arrays_data.push_back(prev_activ_data); + concat_input_arrays_dims.push_back(&input_dims); + concat_input_arrays_dims.push_back(&prev_activ_dims); + Concatenation( + 0, &(concat_input_arrays_data[0]), &(concat_input_arrays_dims[0]), + concat_input_arrays_data.size(), concat_temp_data, concat_temp_dims); + + // Fully connected + FullyConnected( + concat_temp_data, concat_temp_dims, weights_data, weights_dims, bias_data, + bias_dims, activ_temp_data, activ_temp_dims); + + // Memory state update (the LSTM "guts") + for (int b = 0; b < batches; ++b) { + for (int w = 0; w < width; ++w) { + for (int h = 0; h < height; ++h) { + for (int c = 0; c < output_depth; ++c) { + const float input_gate = + 1.f / + (1.f + std::exp(-activ_temp_data[Offset( + activ_temp_dims, 0 * output_depth + c, w, h, b)])); + const float new_input = std::tanh(activ_temp_data[Offset( + activ_temp_dims, 1 * output_depth + c, w, h, b)]); + const float forget_gate = + 1.f / + (1.f + std::exp(-activ_temp_data[Offset( + activ_temp_dims, 2 * output_depth + c, w, h, b)])); + const float output_gate = + 1.f / + (1.f + std::exp(-activ_temp_data[Offset( + activ_temp_dims, 3 * output_depth + c, w, h, b)])); + const float new_state = + input_gate * new_input + + forget_gate * + prev_state_data[Offset(prev_state_dims, c, w, h, b)]; + output_state_data[Offset(output_state_dims, c, w, h, b)] = new_state; + output_activ_data[Offset(output_activ_dims, c, w, h, b)] = + output_gate * std::tanh(new_state); + } + } + } + } +} + +template +void TensorFlowSplit(const Scalar* input_data, const Dims<4>& input_dims, + int outputs_count, Scalar* const* output_data, + const Dims<4>* const* output_dims) { + TFLITE_DCHECK_GE(outputs_count, 1); + for (int i = 0; i < outputs_count; i++) { + /* batches = */ MatchingArraySize(*output_dims[i], 3, input_dims, 3); + /* height = */ MatchingArraySize(*output_dims[i], 2, input_dims, 2); + /* width = */ MatchingArraySize(*output_dims[i], 1, input_dims, 1); + } + const int batches = MatchingArraySize(*output_dims[0], 3, input_dims, 3); + const int height = MatchingArraySize(*output_dims[0], 2, input_dims, 2); + const int width = MatchingArraySize(*output_dims[0], 1, input_dims, 1); + // for now we dont have a model with a TensorFlowSplit + // with fused activation function. + TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + int in_c = 0; + for (int i = 0; i < outputs_count; ++i) { + const int depth = ArraySize(*output_dims[i], 0); + for (int c = 0; c < depth; ++c) { + output_data[i][Offset(*output_dims[i], c, x, y, b)] = + input_data[Offset(input_dims, in_c, x, y, b)]; + in_c++; + } + } + TFLITE_DCHECK(in_c == ArraySize(input_dims, 0)); + } + } + } +} + +// TODO(benoitjacob) make this a proper reference impl without Eigen! +template +using MatrixMap = typename std::conditional< + std::is_const::value, + Eigen::Map::type, + Eigen::Dynamic, Eigen::Dynamic>>, + Eigen::Map>>::type; + +template +MatrixMap MapAsMatrixWithFirstDimAsRows(Scalar* data, + const Dims& dims) { + const int rows = dims.sizes[0]; + int cols = 1; + for (int d = 1; d < N; d++) { + cols *= dims.sizes[d]; + } + return MatrixMap(data, rows, cols); +} + +template +MatrixMap MapAsMatrixWithLastDimAsCols(Scalar* data, + const Dims& dims) { + const int cols = dims.sizes[N - 1]; + int rows = 1; + for (int d = 0; d < N - 1; d++) { + rows *= dims.sizes[d]; + } + return MatrixMap(data, rows, cols); +} + +inline int NodeOffset(int b, int h, int w, int height, int width) { + return (b * height + h) * width + w; +} + +inline void AveragePool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int filter_width, int filter_height, + float output_activation_min, + float output_activation_max, float* output_data, + const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + for (int batch = 0; batch < batches; ++batch) { + for (int out_y = 0; out_y < output_height; ++out_y) { + for (int out_x = 0; out_x < output_width; ++out_x) { + for (int channel = 0; channel < depth; ++channel) { + const int in_x_origin = (out_x * stride_width) - pad_width; + const int in_y_origin = (out_y * stride_height) - pad_height; + // Compute the boundaries of the filter region clamped so as to + // ensure that the filter window fits in the input array. + const int filter_x_start = std::max(0, -in_x_origin); + const int filter_x_end = + std::min(filter_width, input_width - in_x_origin); + const int filter_y_start = std::max(0, -in_y_origin); + const int filter_y_end = + std::min(filter_height, input_height - in_y_origin); + float total = 0.f; + float filter_count = 0; + for (int filter_y = filter_y_start; filter_y < filter_y_end; + ++filter_y) { + for (int filter_x = filter_x_start; filter_x < filter_x_end; + ++filter_x) { + const int in_x = in_x_origin + filter_x; + const int in_y = in_y_origin + filter_y; + total += + input_data[Offset(input_dims, channel, in_x, in_y, batch)]; + filter_count++; + } + } + const float average = total / filter_count; + output_data[Offset(output_dims, channel, out_x, out_y, batch)] = + ActivationFunctionWithMinMax(average, output_activation_min, + output_activation_max); + } + } + } + } +} + +// legacy, for compatibility with old checked-in code +template +void AveragePool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int filter_width, int filter_height, + float* output_data, const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + AveragePool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void AveragePool(const float* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, + int filter_height, float* output_data, + const Dims<4>& output_dims) { + AveragePool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_data, output_dims); +} + +inline void AveragePool(const uint8* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int filter_width, int filter_height, + int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + for (int batch = 0; batch < batches; ++batch) { + for (int out_y = 0; out_y < output_height; ++out_y) { + for (int out_x = 0; out_x < output_width; ++out_x) { + for (int channel = 0; channel < depth; ++channel) { + const int in_x_origin = (out_x * stride_width) - pad_width; + const int in_y_origin = (out_y * stride_height) - pad_height; + // Compute the boundaries of the filter region clamped so as to + // ensure that the filter window fits in the input array. + const int filter_x_start = std::max(0, -in_x_origin); + const int filter_x_end = + std::min(filter_width, input_width - in_x_origin); + const int filter_y_start = std::max(0, -in_y_origin); + const int filter_y_end = + std::min(filter_height, input_height - in_y_origin); + int32 acc = 0; + int filter_count = 0; + for (int filter_y = filter_y_start; filter_y < filter_y_end; + ++filter_y) { + for (int filter_x = filter_x_start; filter_x < filter_x_end; + ++filter_x) { + const int in_x = in_x_origin + filter_x; + const int in_y = in_y_origin + filter_y; + acc += input_data[Offset(input_dims, channel, in_x, in_y, batch)]; + filter_count++; + } + } + acc = (acc + filter_count / 2) / filter_count; + acc = std::max(acc, output_activation_min); + acc = std::min(acc, output_activation_max); + output_data[Offset(output_dims, channel, out_x, out_y, batch)] = + static_cast(acc); + } + } + } + } +} + +// legacy, for compatibility with old checked-in code +template +void AveragePool(const uint8* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int filter_width, int filter_height, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + AveragePool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void AveragePool(const uint8* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, + int filter_height, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + AveragePool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +inline void L2Pool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int filter_width, int filter_height, + float output_activation_min, float output_activation_max, + float* output_data, const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + for (int batch = 0; batch < batches; ++batch) { + for (int out_y = 0; out_y < output_height; ++out_y) { + for (int out_x = 0; out_x < output_width; ++out_x) { + for (int channel = 0; channel < depth; ++channel) { + const int in_x_origin = (out_x * stride_width) - pad_width; + const int in_y_origin = (out_y * stride_height) - pad_height; + // Compute the boundaries of the filter region clamped so as to + // ensure that the filter window fits in the input array. + const int filter_x_start = std::max(0, -in_x_origin); + const int filter_x_end = + std::min(filter_width, input_width - in_x_origin); + const int filter_y_start = std::max(0, -in_y_origin); + const int filter_y_end = + std::min(filter_height, input_height - in_y_origin); + float sum_squares = 0.f; + int filter_count = 0; + for (int filter_y = filter_y_start; filter_y < filter_y_end; + ++filter_y) { + for (int filter_x = filter_x_start; filter_x < filter_x_end; + ++filter_x) { + const int in_x = in_x_origin + filter_x; + const int in_y = in_y_origin + filter_y; + const float val = + input_data[Offset(input_dims, channel, in_x, in_y, batch)]; + sum_squares += val * val; + filter_count++; + } + } + const float l2pool_result = std::sqrt(sum_squares / filter_count); + output_data[Offset(output_dims, channel, out_x, out_y, batch)] = + ActivationFunctionWithMinMax(l2pool_result, output_activation_min, + output_activation_max); + } + } + } + } +} + +// legacy, for compatibility with old checked-in code +template +void L2Pool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, int pad_height, + int filter_width, int filter_height, float* output_data, + const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + + L2Pool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void L2Pool(const float* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, int filter_height, + float* output_data, const Dims<4>& output_dims) { + L2Pool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_data, output_dims); +} + +inline void MaxPool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int filter_width, int filter_height, + float output_activation_min, float output_activation_max, + float* output_data, const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + for (int batch = 0; batch < batches; ++batch) { + for (int out_y = 0; out_y < output_height; ++out_y) { + for (int out_x = 0; out_x < output_width; ++out_x) { + for (int channel = 0; channel < depth; ++channel) { + const int in_x_origin = (out_x * stride_width) - pad_width; + const int in_y_origin = (out_y * stride_height) - pad_height; + // Compute the boundaries of the filter region clamped so as to + // ensure that the filter window fits in the input array. + const int filter_x_start = std::max(0, -in_x_origin); + const int filter_x_end = + std::min(filter_width, input_width - in_x_origin); + const int filter_y_start = std::max(0, -in_y_origin); + const int filter_y_end = + std::min(filter_height, input_height - in_y_origin); + float max = std::numeric_limits::lowest(); + for (int filter_y = filter_y_start; filter_y < filter_y_end; + ++filter_y) { + for (int filter_x = filter_x_start; filter_x < filter_x_end; + ++filter_x) { + const int in_x = in_x_origin + filter_x; + const int in_y = in_y_origin + filter_y; + max = std::max( + max, + input_data[Offset(input_dims, channel, in_x, in_y, batch)]); + } + } + output_data[Offset(output_dims, channel, out_x, out_y, batch)] = + ActivationFunctionWithMinMax(max, output_activation_min, + output_activation_max); + } + } + } + } +} + +// legacy, for compatibility with old checked-in code +template +void MaxPool(const float* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, int pad_height, + int filter_width, int filter_height, float* output_data, + const Dims<4>& output_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + MaxPool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void MaxPool(const float* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, int filter_height, + float* output_data, const Dims<4>& output_dims) { + MaxPool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_data, output_dims); +} + +inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, int filter_width, int filter_height, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + TFLITE_DCHECK_LE(output_activation_min, output_activation_max); + TFLITE_DCHECK_GE(output_activation_min, 0); + TFLITE_DCHECK_LE(output_activation_max, 255); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + for (int batch = 0; batch < batches; ++batch) { + for (int out_y = 0; out_y < output_height; ++out_y) { + for (int out_x = 0; out_x < output_width; ++out_x) { + for (int channel = 0; channel < depth; ++channel) { + const int in_x_origin = (out_x * stride_width) - pad_width; + const int in_y_origin = (out_y * stride_height) - pad_height; + // Compute the boundaries of the filter region clamped so as to + // ensure that the filter window fits in the input array. + const int filter_x_start = std::max(0, -in_x_origin); + const int filter_x_end = + std::min(filter_width, input_width - in_x_origin); + const int filter_y_start = std::max(0, -in_y_origin); + const int filter_y_end = + std::min(filter_height, input_height - in_y_origin); + uint8 max = 0; + for (int filter_y = filter_y_start; filter_y < filter_y_end; + ++filter_y) { + for (int filter_x = filter_x_start; filter_x < filter_x_end; + ++filter_x) { + const int in_x = in_x_origin + filter_x; + const int in_y = in_y_origin + filter_y; + max = std::max( + max, + input_data[Offset(input_dims, channel, in_x, in_y, batch)]); + } + } + max = std::max(max, output_activation_min); + max = std::min(max, output_activation_max); + output_data[Offset(output_dims, channel, out_x, out_y, batch)] = + static_cast(max); + } + } + } + } +} + +// legacy, for compatibility with old checked-in code +template +void MaxPool(const uint8* input_data, const Dims<4>& input_dims, + int stride_width, int stride_height, int pad_width, int pad_height, + int filter_width, int filter_height, int32 output_activation_min, + int32 output_activation_max, uint8* output_data, + const Dims<4>& output_dims) { + static_assert(Ac == FusedActivationFunctionType::kNone || + Ac == FusedActivationFunctionType::kRelu || + Ac == FusedActivationFunctionType::kRelu6 || + Ac == FusedActivationFunctionType::kRelu1, + ""); + if (Ac == FusedActivationFunctionType::kNone) { + TFLITE_DCHECK_EQ(output_activation_min, 0); + TFLITE_DCHECK_EQ(output_activation_max, 255); + } + MaxPool(input_data, input_dims, stride_width, stride_height, pad_width, + pad_height, filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +// legacy, for compatibility with old checked-in code +template +void MaxPool(const uint8* input_data, const Dims<4>& input_dims, int stride, + int pad_width, int pad_height, int filter_width, int filter_height, + int32 output_activation_min, int32 output_activation_max, + uint8* output_data, const Dims<4>& output_dims) { + MaxPool(input_data, input_dims, stride, stride, pad_width, pad_height, + filter_width, filter_height, output_activation_min, + output_activation_max, output_data, output_dims); +} + +inline void LocalResponseNormalization(const float* input_data, + const Dims<4>& input_dims, int range, + float bias, float alpha, float beta, + float* output_data, + const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + const int begin_input_c = std::max(0, c - range); + const int end_input_c = std::min(depth, c + range); + float accum = 0.f; + for (int input_c = begin_input_c; input_c < end_input_c; ++input_c) { + const float input_val = + input_data[Offset(input_dims, input_c, x, y, b)]; + accum += input_val * input_val; + } + const float multiplier = std::pow(bias + alpha * accum, -beta); + output_data[Offset(output_dims, c, x, y, b)] = + input_data[Offset(input_dims, c, x, y, b)] * multiplier; + } + } + } + } +} + +inline void Softmax(const float* input_data, const Dims<4>& input_dims, + float beta, float* output_data, + const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + // Find max element value which we'll use to ensure numerical stability + // taking advantage of the following equality: + // exp(x[i])/sum(exp(x[i])) == exp(x[i]+C)/sum(exp(x[i]+C)) + float max = std::numeric_limits::lowest(); + for (int c = 0; c < depth; ++c) { + max = std::max(max, input_data[Offset(input_dims, c, x, y, b)]); + } + + // Compute sum. + float sum = 0.f; + for (int c = 0; c < depth; ++c) { + sum += std::exp((input_data[Offset(input_dims, c, x, y, b)] - max) * + beta); + } + + // Compute result. + for (int c = 0; c < depth; ++c) { + output_data[Offset(output_dims, c, x, y, b)] = + std::exp((input_data[Offset(input_dims, c, x, y, b)] - max) * + beta) / + sum; + } + } + } + } +} + +inline void Softmax(const uint8* input_data, const Dims<4>& input_dims, + int32 input_beta_multiplier, int32 input_beta_left_shift, + int diff_min, uint8* output_data, + const Dims<4>& output_dims) { + // The representation chosen for the input to the exp() function is Q5.26. + // We need to leave extra space since values that we skip might be as large as + // -32 before multiplying by input_beta_multiplier, and therefore as large as + // -16 afterwards. Note that exp(-8) is definitely not insignificant to + // accumulation, but exp(-16) definitely is. + static const int kScaledDiffIntegerBits = 5; + static const int kAccumulationIntegerBits = 12; + using FixedPointScaledDiff = + gemmlowp::FixedPoint; + using FixedPointAccum = gemmlowp::FixedPoint; + using FixedPoint0 = gemmlowp::FixedPoint; + + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + + for (int b = 0; b < batches; ++b) { + for (int x = 0; x < width; ++x) { + for (int y = 0; y < height; ++y) { + uint8 max_in_row = 0; + for (int c = 0; c < depth; ++c) { + max_in_row = + std::max(max_in_row, input_data[Offset(input_dims, c, x, y, b)]); + } + + FixedPointAccum sum_of_exps = FixedPointAccum::Zero(); + for (int c = 0; c < depth; ++c) { + int32 input_diff = + static_cast(input_data[Offset(input_dims, c, x, y, b)]) - + max_in_row; + if (input_diff >= diff_min) { + const int32 input_diff_rescaled = + MultiplyByQuantizedMultiplierGreaterThanOne( + input_diff, input_beta_multiplier, input_beta_left_shift); + const FixedPointScaledDiff scaled_diff_f8 = + FixedPointScaledDiff::FromRaw(input_diff_rescaled); + sum_of_exps = + sum_of_exps + gemmlowp::Rescale( + exp_on_negative_values(scaled_diff_f8)); + } + } + + int32 fixed_sum_of_exps = sum_of_exps.raw(); + int headroom_plus_one = + CountLeadingZeros(static_cast(fixed_sum_of_exps)); + // This is the number of bits to the left of the binary point above 1.0. + // Consider fixed_sum_of_exps=1.25. In that case shifted_scale=0.8 and + // no later adjustment will be needed. + int num_bits_over_unit = kAccumulationIntegerBits - headroom_plus_one; + int32 shifted_sum_minus_one = static_cast( + (static_cast(fixed_sum_of_exps) << headroom_plus_one) - + (static_cast(1) << 31)); + + FixedPoint0 shifted_scale = gemmlowp::one_over_one_plus_x_for_x_in_0_1( + FixedPoint0::FromRaw(shifted_sum_minus_one)); + + for (int c = 0; c < depth; ++c) { + int32 input_diff = + static_cast(input_data[Offset(input_dims, c, x, y, b)]) - + max_in_row; + if (input_diff >= diff_min) { + const int32 input_diff_rescaled = + MultiplyByQuantizedMultiplierGreaterThanOne( + input_diff, input_beta_multiplier, input_beta_left_shift); + const FixedPointScaledDiff scaled_diff_f8 = + FixedPointScaledDiff::FromRaw(input_diff_rescaled); + + FixedPoint0 exp_in_0 = exp_on_negative_values(scaled_diff_f8); + int32 unsat_output = gemmlowp::RoundingDivideByPOT( + (shifted_scale * exp_in_0).raw(), num_bits_over_unit + 31 - 8); + + output_data[Offset(output_dims, c, x, y, b)] = static_cast( + std::max(std::min(unsat_output, static_cast(255)), 0)); + + } else { + output_data[Offset(output_dims, c, x, y, b)] = 0; + } + } + } + } + } +} + +inline void Logistic(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + float val = input_data[Offset(input_dims, c, x, y, b)]; + float result = 1.f / (1.f + std::exp(-val)); + output_data[Offset(output_dims, c, x, y, b)] = result; + } + } + } + } +} + +inline void Logistic(const uint8* input_data, const Dims<4>& input_dims, + int32 input_zero_point, int32 input_range_radius, + int32 input_multiplier, int input_left_shift, + uint8* output_data, const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + const uint8 input_val_u8 = input_data[Offset(input_dims, c, x, y, b)]; + const int32 input_val_centered = + static_cast(input_val_u8) - input_zero_point; + uint8 output_val; + if (input_val_centered <= -input_range_radius) { + output_val = 0; + } else if (input_val_centered >= input_range_radius) { + output_val = 255; + } else { + const int32 input_val_rescaled = + MultiplyByQuantizedMultiplierGreaterThanOne( + input_val_centered, input_multiplier, input_left_shift); + using FixedPoint4 = gemmlowp::FixedPoint; + using FixedPoint0 = gemmlowp::FixedPoint; + const FixedPoint4 input_val_f4 = + FixedPoint4::FromRaw(input_val_rescaled); + const FixedPoint0 output_val_f0 = gemmlowp::logistic(input_val_f4); + using gemmlowp::RoundingDivideByPOT; + int32 output_val_s32 = RoundingDivideByPOT(output_val_f0.raw(), 23); + if (output_val_s32 == 256) { + output_val_s32 = 255; + } + TFLITE_DCHECK_GE(output_val_s32, 0); + TFLITE_DCHECK_LE(output_val_s32, 255); + output_val = static_cast(output_val_s32); + } + output_data[Offset(output_dims, c, x, y, b)] = output_val; + } + } + } + } +} + +inline void Tanh(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + float val = input_data[Offset(input_dims, c, x, y, b)]; + float result = std::tanh(val); + output_data[Offset(output_dims, c, x, y, b)] = result; + } + } + } + } +} + +inline void Dequantize(const uint8* input_data, const Dims<4>& input_dims, + int32 zero_point, double scale, float* output_data, + const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + int32 val = input_data[Offset(input_dims, c, x, y, b)]; + float result = static_cast(scale * (val - zero_point)); + output_data[Offset(output_dims, c, x, y, b)] = result; + } + } + } + } +} + +inline void FakeQuant(const float* input_data, const Dims<4>& input_dims, + float rmin, float rmax, float* output_data, + const Dims<4>& output_dims) { + // 0 should always be a representable value. Let's assume that the initial + // min,max range contains 0. + TFLITE_DCHECK_LE(rmin, 0.); + TFLITE_DCHECK_GE(rmax, 0.); + + // Determine quantization parameters: zero_point, scale. + using Integer = uint8; + const Integer qmin = std::numeric_limits::min(); + const Integer qmax = std::numeric_limits::max(); + const float qmin_float = qmin; + const float qmax_float = qmax; + int32 zero_point = 0; + float scale = 0.f; + // If rmin==rmax, both must be zero per the above assertion, + // so we are done. + if (rmin != rmax) { + // First determine the scale. + scale = (rmax - rmin) / (qmax_float - qmin_float); + + // Zero-point computation. + // First the initial floating-point computation. The zero-point can be + // determined from solving an affine equation for any known pair + // (real value, corresponding quantized value). + // We know two such pairs: (rmin, qmin) and (rmax, qmax). + // The arithmetic error on the zero point computed from either pair + // will be roughly machine_epsilon * (sum of absolute values of terms) + // so we want to use the variant that adds the smaller terms. + const float zero_point_from_min = qmin_float - rmin / scale; + const float zero_point_from_max = qmax_float - rmax / scale; + const float zero_point_from_min_error = + std::abs(qmin_float) + std::abs(rmin / scale); + const float zero_point_from_max_error = + std::abs(qmax_float) + std::abs(rmax / scale); + + const float zero_point_float = + zero_point_from_min_error < zero_point_from_max_error + ? zero_point_from_min + : zero_point_from_max; + + // Now we need to nudge the zero point to be an integer + // (our zero points are integer, and this is motivated by the requirement + // to be able to represent the real value "0" exactly as a quantized value, + // which is required in multiple places, for example in Im2col with SAME + // padding). + if (zero_point_float < qmin_float) { + zero_point = qmin; + } else if (zero_point_float > qmax_float) { + zero_point = qmax; + } else { + zero_point = static_cast(TfLiteRound(zero_point_float)); + } + // The zero point should always be in the range of quantized value, + // [qmin, qmax]. + TFLITE_DCHECK_GE(zero_point, qmin); + TFLITE_DCHECK_LE(zero_point, qmax); + } + + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + const float src_val = input_data[Offset(input_dims, c, x, y, b)]; + const float unclamped_quantized_val = + TfLiteRound(zero_point + src_val / scale); + const float quantized_val = std::min( + qmax_float, std::max(qmin_float, unclamped_quantized_val)); + const float dst_val = scale * (quantized_val - zero_point); + output_data[Offset(output_dims, c, x, y, b)] = dst_val; + } + } + } + } +} + +template +inline void Cast(const SrcT* input_data, const Dims<4>& input_dims, + DstT* output_data, const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + int offset = Offset(input_dims, c, x, y, b); + output_data[offset] = static_cast(input_data[offset]); + } + } + } + } +} + +inline void Floor(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + int offset = Offset(input_dims, c, x, y, b); + output_data[offset] = std::floor(input_data[offset]); + } + } + } + } +} + +template +inline void Gather(const T* input_data, const Dims<4>& input_dims, + int input_rank, const int32* coords_data, + const Dims<4>& coords_dims, T* output_data, + const Dims<4>& output_dims) { + TFLITE_DCHECK(coords_dims.sizes[0] == output_dims.sizes[input_rank - 1]); + int stride = input_dims.strides[input_rank - 1]; + T* out = output_data; + + for (int i = 0; i < coords_dims.sizes[0]; i++) { + TFLITE_DCHECK_GE(coords_data[i], 0); + TFLITE_DCHECK_LT(coords_data[i], input_dims.sizes[input_rank - 1]); + const T* in = input_data + coords_data[i] * stride; + memcpy(out, in, sizeof(T) * stride); + out += stride; + } +} + +inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims, + const int32* output_size_data, + const Dims<4>& output_size_dims, float* output_data, + const Dims<4>& output_dims) { + int32 batches = MatchingArraySize(input_dims, 3, output_dims, 3); + int32 input_height = ArraySize(input_dims, 2); + int32 input_width = ArraySize(input_dims, 1); + int32 depth = MatchingArraySize(input_dims, 0, output_dims, 0); + + TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 3), 1); + TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 2), 1); + TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 1), 1); + TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 0), 2); + int32 output_height = output_size_data[Offset(output_size_dims, 0, 0, 0, 0)]; + int32 output_width = output_size_data[Offset(output_size_dims, 1, 0, 0, 0)]; + float height_scale = static_cast(input_height) / output_height; + float width_scale = static_cast(input_width) / output_width; + + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < output_height; ++y) { + float input_y = y * height_scale; + int32 y0 = static_cast(std::floor(input_y)); + int32 y1 = std::min(y0 + 1, input_height - 1); + for (int x = 0; x < output_width; ++x) { + float input_x = x * width_scale; + int32 x0 = static_cast(std::floor(input_x)); + int32 x1 = std::min(x0 + 1, input_width - 1); + for (int c = 0; c < depth; ++c) { + float interpolation = input_data[Offset(input_dims, c, x0, y0, b)] * + (1 - (input_y - y0)) * + (1 - (input_x - x0)) + + input_data[Offset(input_dims, c, x0, y1, b)] * + (input_y - y0) * (1 - (input_x - x0)) + + input_data[Offset(input_dims, c, x1, y0, b)] * + (1 - (input_y - y0)) * (input_x - x0) + + input_data[Offset(input_dims, c, x1, y1, b)] * + (input_y - y0) * (input_x - x0); + output_data[Offset(output_dims, c, x, y, b)] = interpolation; + } + } + } + } +} + +template +inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims, + const int32* block_shape_data, + const Dims<4>& block_shape_dims, + const int32* paddings_data, + const Dims<4>& paddings_dims, T* output_data, + const Dims<4>& output_dims) { + const int output_batch_size = ArraySize(output_dims, 3); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + const int input_batch_size = ArraySize(input_dims, 3); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int depth = ArraySize(input_dims, 0); + const int block_shape_height = block_shape_data[0]; + const int block_shape_width = block_shape_data[1]; + const int padding_top = paddings_data[0]; + const int padding_left = paddings_data[2]; + + for (int out_b = 0; out_b < output_batch_size; ++out_b) { + int input_batch = out_b % input_batch_size; + int shift_w = (out_b / input_batch_size) % block_shape_width; + int shift_h = (out_b / input_batch_size) / block_shape_width; + for (int out_h = 0; out_h < output_height; ++out_h) { + for (int out_w = 0; out_w < output_width; ++out_w) { + T* out = output_data + Offset(output_dims, 0, out_w, out_h, out_b); + if (out_h * block_shape_height < padding_top || + out_h * block_shape_height >= padding_top + input_height || + out_w * block_shape_width < padding_left || + out_w * block_shape_width >= padding_left + input_width) { + memset(out, 0, depth * sizeof(T)); + } else { + const T* in = + input_data + + Offset(input_dims, 0, + (out_w * block_shape_width + shift_w) - padding_left, + (out_h * block_shape_height + shift_h) - padding_top, + input_batch); + memcpy(out, in, depth * sizeof(T)); + } + } + } + } +} + +template +inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims, + const int32* block_shape_data, + const Dims<4>& block_shape_dims, T* output_data, + const Dims<4>& output_dims) { + const int output_batch_size = ArraySize(output_dims, 3); + const int input_batch_size = ArraySize(input_dims, 3); + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int depth = ArraySize(input_dims, 0); + const int block_shape_width = block_shape_data[1]; + const int block_shape_height = block_shape_data[0]; + + for (int in_batch = 0; in_batch < input_batch_size; ++in_batch) { + for (int in_h = 0; in_h < input_height; ++in_h) { + for (int in_w = 0; in_w < input_width; ++in_w) { + int out_batch = in_batch % output_batch_size; + int out_w = in_w * block_shape_width + + (in_batch / output_batch_size) % block_shape_width; + int out_h = in_h * block_shape_height + + (in_batch / output_batch_size) / block_shape_width; + T* out = output_data + Offset(output_dims, 0, out_w, out_h, out_batch); + const T* in = input_data + Offset(input_dims, 0, in_w, in_h, in_batch); + memcpy(out, in, depth * sizeof(T)); + } + } + } +} + +template +inline void Pad(const T* input_data, const Dims<4>& input_dims, + const std::vector& left_paddings, + const std::vector& right_paddings, T* output_data, + const Dims<4>& output_dims) { + const int output_batch = ArraySize(output_dims, 3); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + const int output_depth = ArraySize(output_dims, 0); + + const int left_b_padding = left_paddings[3]; + const int left_h_padding = left_paddings[2]; + const int left_w_padding = left_paddings[1]; + const int left_d_padding = left_paddings[0]; + + const int right_b_padding = right_paddings[3]; + const int right_h_padding = right_paddings[2]; + const int right_w_padding = right_paddings[1]; + const int right_d_padding = right_paddings[0]; + + const T* in_ptr = input_data; + T* out_ptr = output_data; + for (int out_b = 0; out_b < output_batch; ++out_b) { + for (int out_h = 0; out_h < output_height; ++out_h) { + for (int out_w = 0; out_w < output_width; ++out_w) { + for (int out_d = 0; out_d < output_depth; ++out_d) { + if (out_b < left_b_padding || + out_b >= output_batch - right_b_padding || + out_h < left_h_padding || + out_h >= output_height - right_h_padding || + out_w < left_w_padding || + out_w >= output_width - right_w_padding || + out_d < left_d_padding || + out_d >= output_depth - right_d_padding) { + *out_ptr++ = 0; + } else { + *out_ptr++ = *in_ptr++; + } + } + } + } + } +} + +template +inline void StridedSlice(const T* input_data, const Dims<4>& input_dims, + int begin_mask, int end_mask, + const std::vector& starts, + const std::vector& stops, + const std::vector& strides, T* output_data, + const Dims<4>& output_dims) { + const int start_b = (begin_mask & 8) ? 0 : starts[3]; + const int stop_b = (end_mask & 8) ? input_dims.sizes[3] : stops[3]; + const int start_h = (begin_mask & 4) ? 0 : starts[2]; + const int stop_h = (end_mask & 4) ? input_dims.sizes[2] : stops[2]; + const int start_w = (begin_mask & 2) ? 0 : starts[1]; + const int stop_w = (end_mask & 2) ? input_dims.sizes[1] : stops[1]; + const int start_d = (begin_mask & 1) ? 0 : starts[0]; + const int stop_d = (end_mask & 1) ? input_dims.sizes[0] : stops[0]; + + T* out_ptr = output_data; + for (int in_b = start_b; in_b < stop_b; in_b += strides[3]) { + for (int in_h = start_h; in_h < stop_h; in_h += strides[2]) { + for (int in_w = start_w; in_w < stop_w; in_w += strides[1]) { + for (int in_d = start_d; in_d < stop_d; in_d += strides[0]) { + *out_ptr++ = input_data[Offset(input_dims, in_d, in_w, in_h, in_b)]; + } + } + } + } +} + +template +inline void Slice(const T* input_data, const Dims<4>& input_dims, + const std::vector& begin, const std::vector& size, + T* output_data, const Dims<4>& output_dims) { + // TODO(dkalenichenko): This op only supports 4D tensors. + TFLITE_DCHECK_EQ(begin.size(), 4); + TFLITE_DCHECK_EQ(size.size(), 4); + const int start_b = begin[3]; + const int stop_b = + size[3] == -1 ? input_dims.sizes[3] - start_b : start_b + size[3]; + const int start_h = begin[2]; + const int stop_h = + size[2] == -1 ? input_dims.sizes[2] - start_b : start_b + size[2]; + const int start_w = begin[1]; + const int stop_w = + size[1] == -1 ? input_dims.sizes[1] - start_b : start_b + size[1]; + const int start_d = begin[0]; + const int stop_d = + size[0] == -1 ? input_dims.sizes[0] - start_d : start_d + size[0]; + + T* out_ptr = output_data; + for (int in_b = start_b; in_b < stop_b; ++in_b) { + for (int in_h = start_h; in_h < stop_h; ++in_h) { + for (int in_w = start_w; in_w < stop_w; ++in_w) { + for (int in_d = start_d; in_d < stop_d; ++in_d) { + *out_ptr++ = input_data[Offset(input_dims, in_d, in_w, in_h, in_b)]; + } + } + } + } +} + +template +inline void Mean(const T* input_data, const Dims<4>& input_dims, + const std::vector& reduction_indices, T* output_data, + const Dims<4>& output_dims) { + const int output_batch = ArraySize(output_dims, 3); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + const int output_depth = ArraySize(output_dims, 0); + + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + + // The current implementation only supports simultaneous reduction over + // width and height. + TFLITE_DCHECK_EQ(reduction_indices.size(), 2); + TFLITE_DCHECK((reduction_indices[0] == 1 && reduction_indices[1] == 2) || + (reduction_indices[0] == 2 && reduction_indices[1] == 1)); + TFLITE_DCHECK_EQ(output_height, 1); + TFLITE_DCHECK_EQ(output_width, 1); + + for (int out_b = 0; out_b < output_batch; ++out_b) { + for (int out_d = 0; out_d < output_depth; ++out_d) { + float value = 0; + for (int in_h = 0; in_h < input_height; ++in_h) { + for (int in_w = 0; in_w < input_width; ++in_w) { + value += input_data[Offset(input_dims, out_d, in_w, in_h, out_b)]; + } + } + output_data[Offset(output_dims, out_d, 0, 0, out_b)] = + value / (input_width * input_height); + } + } +} + +template +void Sub(const T* input1_data, const Dims<4>& input1_dims, const T* input2_data, + const Dims<4>& input2_dims, T* output_data, + const Dims<4>& output_dims) { + NdArrayDesc<4> desc1; + NdArrayDesc<4> desc2; + NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); + + // In Tensorflow, the dimensions are canonically named (batch_number, row, + // col, channel), with extents (batches, height, width, depth), with the + // trailing dimension changing most rapidly (channels has the smallest stride, + // typically 1 element). + // + // In generated C code, we store arrays with the dimensions reversed. The + // first dimension has smallest stride. + // + // We name our variables by their Tensorflow convention, but generate C code + // nesting loops such that the innermost loop has the smallest stride for the + // best cache behavior. + for (int b = 0; b < ArraySize(output_dims, 3); ++b) { + for (int y = 0; y < ArraySize(output_dims, 2); ++y) { + for (int x = 0; x < ArraySize(output_dims, 1); ++x) { + for (int c = 0; c < ArraySize(output_dims, 0); ++c) { + output_data[Offset(output_dims, c, x, y, b)] = + input1_data[SubscriptToIndex(desc1, c, x, y, b)] - + input2_data[SubscriptToIndex(desc2, c, x, y, b)]; + } + } + } + } +} + +template +void TensorFlowMinimum(const T* input1_data, const Dims<4>& input1_dims, + const T* input2_data, T* output_data, + const Dims<4>& output_dims) { + int batches = MatchingArraySize(input1_dims, 3, output_dims, 3); + int input_height = MatchingArraySize(input1_dims, 2, output_dims, 2); + int input_width = MatchingArraySize(input1_dims, 1, output_dims, 1); + int depth = MatchingArraySize(input1_dims, 0, output_dims, 0); + + auto min_value = input2_data[0]; + + for (int b = 0; b < batches; b++) { + for (int y = 0; y < input_height; y++) { + for (int x = 0; x < input_width; x++) { + for (int c = 0; c < depth; c++) { + int offset = Offset(input1_dims, c, x, y, b); + output_data[offset] = + input1_data[offset] > min_value ? min_value : input1_data[offset]; + } + } + } + } +} + +template +void TensorFlowMaximum(const T* input1_data, const Dims<4>& input1_dims, + const T* input2_data, T* output_data, + const Dims<4>& output_dims) { + int batches = MatchingArraySize(input1_dims, 3, output_dims, 3); + int input_height = MatchingArraySize(input1_dims, 2, output_dims, 2); + int input_width = MatchingArraySize(input1_dims, 1, output_dims, 1); + int depth = MatchingArraySize(input1_dims, 0, output_dims, 0); + + auto max_value = input2_data[0]; + + for (int b = 0; b < batches; b++) { + for (int y = 0; y < input_height; y++) { + for (int x = 0; x < input_width; x++) { + for (int c = 0; c < depth; c++) { + int offset = Offset(input1_dims, c, x, y, b); + output_data[offset] = + input1_data[offset] < max_value ? max_value : input1_data[offset]; + } + } + } + } +} + +} // namespace reference_ops +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_REFERENCE_OPS_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/round.h b/tensorflow/contrib/lite/kernels/internal/round.h new file mode 100644 index 0000000000000000000000000000000000000000..38525b0e208b852343849096ac68cbfc9ef3e389 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/round.h @@ -0,0 +1,39 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_ROUND_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_ROUND_H_ + +#include + +namespace tflite { + +// TODO(aselle): See if we can do this only on jdk. Also mikecase, check +// if you need this for java host build. +#if defined(__ANDROID__) && !defined(__NDK_MAJOR__) +template +inline float TfLiteRound(const float x) { + return ::round(x); +} +inline double TfLiteRound(const double x) { return ::round(x); } +#else +template +inline T TfLiteRound(const T x) { + return std::round(x); +} +#endif + +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_ROUND_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/tensor.h b/tensorflow/contrib/lite/kernels/internal/tensor.h new file mode 100644 index 0000000000000000000000000000000000000000..ee4111e0416560d94d513c528971bdf3bf819662 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/tensor.h @@ -0,0 +1,87 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_H_ + +#include +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/types.h" + +namespace tflite { + +template +inline T* GetTensorData(TfLiteTensor* tensor); + +template <> +inline float* GetTensorData(TfLiteTensor* tensor) { + return tensor != nullptr ? tensor->data.f : nullptr; +} + +template <> +inline uint8_t* GetTensorData(TfLiteTensor* tensor) { + return tensor != nullptr ? tensor->data.uint8 : nullptr; +} + +template <> +inline int32_t* GetTensorData(TfLiteTensor* tensor) { + return tensor != nullptr ? tensor->data.i32 : nullptr; +} + +template <> +inline int64_t* GetTensorData(TfLiteTensor* tensor) { + return tensor != nullptr ? reinterpret_cast(tensor->data.raw) + : nullptr; +} + +inline int RemapDim(int max_dimensions, int d) { + return max_dimensions - d - 1; +} + +// TODO(ahentz): the implementations in kernels/internal/ take a Dims<4> object +// even if the original tensors were not 4D. We should consider rewriting them +// to take a more generic 'shape' object. +inline Dims<4> GetTensorDims(const int data[], const int size) { + Dims<4> d; + for (int i = 0; i < 4; ++i) { + int src = size - i - 1; + if (src >= 0) { + d.sizes[i] = data[src]; + } else { + d.sizes[i] = 1; + } + } + d.strides[0] = 1; + for (int i = 1; i < 4; i++) { + d.strides[i] = d.strides[i - 1] * d.sizes[i - 1]; + } + return d; +} + +inline Dims<4> GetTensorDims(std::vector data) { + return GetTensorDims(data.data(), data.size()); +} + +inline Dims<4> GetTensorDims(const TfLiteTensor* tensor) { + if (tensor == nullptr) { + return Dims<4>(); + } + + auto* dims = tensor->dims; + return GetTensorDims(dims->data, dims->size); +} + +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_test.cc b/tensorflow/contrib/lite/kernels/internal/tensor_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..bf2068d320f65cf0195abbc181f4ef4ff8f20679 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/tensor_test.cc @@ -0,0 +1,55 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include +#include + +namespace tflite { +namespace { + +using ::testing::ElementsAre; + +TEST(TensorTest, GetTensorDims4D) { + Dims<4> d = GetTensorDims({2, 3, 4, 5}); + EXPECT_THAT(d.sizes, ElementsAre(5, 4, 3, 2)); + EXPECT_THAT(d.strides, ElementsAre(1, 5, 20, 60)); +} + +TEST(TensorTest, GetTensorDims3D) { + Dims<4> d = GetTensorDims({3, 4, 5}); + EXPECT_THAT(d.sizes, ElementsAre(5, 4, 3, 1)); + EXPECT_THAT(d.strides, ElementsAre(1, 5, 20, 60)); +} + +TEST(TensorTest, GetTensorDims2D) { + Dims<4> d = GetTensorDims({4, 5}); + EXPECT_THAT(d.sizes, ElementsAre(5, 4, 1, 1)); + EXPECT_THAT(d.strides, ElementsAre(1, 5, 20, 20)); +} + +TEST(TensorTest, GetTensorDims1D) { + Dims<4> d = GetTensorDims({5}); + EXPECT_THAT(d.sizes, ElementsAre(5, 1, 1, 1)); + EXPECT_THAT(d.strides, ElementsAre(1, 5, 5, 5)); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + // On Linux, add: tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/tensor_utils.cc new file mode 100644 index 0000000000000000000000000000000000000000..904a97803a6a9ba369c1e64c711b12d19ffc10c4 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/tensor_utils.cc @@ -0,0 +1,27 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" + +#ifndef USE_NEON +#if defined(__ARM_NEON__) || defined(__ARM_NEON) +#define USE_NEON +#endif // defined(__ARM_NEON__) || defined(__ARM_NEON) +#endif // USE_NEON + +#ifdef USE_NEON +#include "tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h" +#else +#include "tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h" +#endif // USE_NEON diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/tensor_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..0e69ef5982f01e364d865684652d1dfecab6fee3 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/tensor_utils.h @@ -0,0 +1,116 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_UTILS_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_UTILS_H_ + +#include "tensorflow/contrib/lite/builtin_op_data.h" + +namespace tflite { +namespace tensor_utils { + +// Limit a float input f betweeen +abs_limit and -abs_limit. +float Clip(float f, float abs_limit); + +// Multiply a matrix by a batch vector, and store results in a batch-size +// vector using a stride value provided in result_stride. 'result_stride' shows +// how the number of elements between consecutive result values. For example +// result_stride = 1, will cause the output to look like this: +// [O_1, 0_2, ... O_rows] in memory, but result_stride = 3, will cause it to be +// arranged like this in memory: [O_1, x, x, 0_2, x, x, ..., O_rows] +void MatrixBatchVectorMultiplyAccumulate(const float* matrix, int m_rows, + int m_cols, const float* vector, + int n_batch, float* result, + int result_stride); + +// Cwise product of two vectors. +void VectorVectorCwiseProduct(const float* vector1, const float* vector2, + int v_size, float* result); + +// Cwise product and accumulate of two vectors. Since it's a MAC opertation, the +// assumption here is that result array is initialized to valid values. +void VectorVectorCwiseProductAccumulate(const float* vector1, + const float* vector2, int v_size, + float* result); + +// Dot product of two vectors. +float VectorVectorDotProduct(const float* vector1, const float* vector2, + int v_size); + +// Dot product of two batch vectors of size n_batch * v_size: +// vector1 = [x_1_1, x_1_2, ..., x_1_vsize, +// x_2_1, x_2_2, ..., x_2_vsize, +// ... +// x_nbatch_1,..., x_nbatch_vsize] +// vector2 = [y_1_1, y_1_2, ..., y_1_vsize, +// y_2_1, y_2_2, ..., y_2_vsize, +// ... +// y_nbatch_1,..., y_nbatch_vsize] +// Then result will be a vector of n_batch size which will be saved with a +// stride of result_stride in memory starting from 'result': +// [x_1_1 * y_1_1 + x_1_2 * y_1_2 + ... + x_1_vsize * y_1_vsize, +// x_2_1 * y_2_1 + x_2_2 * y_2_2 + ... + x_2_vsize * y_2_vsize, +// ... +// x_nbatch_1 * y_nbatch_1 + ... + x_nbatch_vsize * y_nbatch_vsize] +void BatchVectorBatchVectorDotProduct(const float* vector1, + const float* vector2, int v_size, + int n_batch, float* result, + int result_stride); + +// Cwise product and accumulate of a vector and a batch-vector. Since it's a MAC +// operation, the assumption here is that result array is initialized to valid +// values. +void VectorBatchVectorCwiseProductAccumulate(const float* vector, int v_size, + const float* batch_vector, + int n_batch, float* result); + +// Batch vector initialization with another vector. +void VectorBatchVectorAssign(const float* vector, int v_size, int n_batch, + float* batch_vector); + +// Apply sigmoid to elements of a vector. +void ApplySigmoidToVector(const float* vector, int v_size, float* result); + +// Apply activation function to elements of a vector. +void ApplyActivationToVector(const float* vector, int v_size, + TfLiteFusedActivation activation, float* result); + +// Copy vector to another vector. +void CopyVector(const float* vector, int v_size, float* result); + +// Compute "1.0f - elements of vector" (used in CIFG). +void Sub1Vector(const float* vector, int v_size, float* result); + +// Fill vector with 0.f. +void ZeroVector(float* vector, int v_size); + +// Clip elements of a vector using a abs_limit value. +void ClipVector(const float* vector, int v_size, float abs_limit, + float* result); + +// Shift left a vector in place with v_size size. +void VectorShiftLeft(float* vector, int v_size, float shift_value); + +// Reduce-sum on a float input vector: +// input_vector: float pointer to input vector. +// output_vector: float pointer to vector. +// output_size: output vector size. +// reduction_size: number of consecutive elements from input vector which are +// added to get one element of output. +void ReductionSumVector(const float* input_vector, float* output_vector, + int output_size, int reduction_size); +} // namespace tensor_utils +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_UTILS_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc b/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..588f1a428b8c84367d659c2c5bb59a411cd8bb34 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc @@ -0,0 +1,192 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" +#include +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" + +namespace tflite { +namespace tensor_utils { + +TEST(uKernels, ClipTest) { + constexpr int kVectorSize = 10; + constexpr float kAbsLimit = 2.0; + static float input[kVectorSize] = {0.0, -0.5, 1.0, -1.5, 2.0, + -2.5, 3.0, -3.5, 4.0, -4.5}; + std::vector output(kVectorSize); + ClipVector(input, kVectorSize, kAbsLimit, output.data()); + EXPECT_THAT(output, + ElementsAreArray(ArrayFloatNear( + {0.0, -0.5, 1.0, -1.5, 2.0, -2.0, 2.0, -2.0, 2.0, -2.0}))); +} + +TEST(uKernels, MatrixBatchVectorMultiplyAccumulateTest) { + constexpr int kRow = 3; + constexpr int kCol = 4; + constexpr int kBatch = 2; + static float matrix[kRow * kCol] = {1.0, 2.0, 3.0, 4.0, // + -1.0, -2.0, -3.0, -4.0, // + 1.0, -2.0, 3.0, -4.0}; + static float vector[kCol * kBatch] = {1.0, -1.0, 1.0, -1.0, // + 2.0, -2.0, 2.0, -2.0}; + std::vector output(kRow * kBatch); + std::fill(output.begin(), output.end(), 3.0); + MatrixBatchVectorMultiplyAccumulate(matrix, kRow, kCol, vector, kBatch, + output.data(), /*result_stride=*/1); + EXPECT_THAT(output, ElementsAreArray(ArrayFloatNear({1., 5., 13., // + -1., 7., 23.}))); + + std::vector output_with_stride2(kRow * kBatch * 2); + std::fill(output_with_stride2.begin(), output_with_stride2.end(), 3.0); + MatrixBatchVectorMultiplyAccumulate(matrix, kRow, kCol, vector, kBatch, + output_with_stride2.data(), + /*result_stride=*/2); + EXPECT_THAT(output_with_stride2, + ElementsAreArray(ArrayFloatNear({1., 3., 5., 3., 13., 3., // + -1., 3., 7., 3., 23., 3.}))); +} + +TEST(uKernels, VectorVectorCwiseProductTest) { + constexpr int kVectorSize = 10; + static float input1[kVectorSize] = {0.0, -0.5, 1.0, -1.5, 2.0, + -2.5, 3.0, -3.5, 4.0, -4.5}; + static float input2[kVectorSize] = {0.1, -0.1, 0.1, -0.1, 0.1, + -0.1, 0.1, -0.1, 0.1, -0.1}; + std::vector output(kVectorSize); + VectorVectorCwiseProduct(input1, input2, kVectorSize, output.data()); + EXPECT_THAT(output, + ElementsAreArray(ArrayFloatNear( + {0.0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45}))); +} + +TEST(uKernels, VectorVectorCwiseProductAccumulateTest) { + constexpr int kVectorSize = 10; + static float input1[kVectorSize] = {0.0, -0.5, 1.0, -1.5, 2.0, + -2.5, 3.0, -3.5, 4.0, -4.5}; + static float input2[kVectorSize] = {0.1, -0.1, 0.1, -0.1, 0.1, + -0.1, 0.1, -0.1, 0.1, -0.1}; + std::vector output(kVectorSize); + std::fill(output.begin(), output.end(), 1.0); + VectorVectorCwiseProductAccumulate(input1, input2, kVectorSize, + output.data()); + EXPECT_THAT(output, + ElementsAreArray(ArrayFloatNear( + {1.0, 1.05, 1.1, 1.15, 1.2, 1.25, 1.3, 1.35, 1.4, 1.45}))); +} + +TEST(uKernels, VectorBatchVectorAssignTest) { + constexpr int kVectorSize = 5; + constexpr int kBatchSize = 3; + static float input[kVectorSize] = {0.0, -0.5, 1.0, -1.5, 2.0}; + std::vector output(kVectorSize * kBatchSize); + VectorBatchVectorAssign(input, kVectorSize, kBatchSize, output.data()); + EXPECT_THAT(output, ElementsAreArray(ArrayFloatNear( + {0.0, -0.5, 1.0, -1.5, 2.0, 0.0, -0.5, 1.0, -1.5, 2.0, + 0.0, -0.5, 1.0, -1.5, 2.0}))); +} + +TEST(uKernels, ApplySigmoidToVectorTest) { + constexpr int kVectorSize = 5; + static float input[kVectorSize] = {0.0, -0.5, 1.0, -1.5, 2.0}; + std::vector output(kVectorSize); + ApplySigmoidToVector(input, kVectorSize, output.data()); + EXPECT_THAT(output, ElementsAreArray(ArrayFloatNear( + {0.5, 0.377541, 0.731059, 0.182426, 0.880797}))); +} + +TEST(uKernels, ApplyActivationToVectorTest) { + constexpr int kVectorSize = 5; + static float input[kVectorSize] = {0.0, -0.5, 1.0, -1.5, 2.0}; + std::vector output(kVectorSize); + ApplyActivationToVector(input, kVectorSize, kTfLiteActRelu, output.data()); + EXPECT_THAT(output, + ElementsAreArray(ArrayFloatNear({0.0, 0.0, 1.0, 0.0, 2.0}))); + + ApplyActivationToVector(input, kVectorSize, kTfLiteActTanh, output.data()); + EXPECT_THAT(output, ElementsAreArray(ArrayFloatNear( + {0.0, -0.462117, 0.761594, -0.905148, 0.964028}))); +} + +TEST(uKernels, CopyVectorTest) { + constexpr int kVectorSize = 5; + static float input[kVectorSize] = {0.0, -0.5, 1.0, -1.5, 2.0}; + std::vector output(kVectorSize); + CopyVector(input, kVectorSize, output.data()); + EXPECT_THAT(output, + ElementsAreArray(ArrayFloatNear({0.0, -0.5, 1.0, -1.5, 2.0}))); +} + +TEST(uKernels, Sub1VectorTest) { + constexpr int kVectorSize = 5; + static float input[kVectorSize] = {0.0, -0.5, 1.0, -1.5, 2.0}; + std::vector output(kVectorSize); + Sub1Vector(input, kVectorSize, output.data()); + EXPECT_THAT(output, + ElementsAreArray(ArrayFloatNear({1.0, 1.5, 0.0, 2.5, -1.0}))); +} + +TEST(uKernels, ZeroVectorTest) { + constexpr int kVectorSize = 5; + std::vector output(kVectorSize); + ZeroVector(output.data(), kVectorSize); + EXPECT_THAT(output, + ElementsAreArray(ArrayFloatNear({0.0, 0.0, 0.0, 0.0, 0.0}))); +} + +TEST(uKernels, BatchVectorBatchVectorDotProductTest) { + constexpr int kVectorSize = 5; + constexpr int kBatch = 2; + static float input1[kVectorSize * kBatch] = {0.0, -0.5, 1.0, -1.5, 2.0, + -2.5, 3.0, -3.5, 4.0, -4.5}; + static float input2[kVectorSize * kBatch] = {0.1, -0.1, 0.1, -0.1, 0.1, + -0.1, 0.1, -0.1, 0.1, -0.1}; + std::vector output(kBatch); + BatchVectorBatchVectorDotProduct(input1, input2, kVectorSize, kBatch, + output.data(), /*result_stride=*/1); + EXPECT_THAT(output, ElementsAreArray(ArrayFloatNear({0.5, 1.75}))); +} + +TEST(uKernels, VectorShiftLeftTest) { + constexpr int kVectorSize = 5; + static float input[kVectorSize] = {0.0, -0.5, 1.0, -1.5, 2.0}; + std::vector result(kVectorSize); + VectorShiftLeft(input, kVectorSize, 3.0); + result.assign(input, input + kVectorSize); + EXPECT_THAT(result, + ElementsAreArray(ArrayFloatNear({-0.5, 1.0, -1.5, 2.0, 3.0}))); +} + +TEST(uKernels, ReductionSumVectorTest) { + constexpr int kInputVectorSize = 10; + constexpr int kOutputVectorSize1 = 5; + constexpr int kReductionSize1 = 2; + static float input[kInputVectorSize] = {0.0, -0.5, 1.0, -1.5, 2.0, + 0.0, -0.5, 1.0, 1.0, 2.0}; + std::vector result1(kOutputVectorSize1); + ReductionSumVector(input, result1.data(), kOutputVectorSize1, + kReductionSize1); + EXPECT_THAT(result1, + ElementsAreArray(ArrayFloatNear({-0.5, -0.5, 2.0, 0.5, 3.0}))); + + constexpr int kOutputVectorSize2 = 2; + constexpr int kReductionSize2 = 5; + std::vector result2(kOutputVectorSize2); + ReductionSumVector(input, result2.data(), kOutputVectorSize2, + kReductionSize2); + EXPECT_THAT(result2, ElementsAreArray(ArrayFloatNear({1.0, 3.5}))); +} + +} // namespace tensor_utils +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/types.h b/tensorflow/contrib/lite/kernels/internal/types.h new file mode 100644 index 0000000000000000000000000000000000000000..07f1cb40045fff3ae47ed4efa6ec43b0cb88a0a7 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/types.h @@ -0,0 +1,81 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_ + +#include "tensorflow/contrib/lite/kernels/internal/compatibility.h" + +namespace tflite { + +enum class FusedActivationFunctionType : uint8 { kNone, kRelu6, kRelu1, kRelu }; + +template +struct Dims { + int sizes[N]; + int strides[N]; +}; + +inline int Offset(const Dims<4>& dims, int i0, int i1, int i2, int i3) { + TFLITE_DCHECK(i0 >= 0 && i0 < dims.sizes[0]); + TFLITE_DCHECK(i1 >= 0 && i1 < dims.sizes[1]); + TFLITE_DCHECK(i2 >= 0 && i2 < dims.sizes[2]); + TFLITE_DCHECK(i3 >= 0 && i3 < dims.sizes[3]); + return i0 * dims.strides[0] + i1 * dims.strides[1] + i2 * dims.strides[2] + + i3 * dims.strides[3]; +} + +// Get array size, DCHECKing that the dim index is in range. +template +int ArraySize(const Dims& array, int index) { + TFLITE_DCHECK(index >= 0 && index < N); + return array.sizes[index]; +} + +// Get common array size, DCHECKing that they all agree. +template +int MatchingArraySize(const ArrayType1& array1, int index1, + const ArrayType2& array2, int index2) { + TFLITE_DCHECK_EQ(ArraySize(array1, index1), ArraySize(array2, index2)); + return ArraySize(array1, index1); +} + +template +int MatchingArraySize(const ArrayType1& array1, int index1, + const ArrayType2& array2, int index2, Args... args) { + TFLITE_DCHECK_EQ(ArraySize(array1, index1), ArraySize(array2, index2)); + return MatchingArraySize(array1, index1, args...); +} + +inline int RequiredBufferSizeForDims(const Dims<4>& dims) { + int max_offset = 0; + for (int i = 0; i < 4; i++) { + max_offset += (dims.sizes[i] - 1) * dims.strides[i]; + } + return max_offset + 1; +} + +template +bool IsPackedWithoutStrides(const Dims& dims) { + int expected_stride = 1; + for (int d = 0; d < N; d++) { + if (dims.strides[d] != expected_stride) return false; + expected_stride *= dims.sizes[d]; + } + return true; +} + +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_ diff --git a/tensorflow/contrib/lite/kernels/kernel_util.cc b/tensorflow/contrib/lite/kernels/kernel_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..b0546c00cf977af5f722a802866448b0cb293b8d --- /dev/null +++ b/tensorflow/contrib/lite/kernels/kernel_util.cc @@ -0,0 +1,87 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include +#include +#include "tensorflow/contrib/lite/kernels/internal/round.h" + +namespace tflite { + +TfLiteStatus GetQuantizedConvolutionMultipler( + TfLiteContext* context, TfLiteTensor* input, TfLiteTensor* filter, + TfLiteTensor* bias, TfLiteTensor* output, double* multiplier) { + const double input_product_scale = input->params.scale * filter->params.scale; + const double bias_scale = bias->params.scale; + const double output_scale = output->params.scale; + + // TODO(ahentz): The following conditions must be guaranteed by the training + // pipeline. + TF_LITE_ENSURE(context, std::abs(input_product_scale - bias_scale) <= + 1e-6 * std::min(input_product_scale, bias_scale)); + TF_LITE_ENSURE(context, input_product_scale >= 0); + TF_LITE_ENSURE(context, input_product_scale < output_scale); + + *multiplier = input_product_scale / output_scale; + + return kTfLiteOk; +} + +void CalculateActivationRangeUint8(TfLiteFusedActivation activation, + TfLiteTensor* output, int32_t* act_min, + int32_t* act_max) { + const int32_t qmin = std::numeric_limits::min(); + const int32_t qmax = std::numeric_limits::max(); + + const auto scale = output->params.scale; + const auto zero_point = output->params.zero_point; + + auto quantize = [scale, zero_point](float f) { + return zero_point + static_cast(TfLiteRound(f / scale)); + }; + + if (activation == kTfLiteActRelu) { + *act_min = std::max(qmin, quantize(0.0)); + *act_max = qmax; + } else if (activation == kTfLiteActRelu6) { + *act_min = std::max(qmin, quantize(0.0)); + *act_max = std::min(qmax, quantize(6.0)); + } else if (activation == kTfLiteActRelu1) { + *act_min = std::max(qmin, quantize(-1.0)); + *act_max = std::min(qmax, quantize(1.0)); + } else { + *act_min = qmin; + *act_max = qmax; + } +} + +void CalculateActivationRangeFloat(TfLiteFusedActivation activation, + float* activation_min, + float* activation_max) { + if (activation == kTfLiteActRelu) { + *activation_min = 0.f; + *activation_max = std::numeric_limits::max(); + } else if (activation == kTfLiteActRelu6) { + *activation_min = 0.f; + *activation_max = 6.f; + } else if (activation == kTfLiteActRelu1) { + *activation_min = -1.f; + *activation_max = 1.f; + } else { + *activation_min = std::numeric_limits::lowest(); + *activation_max = std::numeric_limits::max(); + } +} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/kernel_util.h b/tensorflow/contrib/lite/kernels/kernel_util.h new file mode 100644 index 0000000000000000000000000000000000000000..25556ae4567aca45b3bfe4ba02b1cb58331d239d --- /dev/null +++ b/tensorflow/contrib/lite/kernels/kernel_util.h @@ -0,0 +1,65 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_KERNEL_UTIL_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_KERNEL_UTIL_H_ + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" + +namespace tflite { + +inline int NumDimensions(const TfLiteTensor* t) { return t->dims->size; } +inline int SizeOfDimension(const TfLiteTensor* t, int dim) { + return t->dims->data[dim]; +} +inline TfLiteTensor* GetInput(TfLiteContext* context, TfLiteNode* node, + int index) { + return &context->tensors[node->inputs->data[index]]; +} +inline TfLiteTensor* GetOutput(TfLiteContext* context, TfLiteNode* node, + int index) { + return &context->tensors[node->outputs->data[index]]; +} +inline int NumInputs(const TfLiteNode* node) { return node->inputs->size; } +inline int NumOutputs(const TfLiteNode* node) { return node->outputs->size; } + +inline TfLiteTensor* GetOptionalInputTensor(TfLiteContext* context, + const TfLiteNode* node, int index) { + const bool use_tensor = node->inputs->data[index] != kOptionalTensor; + if (use_tensor) { + return &context->tensors[node->inputs->data[index]]; + } + return nullptr; +} + +// Calculates the multiplication factor for a quantized convolution (or +// quantized depthwise convolution) involving the given tensors. Returns an +// error if the scales of the tensors are not compatible. +TfLiteStatus GetQuantizedConvolutionMultipler( + TfLiteContext* context, TfLiteTensor* input, TfLiteTensor* filter, + TfLiteTensor* bias, TfLiteTensor* output, double* multiplier); + +// Calculates the useful range of an activation layer given its activation +// tensor. +void CalculateActivationRangeUint8(TfLiteFusedActivation activation, + TfLiteTensor* output, int32_t* act_min, + int32_t* act_max); +void CalculateActivationRangeFloat(TfLiteFusedActivation activation, + float* activation_min, + float* activation_max); + +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_KERNEL_UTIL_H_ diff --git a/tensorflow/contrib/lite/kernels/l2norm.cc b/tensorflow/contrib/lite/kernels/l2norm.cc new file mode 100644 index 0000000000000000000000000000000000000000..f43aa372b6398a38e57dd38f3d7c7db2bd3aefc1 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/l2norm.cc @@ -0,0 +1,112 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace l2norm { + +// This file has two implementation of L2Norm. +enum KernelType { + kReference, + kGenericOptimized, +}; + +constexpr int kInputTensor = 0; +constexpr int kOutputTensor = 0; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + // TODO(ahentz): Our current implementations rely on the inputs being 4D. + TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4); + + // TODO(ahentz): Our current implementations only support float32. + TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32); + TF_LITE_ENSURE_EQ(context, input->type, output->type); + + // TODO(ahentz): For some reason our implementations don't support + // activations. + TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActNone); + + TfLiteIntArray* output_size = TfLiteIntArrayCreate(4); + output_size->data[0] = input->dims->data[0]; + output_size->data[1] = input->dims->data[1]; + output_size->data[2] = input->dims->data[2]; + output_size->data[3] = input->dims->data[3]; + + return context->ResizeTensor(context, output, output_size); +} + +template +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + if (output->type == kTfLiteFloat32) { +#define TF_LITE_L2NORM(type) \ + type::L2Normalization( \ + GetTensorData(input), GetTensorDims(input), \ + GetTensorData(output), GetTensorDims(output)) + + if (kernel_type == kReference) { + TF_LITE_L2NORM(reference_ops); + } + if (kernel_type == kGenericOptimized) { + TF_LITE_L2NORM(optimized_ops); + } +#undef TF_LITE_L2NORM + } else { + context->ReportError(context, "Inputs and outputs not all float types."); + return kTfLiteError; + } + + return kTfLiteOk; +} + +} // namespace l2norm + +TfLiteRegistration* Register_L2NORM_REF() { + static TfLiteRegistration r = {nullptr, nullptr, l2norm::Prepare, + l2norm::Eval}; + return &r; +} + +TfLiteRegistration* Register_L2NORM_GENERIC_OPT() { + static TfLiteRegistration r = {nullptr, nullptr, l2norm::Prepare, + l2norm::Eval}; + return &r; +} + +TfLiteRegistration* Register_L2_NORMALIZATION() { + return Register_L2NORM_GENERIC_OPT(); +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/l2norm_test.cc b/tensorflow/contrib/lite/kernels/l2norm_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..30e103f3303484c339ef98e6a68e0438291c102f --- /dev/null +++ b/tensorflow/contrib/lite/kernels/l2norm_test.cc @@ -0,0 +1,63 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class L2NormOpModel : public SingleOpModel { + public: + L2NormOpModel(std::initializer_list input_shape, + ActivationFunctionType activation_type) { + input_ = AddInput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp(BuiltinOperator_L2_NORMALIZATION, BuiltinOptions_L2NormOptions, + CreateL2NormOptions(builder_, activation_type).Union()); + BuildInterpreter({input_shape}); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + + private: + int input_; + int output_; +}; + +TEST(L2NormOpTest, SimpleTest) { + L2NormOpModel m({1, 1, 1, 6}, ActivationFunctionType_NONE); + m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({-0.55, 0.3, 0.35, 0.6, -0.35, 0.05})); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/local_response_norm.cc b/tensorflow/contrib/lite/kernels/local_response_norm.cc new file mode 100644 index 0000000000000000000000000000000000000000..c1c70d0dfa0050dee3815aa15f5d16d2e7ddc721 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/local_response_norm.cc @@ -0,0 +1,109 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace local_response_norm { + +// This file has two implementation of LocalResponseNorm. +enum KernelType { + kReference, + kGenericOptimized, +}; + +constexpr int kInputTensor = 0; +constexpr int kOutputTensor = 0; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4); + + TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32); + TF_LITE_ENSURE_EQ(context, input->type, output->type); + + TfLiteIntArray* output_size = TfLiteIntArrayCreate(4); + output_size->data[0] = input->dims->data[0]; + output_size->data[1] = input->dims->data[1]; + output_size->data[2] = input->dims->data[2]; + output_size->data[3] = input->dims->data[3]; + + return context->ResizeTensor(context, output, output_size); +} + +template +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = + reinterpret_cast(node->builtin_data); + + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + if (output->type == kTfLiteFloat32) { +#define TF_LITE_LOCAL_RESPONSE_NORM(type) \ + type::LocalResponseNormalization( \ + GetTensorData(input), GetTensorDims(input), params->radius, \ + params->bias, params->alpha, params->beta, GetTensorData(output), \ + GetTensorDims(output)) + if (kernel_type == kReference) { + TF_LITE_LOCAL_RESPONSE_NORM(reference_ops); + } + if (kernel_type == kGenericOptimized) { + TF_LITE_LOCAL_RESPONSE_NORM(optimized_ops); + } +#undef TF_LITE_LOCAL_RESPONSE_NORM + } else { + context->ReportError(context, "Inputs and outputs not all float types."); + return kTfLiteError; + } + + return kTfLiteOk; +} + +} // namespace local_response_norm + +TfLiteRegistration* Register_LOCAL_RESPONSE_NORM_REF() { + static TfLiteRegistration r = { + nullptr, nullptr, local_response_norm::Prepare, + local_response_norm::Eval}; + return &r; +} + +TfLiteRegistration* Register_LOCAL_RESPONSE_NORM_GENERIC_OPT() { + static TfLiteRegistration r = { + nullptr, nullptr, local_response_norm::Prepare, + local_response_norm::Eval}; + return &r; +} + +TfLiteRegistration* Register_LOCAL_RESPONSE_NORMALIZATION() { + return Register_LOCAL_RESPONSE_NORM_GENERIC_OPT(); +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/local_response_norm_test.cc b/tensorflow/contrib/lite/kernels/local_response_norm_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..d75ce258a04c820d8f82735988c01d0154ef36f2 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/local_response_norm_test.cc @@ -0,0 +1,101 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class LocalResponseNormOpModel : public SingleOpModel { + public: + LocalResponseNormOpModel(std::initializer_list input_shape, int radius, + float bias, float alpha, float beta) { + input_ = AddInput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp(BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION, + BuiltinOptions_LocalResponseNormalizationOptions, + CreateLocalResponseNormalizationOptions(builder_, radius, bias, + alpha, beta) + .Union()); + BuildInterpreter({input_shape}); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + + private: + int input_; + int output_; +}; + +TEST(LocalResponseNormOpTest, SameAsL2Norm) { + LocalResponseNormOpModel m({1, 1, 1, 6}, /*radius=*/20, /*bias=*/0.0, + /*alpha=*/1.0, /*beta=*/0.5); + m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1}); + m.Invoke(); + // The result is every input divided by 2. + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray(ArrayFloatNear({-0.55, 0.3, 0.35, 0.6, -0.35, 0.05}))); +} + +TEST(LocalResponseNormOpTest, WithAlpha) { + LocalResponseNormOpModel m({1, 1, 1, 6}, /*radius=*/20, /*bias=*/0.0, + /*alpha=*/4.0, /*beta=*/0.5); + m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1}); + m.Invoke(); + // The result is every input divided by 3. + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear( + {-0.275, 0.15, 0.175, 0.3, -0.175, 0.025}))); +} + +TEST(LocalResponseNormOpTest, WithBias) { + LocalResponseNormOpModel m({1, 1, 1, 6}, /*radius=*/20, /*bias=*/9.0, + /*alpha=*/4.0, /*beta=*/0.5); + m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1}); + m.Invoke(); + // The result is every input divided by 5. + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray(ArrayFloatNear({-0.22, 0.12, 0.14, 0.24, -0.14, 0.02}))); +} + +TEST(LocalResponseNormOpTest, SmallRadius) { + LocalResponseNormOpModel m({1, 1, 1, 6}, /*radius=*/2, /*bias=*/9.0, + /*alpha=*/4.0, /*beta=*/0.5); + m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1}); + m.Invoke(); + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray(ArrayFloatNear( + {-0.264926, 0.125109, 0.140112, 0.267261, -0.161788, 0.0244266}))); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/lsh_projection.cc b/tensorflow/contrib/lite/kernels/lsh_projection.cc new file mode 100644 index 0000000000000000000000000000000000000000..5f73b56ed9790b216adc788490faebaabd2bc756 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/lsh_projection.cc @@ -0,0 +1,204 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// LSH Projection projects an input to a bit vector via locality senstive +// hashing. +// +// Options: +// Sparse: +// Computed bit vector is considered to be sparse. +// Each output element is an int32 made up by multiple bits computed from +// hash functions. +// +// Dense: +// Computed bit vector is considered to be dense. Each output element is +// either 0 or 1 that represents a bit. +// +// Input: +// Tensor[0]: Hash functions. Dim.size == 2, DataType: Float. +// Tensor[0].Dim[0]: Num of hash functions. +// Tensor[0].Dim[1]: Num of projected output bits generated by +// each hash function. +// In sparse case, Tensor[0].Dim[1] + ceil( log2(Tensor[0].Dim[0] )) <= 32. +// +// Tensor[1]: Input. Dim.size >= 1, No restriction on DataType. +// Tensor[2]: Optional, Weight. Dim.size == 1, DataType: Float. +// If not set, each element of input is considered to have same +// weight of 1.0 Tensor[1].Dim[0] == Tensor[2].Dim[0] +// +// Output: +// Sparse: +// Output.Dim == { Tensor[0].Dim[0] } +// A tensor of int32 that represents hash signatures, +// +// NOTE: To avoid collisions across hash functions, an offset value of +// k * (1 << Tensor[0].Dim[1]) will be added to each signature, +// k is the index of the hash function. +// Dense: +// Output.Dim == { Tensor[0].Dim[0] * Tensor[0].Dim[1] } +// A flattened tensor represents projected bit vectors. + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include + +namespace tflite { +namespace ops { +namespace builtin { +namespace lsh_projection { + +TfLiteStatus Resize(TfLiteContext* context, TfLiteNode* node) { + auto* params = + reinterpret_cast(node->builtin_data); + TF_LITE_ENSURE(context, NumInputs(node) == 2 || NumInputs(node) == 3); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + TfLiteTensor* hash = GetInput(context, node, 0); + TF_LITE_ENSURE_EQ(context, NumDimensions(hash), 2); + // Support up to 32 bits. + TF_LITE_ENSURE(context, SizeOfDimension(hash, 1) <= 32); + + TfLiteTensor* input = GetInput(context, node, 1); + TF_LITE_ENSURE(context, NumDimensions(input) >= 1); + + if (NumInputs(node) == 3) { + TfLiteTensor* weight = GetInput(context, node, 2); + TF_LITE_ENSURE_EQ(context, NumDimensions(weight), 1); + TF_LITE_ENSURE_EQ(context, SizeOfDimension(weight, 0), + SizeOfDimension(input, 0)); + } + + TfLiteTensor* output = GetOutput(context, node, 0); + TfLiteIntArray* outputSize = TfLiteIntArrayCreate(1); + switch (params->type) { + case kTfLiteLshProjectionSparse: + outputSize->data[0] = SizeOfDimension(hash, 0); + break; + case kTfLiteLshProjectionDense: + outputSize->data[0] = SizeOfDimension(hash, 0) * SizeOfDimension(hash, 1); + break; + default: + return kTfLiteError; + } + return context->ResizeTensor(context, output, outputSize); +} + +// Compute sign bit of dot product of hash(seed, input) and weight. +// NOTE: use float as seed, and convert it to double as a temporary solution +// to match the trained model. This is going to be changed once the new +// model is trained in an optimized method. +// +int RunningSignBit(const TfLiteTensor* input, const TfLiteTensor* weight, + float seed) { + double score = 0.0; + int input_item_bytes = input->bytes / SizeOfDimension(input, 0); + char* input_ptr = input->data.raw; + + const size_t seed_size = sizeof(float); + const size_t key_bytes = sizeof(float) + input_item_bytes; + std::unique_ptr key(new char[key_bytes]); + + for (int i = 0; i < SizeOfDimension(input, 0); ++i) { + // Create running hash id and value for current dimension. + memcpy(key.get(), &seed, seed_size); + memcpy(key.get() + seed_size, input_ptr, input_item_bytes); + + int64_t hash_signature = ::util::Fingerprint64(key.get(), key_bytes); + double running_value = static_cast(hash_signature); + input_ptr += input_item_bytes; + if (weight == nullptr) { + score += running_value; + } else { + score += weight->data.f[i] * running_value; + } + } + + return (score > 0) ? 1 : 0; +} + +void SparseLshProjection(const TfLiteTensor* hash, const TfLiteTensor* input, + const TfLiteTensor* weight, int32_t* out_buf) { + int num_hash = SizeOfDimension(hash, 0); + int num_bits = SizeOfDimension(hash, 1); + for (int i = 0; i < num_hash; i++) { + int32_t hash_signature = 0; + for (int j = 0; j < num_bits; j++) { + float seed = hash->data.f[i * num_bits + j]; + int bit = RunningSignBit(input, weight, seed); + hash_signature = (hash_signature << 1) | bit; + } + *out_buf++ = hash_signature + i * (1 << num_bits); + } +} + +void DenseLshProjection(const TfLiteTensor* hash, const TfLiteTensor* input, + const TfLiteTensor* weight, int32_t* out_buf) { + int num_hash = SizeOfDimension(hash, 0); + int num_bits = SizeOfDimension(hash, 1); + for (int i = 0; i < num_hash; i++) { + for (int j = 0; j < num_bits; j++) { + float seed = hash->data.f[i * num_bits + j]; + int bit = RunningSignBit(input, weight, seed); + *out_buf++ = bit; + } + } +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = + reinterpret_cast(node->builtin_data); + + int32_t* out_buf = GetOutput(context, node, 0)->data.i32; + TfLiteTensor* hash = GetInput(context, node, 0); + TfLiteTensor* input = GetInput(context, node, 1); + TfLiteTensor* weight = + NumInputs(node) == 2 ? nullptr : GetInput(context, node, 2); + + switch (params->type) { + case kTfLiteLshProjectionDense: + DenseLshProjection(hash, input, weight, out_buf); + break; + case kTfLiteLshProjectionSparse: + SparseLshProjection(hash, input, weight, out_buf); + break; + default: + return kTfLiteError; + } + + return kTfLiteOk; +} +} // namespace lsh_projection + +TfLiteRegistration* Register_LSH_PROJECTION() { + static TfLiteRegistration r = {nullptr, nullptr, lsh_projection::Resize, + lsh_projection::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/lsh_projection_test.cc b/tensorflow/contrib/lite/kernels/lsh_projection_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..414d728dfc153058ec878d3c766f58e86815cd3f --- /dev/null +++ b/tensorflow/contrib/lite/kernels/lsh_projection_test.cc @@ -0,0 +1,123 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAre; + +class LSHProjectionOpModel : public SingleOpModel { + public: + LSHProjectionOpModel(LSHProjectionType type, + std::initializer_list hash_shape, + std::initializer_list input_shape, + std::initializer_list weight_shape) { + hash_ = AddInput(TensorType_FLOAT32); + input_ = AddInput(TensorType_INT32); + if (weight_shape.size() > 0) { + weight_ = AddInput(TensorType_FLOAT32); + } + output_ = AddOutput(TensorType_INT32); + + SetBuiltinOp(BuiltinOperator_LSH_PROJECTION, + BuiltinOptions_LSHProjectionOptions, + CreateLSHProjectionOptions(builder_, type).Union()); + if (weight_shape.size() > 0) { + BuildInterpreter({hash_shape, input_shape, weight_shape}); + } else { + BuildInterpreter({hash_shape, input_shape}); + } + + output_size_ = 1; + for (int i : hash_shape) { + output_size_ *= i; + if (type == LSHProjectionType_SPARSE) { + break; + } + } + } + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + void SetHash(std::initializer_list data) { + PopulateTensor(hash_, data); + } + + void SetWeight(std::initializer_list f) { PopulateTensor(weight_, f); } + + std::vector GetOutput() { return ExtractVector(output_); } + + private: + int input_; + int hash_; + int weight_; + int output_; + + int output_size_; +}; + +TEST(LSHProjectionOpTest2, Dense1DInputs) { + LSHProjectionOpModel m(LSHProjectionType_DENSE, {3, 2}, {5}, {5}); + + m.SetInput({12345, 54321, 67890, 9876, -12345678}); + m.SetHash({0.123, 0.456, -0.321, 1.234, 5.678, -4.321}); + m.SetWeight({1.0, 1.0, 1.0, 1.0, 1.0}); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAre(0, 0, 0, 1, 0, 0)); +} + +TEST(LSHProjectionOpTest2, Sparse1DInputs) { + LSHProjectionOpModel m(LSHProjectionType_SPARSE, {3, 2}, {5}, {}); + + m.SetInput({12345, 54321, 67890, 9876, -12345678}); + m.SetHash({0.123, 0.456, -0.321, 1.234, 5.678, -4.321}); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAre(0 + 0, 4 + 1, 8 + 0)); +} + +TEST(LSHProjectionOpTest2, Sparse3DInputs) { + LSHProjectionOpModel m(LSHProjectionType_SPARSE, {3, 2}, {5, 2, 2}, {5}); + + m.SetInput({1234, 2345, 3456, 1234, 4567, 5678, 6789, 4567, 7891, 8912, + 9123, 7890, -987, -876, -765, -987, -543, -432, -321, -543}); + m.SetHash({0.123, 0.456, -0.321, 1.234, 5.678, -4.321}); + m.SetWeight({0.12, 0.34, 0.56, 0.67, 0.78}); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAre(0 + 2, 4 + 1, 8 + 1)); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/lstm.cc b/tensorflow/contrib/lite/kernels/lstm.cc new file mode 100644 index 0000000000000000000000000000000000000000..6c06264d845c24e71647b6fd2374734be32383ef --- /dev/null +++ b/tensorflow/contrib/lite/kernels/lstm.cc @@ -0,0 +1,515 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/activation_functor.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace lstm { + +// Input Tensors of size {n_batch, n_input} +constexpr int kInputTensor = 0; + +// Input weight tensors of size: {n_cell, n_input} +constexpr int kInputToInputWeightsTensor = 1; // Optional +constexpr int kInputToForgetWeightsTensor = 2; +constexpr int kInputToCellWeightsTensor = 3; +constexpr int kInputToOutputWeightsTensor = 4; + +// Recurrent weight tensors of size {n_cell, n_output} +constexpr int kRecurrentToInputWeightsTensor = 5; // Optional +constexpr int kRecurrentToForgetWeightsTensor = 6; +constexpr int kRecurrentToCellWeightsTensor = 7; +constexpr int kRecurrentToOutputWeightsTensor = 8; + +// Peephole weights tensors of size {n_cell}, representing a diagonal matrix. +constexpr int kCellToInputWeightsTensor = 9; // Optional +constexpr int kCellToForgetWeightsTensor = 10; // Optional +constexpr int kCellToOutputWeightsTensor = 11; // Optional + +// Gates bias tensors of size {n_cell} +constexpr int kInputGateBiasTensor = 12; // Optional +constexpr int kForgetGateBiasTensor = 13; +constexpr int kCellGateBiasTensor = 14; +constexpr int kOutputGateBiasTensor = 15; + +// Projection weight tensor of size {n_output, n_cell} +constexpr int kProjectionWeightsTensor = 16; // Optional +// Projection bias tensor of size {n_output} +constexpr int kProjectionBiasTensor = 17; // Optional + +// Output tensors. +constexpr int kScratchBufferTensor = 0; +constexpr int kOutputStateTensor = 1; +constexpr int kCellStateTensor = 2; +constexpr int kOutputTensor = 3; + +// Check that input tensor dimensions matches with each other. +TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, + TfLiteNode* node, int n_input, + int n_output, int n_cell) { + auto* params = reinterpret_cast(node->builtin_data); + + // Making sure clipping parameters have valid values. + // == 0 means no clipping + // > 0 means clipping + TF_LITE_ENSURE(context, params->cell_clip >= 0); + TF_LITE_ENSURE(context, params->proj_clip >= 0); + + TfLiteTensor* input_to_input_weights = + GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); + if (input_to_input_weights) { + TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[0], n_cell); + TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[1], n_input); + } + + TfLiteTensor* input_to_forget_weights = + GetInput(context, node, kInputToForgetWeightsTensor); + TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[0], n_cell); + TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[1], n_input); + + TfLiteTensor* input_to_cell_weights = + GetInput(context, node, kInputToCellWeightsTensor); + TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[0], n_cell); + TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[1], n_input); + + TfLiteTensor* recurrent_to_input_weights = + GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor); + if (recurrent_to_input_weights) { + TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[0], + n_cell); + TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[1], + n_output); + } + + TfLiteTensor* recurrent_to_forget_weights = + GetInput(context, node, kRecurrentToForgetWeightsTensor); + TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[0], + n_cell); + TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[1], + n_output); + + TfLiteTensor* recurrent_to_cell_weights = + GetInput(context, node, kRecurrentToCellWeightsTensor); + TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[0], n_cell); + TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[1], + n_output); + + // We make sure the input-gate's parameters are either both present (regular + // LSTM) or not at all (CIFG-LSTM). + const bool cifg_weights_all_or_none = + ((input_to_input_weights != nullptr) && + (recurrent_to_input_weights != nullptr)) || + ((input_to_input_weights == nullptr) && + (recurrent_to_input_weights == nullptr)); + TF_LITE_ENSURE(context, cifg_weights_all_or_none == true); + + TfLiteTensor* cell_to_input_weights = + GetOptionalInputTensor(context, node, kCellToInputWeightsTensor); + if (cell_to_input_weights) { + TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->size, 1); + TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->data[0], n_cell); + } + + TfLiteTensor* cell_to_forget_weights = + GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor); + if (cell_to_forget_weights) { + TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->size, 1); + TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->data[0], n_cell); + } + + TfLiteTensor* cell_to_output_weights = + GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor); + if (cell_to_output_weights) { + TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->size, 1); + TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->data[0], n_cell); + } + + // Making sure the peephole weights are there all or none. + const bool use_cifg = (input_to_input_weights == nullptr); + const bool peephole_weights_all_or_none = + ((cell_to_input_weights != nullptr || use_cifg) && + (cell_to_forget_weights != nullptr) && + (cell_to_output_weights != nullptr)) || + ((cell_to_input_weights == nullptr) && + (cell_to_forget_weights == nullptr) && + (cell_to_output_weights == nullptr)); + TF_LITE_ENSURE(context, peephole_weights_all_or_none == true); + + // Make sure the input gate bias is present only when not a CIFG-LSTM. + TfLiteTensor* input_gate_bias = + GetOptionalInputTensor(context, node, kInputGateBiasTensor); + if (use_cifg) { + TF_LITE_ENSURE_EQ(context, input_gate_bias, nullptr); + } else { + TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->size, 1); + TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->data[0], n_cell); + } + + TfLiteTensor* forget_gate_bias = + GetInput(context, node, kForgetGateBiasTensor); + TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->size, 1); + TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->data[0], n_cell); + + TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor); + TF_LITE_ENSURE_EQ(context, cell_bias->dims->size, 1); + TF_LITE_ENSURE_EQ(context, cell_bias->dims->data[0], n_cell); + + TfLiteTensor* output_gate_bias = + GetInput(context, node, kOutputGateBiasTensor); + TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->size, 1); + TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->data[0], n_cell); + + TfLiteTensor* projection_weights = + GetOptionalInputTensor(context, node, kProjectionWeightsTensor); + if (projection_weights) { + TF_LITE_ENSURE_EQ(context, projection_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[0], n_output); + TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[1], n_cell); + } + + TfLiteTensor* projection_bias = + GetOptionalInputTensor(context, node, kProjectionBiasTensor); + if (projection_bias) { + TF_LITE_ENSURE_EQ(context, projection_bias->dims->size, 1); + TF_LITE_ENSURE_EQ(context, projection_bias->dims->data[0], n_output); + } + + // Making sure the projection tensors are consistent: + // 1) If projection weight is not present, then projection bias should not be + // present. + // 2) If projection weight is present, then projection bias is optional. + // TODO(ghodrat): make sure this is correct. + const bool projecton_tensors_consistent = + ((projection_weights != nullptr) || (projection_bias == nullptr)); + TF_LITE_ENSURE(context, projecton_tensors_consistent == true); + + return kTfLiteOk; +} + +// Resize the output, state and scratch tensors based on the sizes of the input +// tensors. Also check that the size of the input tensors match each other. +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + // Check we have all the inputs and outputs we need. + TF_LITE_ENSURE_EQ(context, node->inputs->size, 18); + TF_LITE_ENSURE_EQ(context, node->outputs->size, 4); + + // Inferring batch size, number of outputs and number of cells from the + // input tensors. + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TF_LITE_ENSURE(context, input->dims->size > 1); + const int n_batch = input->dims->data[0]; + const int n_input = input->dims->data[1]; + + TfLiteTensor* input_to_output_weights = + GetInput(context, node, kInputToOutputWeightsTensor); + const int n_cell = input_to_output_weights->dims->data[0]; + TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->data[1], n_input); + + TfLiteTensor* recurrent_to_output_weights = + GetInput(context, node, kRecurrentToOutputWeightsTensor); + TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->data[0], + n_cell); + const int n_output = recurrent_to_output_weights->dims->data[1]; + + // Check that input tensor dimensions matches with each other. + CheckInputTensorDimensions(context, node, n_input, n_output, n_cell); + + // Get the pointer to output, state and scratch buffer tensors. + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor); + TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor); + // TODO(ghodrat): Modify this as soon as we have a finalized method for + // scratch buffers. + TfLiteTensor* scratch_buffer = GetOutput(context, node, kScratchBufferTensor); + + // Resize the output and output_state tensors. + TfLiteIntArray* output_size = TfLiteIntArrayCreate(2); + output_size->data[0] = n_batch; + output_size->data[1] = n_output; + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, output, output_size)); + + TfLiteIntArray* output_state_size = TfLiteIntArrayCreate(2); + output_state_size->data[0] = n_batch; + output_state_size->data[1] = n_output; + TF_LITE_ENSURE_OK( + context, context->ResizeTensor(context, output_state, output_state_size)); + + // Resize the output, state and scratch buffer tensors. + TfLiteIntArray* cell_size = TfLiteIntArrayCreate(2); + cell_size->data[0] = n_batch; + cell_size->data[1] = n_cell; + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, cell_state, cell_size)); + + // Mark state tensors as persistent tensors. + output_state->allocation_type = kTfLiteArenaRwPersistent; + cell_state->allocation_type = kTfLiteArenaRwPersistent; + + TfLiteTensor* input_to_input_weights = + GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); + const bool use_cifg = (input_to_input_weights == nullptr); + if (use_cifg) { + TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2); + scratch_buffer_size->data[0] = n_batch; + // Reserving space for Cell, Forget, Output gates + scratch_buffer_size->data[1] = n_cell * 3; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer, + scratch_buffer_size)); + } else { + TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2); + scratch_buffer_size->data[0] = n_batch; + // Reserving space for Input, Cell, Forget, Output gates + scratch_buffer_size->data[1] = n_cell * 4; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer, + scratch_buffer_size)); + } + return kTfLiteOk; +} + +// The LSTM Op engine. +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + TfLiteTensor* input = GetInput(context, node, kInputTensor); + + TfLiteTensor* input_to_input_weights = + GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); + TfLiteTensor* input_to_forget_weights = + GetInput(context, node, kInputToForgetWeightsTensor); + TfLiteTensor* input_to_cell_weights = + GetInput(context, node, kInputToCellWeightsTensor); + TfLiteTensor* input_to_output_weights = + GetInput(context, node, kInputToOutputWeightsTensor); + + TfLiteTensor* recurrent_to_input_weights = + GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor); + TfLiteTensor* recurrent_to_forget_weights = + GetInput(context, node, kRecurrentToForgetWeightsTensor); + TfLiteTensor* recurrent_to_cell_weights = + GetInput(context, node, kRecurrentToCellWeightsTensor); + TfLiteTensor* recurrent_to_output_weights = + GetInput(context, node, kRecurrentToOutputWeightsTensor); + + TfLiteTensor* cell_to_input_weights = + GetOptionalInputTensor(context, node, kCellToInputWeightsTensor); + TfLiteTensor* cell_to_forget_weights = + GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor); + TfLiteTensor* cell_to_output_weights = + GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor); + + TfLiteTensor* input_gate_bias = + GetOptionalInputTensor(context, node, kInputGateBiasTensor); + TfLiteTensor* forget_gate_bias = + GetInput(context, node, kForgetGateBiasTensor); + TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor); + TfLiteTensor* output_gate_bias = + GetInput(context, node, kOutputGateBiasTensor); + + TfLiteTensor* projection_weights = + GetOptionalInputTensor(context, node, kProjectionWeightsTensor); + TfLiteTensor* projection_bias = + GetOptionalInputTensor(context, node, kProjectionBiasTensor); + + TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor); + TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + const int n_batch = input->dims->data[0]; + const int n_input = input->dims->data[1]; + // n_cell and n_output will be the same size when there is no projection. + const int n_cell = input_to_output_weights->dims->data[0]; + const int n_output = recurrent_to_output_weights->dims->data[1]; + + // Since we have already checked that weights are all there or none, we can + // check the existense of only one to the get the condition. + const bool use_cifg = (input_to_input_weights == nullptr); + const bool use_peephole = (cell_to_output_weights != nullptr); + + // Index the scratch buffers pointers to the global scratch buffer. + TfLiteTensor* scratch_buffer = GetOutput(context, node, kScratchBufferTensor); + float* input_gate_scratch = nullptr; + float* cell_scratch = nullptr; + float* forget_gate_scratch = nullptr; + float* output_gate_scratch = nullptr; + if (use_cifg) { + cell_scratch = scratch_buffer->data.f; + forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch; + output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch; + } else { + input_gate_scratch = scratch_buffer->data.f; + cell_scratch = scratch_buffer->data.f + n_cell * n_batch; + forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch; + output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch; + } + + // Initialize scratch buffers with bias. + if (!use_cifg) { + tensor_utils::VectorBatchVectorAssign(input_gate_bias->data.f, n_cell, + n_batch, input_gate_scratch); + } + tensor_utils::VectorBatchVectorAssign(forget_gate_bias->data.f, n_cell, + n_batch, forget_gate_scratch); + tensor_utils::VectorBatchVectorAssign(cell_bias->data.f, n_cell, n_batch, + cell_scratch); + tensor_utils::VectorBatchVectorAssign(output_gate_bias->data.f, n_cell, + n_batch, output_gate_scratch); + + // For each batch and cell: compute input_weight * input. + if (!use_cifg) { + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_input_weights->data.f, n_cell, n_input, input->data.f, n_batch, + input_gate_scratch, /*result_stride=*/1); + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_forget_weights->data.f, n_cell, n_input, input->data.f, n_batch, + forget_gate_scratch, /*result_stride=*/1); + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_cell_weights->data.f, n_cell, n_input, input->data.f, n_batch, + cell_scratch, /*result_stride=*/1); + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_output_weights->data.f, n_cell, n_input, input->data.f, n_batch, + output_gate_scratch, /*result_stride=*/1); + + // For each batch and cell: compute recurrent_weight * output_state. + if (!use_cifg) { + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_input_weights->data.f, n_cell, n_output, + output_state->data.f, n_batch, input_gate_scratch, /*result_stride=*/1); + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_forget_weights->data.f, n_cell, n_output, + output_state->data.f, n_batch, forget_gate_scratch, /*result_stride=*/1); + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_cell_weights->data.f, n_cell, n_output, output_state->data.f, + n_batch, cell_scratch, /*result_stride=*/1); + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_output_weights->data.f, n_cell, n_output, + output_state->data.f, n_batch, output_gate_scratch, /*result_stride=*/1); + + // For each batch and cell: update input gate. + if (!use_cifg) { + if (use_peephole) { + tensor_utils::VectorBatchVectorCwiseProductAccumulate( + cell_to_input_weights->data.f, n_cell, cell_state->data.f, n_batch, + input_gate_scratch); + } + tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch, + input_gate_scratch); + } + + // For each batch and cell: update forget gate. + if (use_peephole) { + tensor_utils::VectorBatchVectorCwiseProductAccumulate( + cell_to_forget_weights->data.f, n_cell, cell_state->data.f, n_batch, + forget_gate_scratch); + } + tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch, + forget_gate_scratch); + + // For each batch and cell: update the cell. + tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, + cell_state->data.f, n_batch * n_cell, + cell_state->data.f); + tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell, + params->activation, cell_scratch); + if (use_cifg) { + tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell, + forget_gate_scratch); + tensor_utils::VectorVectorCwiseProductAccumulate( + cell_scratch, forget_gate_scratch, n_batch * n_cell, + cell_state->data.f); + } else { + tensor_utils::VectorVectorCwiseProductAccumulate( + cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state->data.f); + } + if (params->cell_clip > 0.0) { + tensor_utils::ClipVector(cell_state->data.f, n_batch * n_cell, + params->cell_clip, cell_state->data.f); + } + + // For each batch and cell: update the output gate. + if (use_peephole) { + tensor_utils::VectorBatchVectorCwiseProductAccumulate( + cell_to_output_weights->data.f, n_cell, cell_state->data.f, n_batch, + output_gate_scratch); + } + tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell, + output_gate_scratch); + tensor_utils::ApplyActivationToVector(cell_state->data.f, n_batch * n_cell, + params->activation, cell_scratch); + tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch, + n_batch * n_cell, output_gate_scratch); + + // For each batch: update the projection and output_state. + const bool use_projection_weight = (projection_weights != nullptr); + const bool use_projection_bias = (projection_bias != nullptr); + if (use_projection_weight) { + if (use_projection_bias) { + tensor_utils::VectorBatchVectorAssign(projection_bias->data.f, n_output, + n_batch, output->data.f); + } else { + tensor_utils::ZeroVector(output->data.f, n_batch * n_output); + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + projection_weights->data.f, n_output, n_cell, output_gate_scratch, + n_batch, output->data.f, /*result_stride=*/1); + if (params->proj_clip > 0.0) { + tensor_utils::ClipVector(output->data.f, n_batch * n_output, + params->proj_clip, output->data.f); + } + } else { + tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output, + output->data.f); + } + tensor_utils::CopyVector(output->data.f, n_batch * n_output, + output_state->data.f); + + return kTfLiteOk; +} + +} // namespace lstm + +TfLiteRegistration* Register_LSTM() { + static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, + lstm::Prepare, lstm::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/lstm_test.cc b/tensorflow/contrib/lite/kernels/lstm_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..c068286b0d84bcb51ebb0e239350a42863de6523 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/lstm_test.cc @@ -0,0 +1,1087 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Unit test for TFLite LSTM op. + +#include +#include +#include + +#include +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class LSTMOpModel : public SingleOpModel { + public: + LSTMOpModel(int n_batch, int n_input, int n_cell, int n_output, bool use_cifg, + bool use_peephole, bool use_projection_weights, + bool use_projection_bias, float cell_clip, float proj_clip, + const std::vector>& input_shapes) + : n_batch_(n_batch), + n_input_(n_input), + n_cell_(n_cell), + n_output_(n_output) { + input_ = AddInput(TensorType_FLOAT32); + + if (use_cifg) { + input_to_input_weights_ = AddNullInput(); + } else { + input_to_input_weights_ = AddInput(TensorType_FLOAT32); + } + + input_to_forget_weights_ = AddInput(TensorType_FLOAT32); + input_to_cell_weights_ = AddInput(TensorType_FLOAT32); + input_to_output_weights_ = AddInput(TensorType_FLOAT32); + + if (use_cifg) { + recurrent_to_input_weights_ = AddNullInput(); + } else { + recurrent_to_input_weights_ = AddInput(TensorType_FLOAT32); + } + + recurrent_to_forget_weights_ = AddInput(TensorType_FLOAT32); + recurrent_to_cell_weights_ = AddInput(TensorType_FLOAT32); + recurrent_to_output_weights_ = AddInput(TensorType_FLOAT32); + + if (use_peephole) { + if (use_cifg) { + cell_to_input_weights_ = AddNullInput(); + } else { + cell_to_input_weights_ = AddInput(TensorType_FLOAT32); + } + cell_to_forget_weights_ = AddInput(TensorType_FLOAT32); + cell_to_output_weights_ = AddInput(TensorType_FLOAT32); + } else { + cell_to_input_weights_ = AddNullInput(); + cell_to_forget_weights_ = AddNullInput(); + cell_to_output_weights_ = AddNullInput(); + } + + if (use_cifg) { + input_gate_bias_ = AddNullInput(); + } else { + input_gate_bias_ = AddInput(TensorType_FLOAT32); + } + forget_gate_bias_ = AddInput(TensorType_FLOAT32); + cell_bias_ = AddInput(TensorType_FLOAT32); + output_gate_bias_ = AddInput(TensorType_FLOAT32); + + if (use_projection_weights) { + projection_weights_ = AddInput(TensorType_FLOAT32); + if (use_projection_bias) { + projection_bias_ = AddInput(TensorType_FLOAT32); + } else { + projection_bias_ = AddNullInput(); + } + } else { + projection_weights_ = AddNullInput(); + projection_bias_ = AddNullInput(); + } + + scratch_buffer_ = AddOutput(TensorType_FLOAT32); + // TODO(ghodrat): Modify these states when we have a permanent solution for + // persistent buffer. + output_state_ = AddOutput(TensorType_FLOAT32); + cell_state_ = AddOutput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + + SetBuiltinOp(BuiltinOperator_LSTM, BuiltinOptions_LSTMOptions, + CreateLSTMOptions(builder_, ActivationFunctionType_TANH, + cell_clip, proj_clip) + .Union()); + BuildInterpreter(input_shapes); + } + + void SetInputToInputWeights(std::initializer_list f) { + PopulateTensor(input_to_input_weights_, f); + } + + void SetInputToForgetWeights(std::initializer_list f) { + PopulateTensor(input_to_forget_weights_, f); + } + + void SetInputToCellWeights(std::initializer_list f) { + PopulateTensor(input_to_cell_weights_, f); + } + + void SetInputToOutputWeights(std::initializer_list f) { + PopulateTensor(input_to_output_weights_, f); + } + + void SetRecurrentToInputWeights(std::initializer_list f) { + PopulateTensor(recurrent_to_input_weights_, f); + } + + void SetRecurrentToForgetWeights(std::initializer_list f) { + PopulateTensor(recurrent_to_forget_weights_, f); + } + + void SetRecurrentToCellWeights(std::initializer_list f) { + PopulateTensor(recurrent_to_cell_weights_, f); + } + + void SetRecurrentToOutputWeights(std::initializer_list f) { + PopulateTensor(recurrent_to_output_weights_, f); + } + + void SetCellToInputWeights(std::initializer_list f) { + PopulateTensor(cell_to_input_weights_, f); + } + + void SetCellToForgetWeights(std::initializer_list f) { + PopulateTensor(cell_to_forget_weights_, f); + } + + void SetCellToOutputWeights(std::initializer_list f) { + PopulateTensor(cell_to_output_weights_, f); + } + + void SetInputGateBias(std::initializer_list f) { + PopulateTensor(input_gate_bias_, f); + } + + void SetForgetGateBias(std::initializer_list f) { + PopulateTensor(forget_gate_bias_, f); + } + + void SetCellBias(std::initializer_list f) { + PopulateTensor(cell_bias_, f); + } + + void SetOutputGateBias(std::initializer_list f) { + PopulateTensor(output_gate_bias_, f); + } + + void SetProjectionWeights(std::initializer_list f) { + PopulateTensor(projection_weights_, f); + } + + void SetProjectionBias(std::initializer_list f) { + PopulateTensor(projection_bias_, f); + } + + void ResetOutputState() { + const int zero_buffer_size = n_cell_ * n_batch_; + std::unique_ptr zero_buffer(new float[zero_buffer_size]); + memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float)); + PopulateTensor(output_state_, 0, zero_buffer.get(), + zero_buffer.get() + zero_buffer_size); + } + + void ResetCellState() { + const int zero_buffer_size = n_cell_ * n_batch_; + std::unique_ptr zero_buffer(new float[zero_buffer_size]); + memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float)); + PopulateTensor(cell_state_, 0, zero_buffer.get(), + zero_buffer.get() + zero_buffer_size); + } + + void SetInput(int offset, float* begin, float* end) { + PopulateTensor(input_, offset, begin, end); + } + + std::vector GetOutput() { return ExtractVector(output_); } + + int num_inputs() { return n_input_; } + int num_outputs() { return n_output_; } + int num_cells() { return n_cell_; } + int num_batches() { return n_batch_; } + + private: + int input_; + int input_to_input_weights_; + int input_to_forget_weights_; + int input_to_cell_weights_; + int input_to_output_weights_; + + int recurrent_to_input_weights_; + int recurrent_to_forget_weights_; + int recurrent_to_cell_weights_; + int recurrent_to_output_weights_; + + int cell_to_input_weights_; + int cell_to_forget_weights_; + int cell_to_output_weights_; + + int input_gate_bias_; + int forget_gate_bias_; + int cell_bias_; + int output_gate_bias_; + + int projection_weights_; + int projection_bias_; + + int output_; + int output_state_; + int cell_state_; + int scratch_buffer_; + + int n_batch_; + int n_input_; + int n_cell_; + int n_output_; +}; + +TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) { + const int n_batch = 1; + const int n_input = 2; + // n_cell and n_output have the same size when there is no projection. + const int n_cell = 4; + const int n_output = 4; + + LSTMOpModel lstm(n_batch, n_input, n_cell, n_output, + /*use_cifg=*/false, /*use_peephole=*/false, + /*use_projection_weights=*/false, + /*use_projection_bias=*/false, + /*cell_clip=*/0.0, /*proj_clip=*/0.0, + { + {n_batch, n_input}, // input tensor + + {n_cell, n_input}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {n_cell, n_output}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {0}, // cell_to_input_weight tensor + {0}, // cell_to_forget_weight tensor + {0}, // cell_to_output_weight tensor + + {n_cell}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {0, 0}, // projection_weight tensor + {0}, // projection_bias tensor + }); + + lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589, + -0.34550029, 0.04266912, -0.15680569, + -0.34856534, 0.43890524}); + + lstm.SetInputToCellWeights({-0.50013041, 0.1370284, 0.11810488, 0.2013163, + -0.20583314, 0.44344562, 0.22077113, + -0.29909778}); + + lstm.SetInputToForgetWeights({0.09701663, 0.20334584, -0.50592935, + -0.31343272, -0.40032279, 0.44781327, + 0.01387155, -0.35593212}); + + lstm.SetInputToOutputWeights({-0.25065863, -0.28290087, 0.04613829, + 0.40525138, 0.44272184, 0.03897077, -0.1556896, + 0.19487578}); + + lstm.SetInputGateBias({0., 0., 0., 0.}); + + lstm.SetCellBias({0., 0., 0., 0.}); + + lstm.SetForgetGateBias({1., 1., 1., 1.}); + + lstm.SetOutputGateBias({0., 0., 0., 0.}); + + lstm.SetRecurrentToInputWeights( + {-0.0063535, -0.2042388, 0.31454784, -0.35746509, 0.28902304, 0.08183324, + -0.16555229, 0.02286911, -0.13566875, 0.03034258, 0.48091322, + -0.12528998, 0.24077177, -0.51332325, -0.33502164, 0.10629296}); + + lstm.SetRecurrentToCellWeights( + {-0.3407414, 0.24443203, -0.2078532, 0.26320225, 0.05695659, -0.00123841, + -0.4744786, -0.35869038, -0.06418842, -0.13502428, -0.501764, 0.22830659, + -0.46367589, 0.26016325, -0.03894562, -0.16368064}); + + lstm.SetRecurrentToForgetWeights( + {-0.48684245, -0.06655136, 0.42224967, 0.2112639, 0.27654213, 0.20864892, + -0.07646349, 0.45877004, 0.00141793, -0.14609534, 0.36447752, 0.09196436, + 0.28053468, 0.01560611, -0.20127171, -0.01140004}); + + lstm.SetRecurrentToOutputWeights( + {0.43385774, -0.17194885, 0.2718237, 0.09215671, 0.24107647, -0.39835793, + 0.18212086, 0.01301402, 0.48572797, -0.50656658, 0.20047462, -0.20607421, + -0.51818722, -0.15390486, 0.0468148, 0.39922136}); + + static float lstm_input[] = {2., 3., 3., 4., 1., 1.}; + static float lstm_golden_output[] = {-0.02973187, 0.1229473, 0.20885126, + -0.15358765, -0.03716109, 0.12507336, + 0.41193449, -0.20860538, -0.15053082, + 0.09120187, 0.24278517, -0.12222792}; + + // Resetting cell_state and output_state + lstm.ResetCellState(); + lstm.ResetOutputState(); + + const int input_sequence_size = + sizeof(lstm_input) / sizeof(float) / (lstm.num_inputs()); + for (int i = 0; i < input_sequence_size; i++) { + float* batch0_start = lstm_input + i * lstm.num_inputs(); + float* batch0_end = batch0_start + lstm.num_inputs(); + + lstm.SetInput(0, batch0_start, batch0_end); + + lstm.Invoke(); + + float* golden_start = lstm_golden_output + i * lstm.num_outputs(); + float* golden_end = golden_start + lstm.num_outputs(); + std::vector expected; + expected.insert(expected.end(), golden_start, golden_end); + EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); + } +} + +TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) { + const int n_batch = 1; + const int n_input = 2; + // n_cell and n_output have the same size when there is no projection. + const int n_cell = 4; + const int n_output = 4; + + LSTMOpModel lstm(n_batch, n_input, n_cell, n_output, + /*use_cifg=*/true, /*use_peephole=*/true, + /*use_projection_weights=*/false, + /*use_projection_bias=*/false, + /*cell_clip=*/0.0, /*proj_clip=*/0.0, + { + {n_batch, n_input}, // input tensor + + {0, 0}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {0, 0}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {0}, // cell_to_input_weight tensor + {n_cell}, // cell_to_forget_weight tensor + {n_cell}, // cell_to_output_weight tensor + + {0}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {0, 0}, // projection_weight tensor + {0}, // projection_bias tensor + }); + + lstm.SetInputToCellWeights({-0.49770179, -0.27711356, -0.09624726, 0.05100781, + 0.04717243, 0.48944736, -0.38535351, + -0.17212132}); + + lstm.SetInputToForgetWeights({-0.55291498, -0.42866567, 0.13056988, + -0.3633365, -0.22755712, 0.28253698, 0.24407166, + 0.33826375}); + + lstm.SetInputToOutputWeights({0.10725588, -0.02335852, -0.55932593, + -0.09426838, -0.44257352, 0.54939759, + 0.01533556, 0.42751634}); + + lstm.SetCellBias({0., 0., 0., 0.}); + + lstm.SetForgetGateBias({1., 1., 1., 1.}); + + lstm.SetOutputGateBias({0., 0., 0., 0.}); + + lstm.SetRecurrentToCellWeights( + {0.54066205, -0.32668582, -0.43562764, -0.56094903, 0.42957711, + 0.01841056, -0.32764608, -0.33027974, -0.10826075, 0.20675004, + 0.19069612, -0.03026325, -0.54532051, 0.33003211, 0.44901288, + 0.21193194}); + + lstm.SetRecurrentToForgetWeights( + {-0.13832897, -0.0515101, -0.2359007, -0.16661474, -0.14340827, + 0.36986142, 0.23414481, 0.55899, 0.10798943, -0.41174671, 0.17751795, + -0.34484994, -0.35874045, -0.11352962, 0.27268326, 0.54058349}); + + lstm.SetRecurrentToOutputWeights( + {0.41613156, 0.42610586, -0.16495961, -0.5663873, 0.30579174, -0.05115908, + -0.33941799, 0.23364776, 0.11178309, 0.09481031, -0.26424935, 0.46261835, + 0.50248802, 0.26114327, -0.43736315, 0.33149987}); + + lstm.SetCellToForgetWeights( + {0.47485286, -0.51955009, -0.24458408, 0.31544167}); + lstm.SetCellToOutputWeights( + {-0.17135078, 0.82760304, 0.85573703, -0.77109635}); + + static float lstm_input[] = {2., 3., 3., 4., 1., 1.}; + static float lstm_golden_output[] = {-0.36444446, -0.00352185, 0.12886585, + -0.05163646, -0.42312205, -0.01218222, + 0.24201041, -0.08124574, -0.358325, + -0.04621704, 0.21641694, -0.06471302}; + + // Resetting cell_state and output_state + lstm.ResetCellState(); + lstm.ResetOutputState(); + + const int input_sequence_size = + sizeof(lstm_input) / sizeof(float) / (lstm.num_inputs()); + for (int i = 0; i < input_sequence_size; i++) { + float* batch0_start = lstm_input + i * lstm.num_inputs(); + float* batch0_end = batch0_start + lstm.num_inputs(); + + lstm.SetInput(0, batch0_start, batch0_end); + + lstm.Invoke(); + + float* golden_start = lstm_golden_output + i * lstm.num_outputs(); + float* golden_end = golden_start + lstm.num_outputs(); + std::vector expected; + expected.insert(expected.end(), golden_start, golden_end); + EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); + } +} + +TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClipping) { + const int n_batch = 2; + const int n_input = 5; + const int n_cell = 20; + const int n_output = 16; + + LSTMOpModel lstm(n_batch, n_input, n_cell, n_output, + /*use_cifg=*/false, /*use_peephole=*/true, + /*use_projection_weights=*/true, + /*use_projection_bias=*/false, + /*cell_clip=*/0.0, /*proj_clip=*/0.0, + { + {n_batch, n_input}, // input tensor + + {n_cell, n_input}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {n_cell, n_output}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {n_cell}, // cell_to_input_weight tensor + {n_cell}, // cell_to_forget_weight tensor + {n_cell}, // cell_to_output_weight tensor + + {n_cell}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {n_output, n_cell}, // projection_weight tensor + {0}, // projection_bias tensor + }); + + lstm.SetInputToInputWeights( + {0.021393683, 0.06124551, 0.046905167, -0.014657677, -0.03149463, + 0.09171803, 0.14647801, 0.10797193, -0.0057968358, 0.0019193048, + -0.2726754, 0.10154029, -0.018539885, 0.080349885, -0.10262385, + -0.022599787, -0.09121155, -0.008675967, -0.045206103, -0.0821282, + -0.008045952, 0.015478081, 0.055217247, 0.038719587, 0.044153627, + -0.06453243, 0.05031825, -0.046935108, -0.008164439, 0.014574226, + -0.1671009, -0.15519552, -0.16819797, -0.13971269, -0.11953059, + 0.25005487, -0.22790983, 0.009855087, -0.028140958, -0.11200698, + 0.11295408, -0.0035217577, 0.054485075, 0.05184695, 0.064711206, + 0.10989193, 0.11674786, 0.03490607, 0.07727357, 0.11390585, + -0.1863375, -0.1034451, -0.13945189, -0.049401227, -0.18767063, + 0.042483903, 0.14233552, 0.13832581, 0.18350165, 0.14545603, + -0.028545704, 0.024939531, 0.050929718, 0.0076203286, -0.0029723682, + -0.042484224, -0.11827596, -0.09171104, -0.10808628, -0.16327988, + -0.2273378, -0.0993647, -0.017155107, 0.0023917493, 0.049272764, + 0.0038534778, 0.054764505, 0.089753784, 0.06947234, 0.08014476, + -0.04544234, -0.0497073, -0.07135631, -0.048929106, -0.004042012, + -0.009284026, 0.018042054, 0.0036860977, -0.07427302, -0.11434604, + -0.018995456, 0.031487543, 0.012834908, 0.019977754, 0.044256654, + -0.39292613, -0.18519334, -0.11651281, -0.06809892, 0.011373677}); + + lstm.SetInputToForgetWeights( + {-0.0018401089, -0.004852237, 0.03698424, 0.014181704, 0.028273236, + -0.016726194, -0.05249759, -0.10204261, 0.00861066, -0.040979505, + -0.009899187, 0.01923892, -0.028177269, -0.08535103, -0.14585495, + 0.10662567, -0.01909731, -0.017883534, -0.0047269356, -0.045103323, + 0.0030784295, 0.076784775, 0.07463696, 0.094531395, 0.0814421, + -0.12257899, -0.033945758, -0.031303465, 0.045630626, 0.06843887, + -0.13492945, -0.012480007, -0.0811829, -0.07224499, -0.09628791, + 0.045100946, 0.0012300825, 0.013964662, 0.099372394, 0.02543059, + 0.06958324, 0.034257296, 0.0482646, 0.06267997, 0.052625068, + 0.12784666, 0.07077897, 0.025725935, 0.04165009, 0.07241905, + 0.018668644, -0.037377294, -0.06277783, -0.08833636, -0.040120605, + -0.011405586, -0.007808335, -0.010301386, -0.005102167, 0.027717464, + 0.05483423, 0.11449111, 0.11289652, 0.10939839, 0.13396506, + -0.08402166, -0.01901462, -0.044678304, -0.07720565, 0.014350063, + -0.11757958, -0.0652038, -0.08185733, -0.076754324, -0.092614375, + 0.10405491, 0.052960336, 0.035755895, 0.035839386, -0.012540553, + 0.036881298, 0.02913376, 0.03420159, 0.05448447, -0.054523353, + 0.02582715, 0.02327355, -0.011857179, -0.0011980024, -0.034641717, + -0.026125094, -0.17582615, -0.15923657, -0.27486774, -0.0006143371, + 0.0001771948, -8.470171e-05, 0.02651807, 0.045790765, 0.06956496}); + + lstm.SetInputToCellWeights( + {-0.04580283, -0.09549462, -0.032418985, -0.06454633, + -0.043528453, 0.043018587, -0.049152344, -0.12418144, + -0.078985475, -0.07596889, 0.019484362, -0.11434962, + -0.0074034138, -0.06314844, -0.092981495, 0.0062155537, + -0.025034338, -0.0028890965, 0.048929527, 0.06235075, + 0.10665918, -0.032036792, -0.08505916, -0.10843358, + -0.13002433, -0.036816437, -0.02130134, -0.016518239, + 0.0047691227, -0.0025825808, 0.066017866, 0.029991534, + -0.10652836, -0.1037554, -0.13056071, -0.03266643, + -0.033702414, -0.006473424, -0.04611692, 0.014419339, + -0.025174323, 0.0396852, 0.081777506, 0.06157468, + 0.10210095, -0.009658194, 0.046511717, 0.03603906, + 0.0069369148, 0.015960095, -0.06507666, 0.09551598, + 0.053568836, 0.06408714, 0.12835667, -0.008714329, + -0.20211966, -0.12093674, 0.029450472, 0.2849013, + -0.029227901, 0.1164364, -0.08560263, 0.09941786, + -0.036999565, -0.028842626, -0.0033637602, -0.017012902, + -0.09720865, -0.11193351, -0.029155117, -0.017936034, + -0.009768936, -0.04223324, -0.036159635, 0.06505112, + -0.021742892, -0.023377212, -0.07221364, -0.06430552, + 0.05453865, 0.091149814, 0.06387331, 0.007518393, + 0.055960953, 0.069779344, 0.046411168, 0.10509911, + 0.07463894, 0.0075130584, 0.012850982, 0.04555431, + 0.056955688, 0.06555285, 0.050801456, -0.009862683, + 0.00826772, -0.026555609, -0.0073611983, -0.0014897042}); + + lstm.SetInputToOutputWeights( + {-0.0998932, -0.07201956, -0.052803773, -0.15629593, -0.15001918, + -0.07650751, 0.02359855, -0.075155355, -0.08037709, -0.15093534, + 0.029517552, -0.04751393, 0.010350531, -0.02664851, -0.016839722, + -0.023121163, 0.0077019283, 0.012851257, -0.05040649, -0.0129761, + -0.021737747, -0.038305793, -0.06870586, -0.01481247, -0.001285394, + 0.10124236, 0.083122835, 0.053313006, -0.062235646, -0.075637154, + -0.027833903, 0.029774971, 0.1130802, 0.09218906, 0.09506135, + -0.086665764, -0.037162706, -0.038880914, -0.035832845, -0.014481564, + -0.09825003, -0.12048569, -0.097665586, -0.05287633, -0.0964047, + -0.11366429, 0.035777505, 0.13568819, 0.052451383, 0.050649304, + 0.05798951, -0.021852335, -0.099848844, 0.014740475, -0.078897946, + 0.04974699, 0.014160473, 0.06973932, 0.04964942, 0.033364646, + 0.08190124, 0.025535367, 0.050893165, 0.048514254, 0.06945813, + -0.078907564, -0.06707616, -0.11844508, -0.09986688, -0.07509403, + 0.06263226, 0.14925587, 0.20188436, 0.12098451, 0.14639415, + 0.0015017595, -0.014267382, -0.03417257, 0.012711468, 0.0028300495, + -0.024758482, -0.05098548, -0.0821182, 0.014225672, 0.021544158, + 0.08949725, 0.07505268, -0.0020780868, 0.04908258, 0.06476295, + -0.022907063, 0.027562456, 0.040185735, 0.019567577, -0.015598739, + -0.049097303, -0.017121866, -0.083368234, -0.02332002, -0.0840956}); + + lstm.SetInputGateBias( + {0.02234832, 0.14757581, 0.18176508, 0.10380666, 0.053110216, + -0.06928846, -0.13942584, -0.11816189, 0.19483899, 0.03652339, + -0.10250295, 0.036714908, -0.18426876, 0.036065217, 0.21810818, + 0.02383196, -0.043370757, 0.08690144, -0.04444982, 0.00030581196}); + + lstm.SetForgetGateBias({0.035185695, -0.042891346, -0.03032477, 0.23027696, + 0.11098921, 0.15378423, 0.09263801, 0.09790885, + 0.09508917, 0.061199076, 0.07665568, -0.015443159, + -0.03499149, 0.046190713, 0.08895977, 0.10899629, + 0.40694186, 0.06030037, 0.012413437, -0.06108739}); + + lstm.SetCellBias({-0.024379363, 0.0055531194, 0.23377132, 0.033463873, + -0.1483596, -0.10639995, -0.091433935, 0.058573797, + -0.06809782, -0.07889636, -0.043246906, -0.09829136, + -0.4279842, 0.034901652, 0.18797937, 0.0075234566, + 0.016178843, 0.1749513, 0.13975595, 0.92058027}); + + lstm.SetOutputGateBias( + {0.046159424, -0.0012809046, 0.03563469, 0.12648113, 0.027195795, + 0.35373217, -0.018957434, 0.008907322, -0.0762701, 0.12018895, + 0.04216877, 0.0022856654, 0.040952638, 0.3147856, 0.08225149, + -0.057416286, -0.14995944, -0.008040261, 0.13208859, 0.029760877}); + + lstm.SetRecurrentToInputWeights( + {-0.001374326, -0.078856036, 0.10672688, 0.029162422, + -0.11585556, 0.02557986, -0.13446963, -0.035785314, + -0.01244275, 0.025961924, -0.02337298, -0.044228926, + -0.055839065, -0.046598054, -0.010546039, -0.06900766, + 0.027239809, 0.022582639, -0.013296484, -0.05459212, + 0.08981, -0.045407712, 0.08682226, -0.06867011, + -0.14390695, -0.02916037, 0.000996957, 0.091420636, + 0.14283475, -0.07390571, -0.06402044, 0.062524505, + -0.093129106, 0.04860203, -0.08364217, -0.08119002, + 0.009352075, 0.22920375, 0.0016303885, 0.11583097, + -0.13732095, 0.012405723, -0.07551853, 0.06343048, + 0.12162708, -0.031923793, -0.014335606, 0.01790974, + -0.10650317, -0.0724401, 0.08554849, -0.05727212, + 0.06556731, -0.042729504, -0.043227166, 0.011683251, + -0.013082158, -0.029302018, -0.010899579, -0.062036745, + -0.022509435, -0.00964907, -0.01567329, 0.04260106, + -0.07787477, -0.11576462, 0.017356863, 0.048673786, + -0.017577527, -0.05527947, -0.082487635, -0.040137455, + -0.10820036, -0.04666372, 0.022746278, -0.07851417, + 0.01068115, 0.032956902, 0.022433773, 0.0026891115, + 0.08944216, -0.0685835, 0.010513544, 0.07228705, + 0.02032331, -0.059686817, -0.0005566496, -0.086984694, + 0.040414046, -0.1380399, 0.094208956, -0.05722982, + 0.012092817, -0.04989123, -0.086576, -0.003399834, + -0.04696032, -0.045747425, 0.10091314, 0.048676282, + -0.029037097, 0.031399418, -0.0040285117, 0.047237843, + 0.09504992, 0.041799378, -0.049185462, -0.031518843, + -0.10516937, 0.026374253, 0.10058866, -0.0033195973, + -0.041975245, 0.0073591834, 0.0033782164, -0.004325073, + -0.10167381, 0.042500053, -0.01447153, 0.06464186, + -0.017142897, 0.03312627, 0.009205989, 0.024138335, + -0.011337001, 0.035530265, -0.010912711, 0.0706555, + -0.005894094, 0.051841937, -0.1401738, -0.02351249, + 0.0365468, 0.07590991, 0.08838724, 0.021681072, + -0.10086113, 0.019608743, -0.06195883, 0.077335775, + 0.023646897, -0.095322326, 0.02233014, 0.09756986, + -0.048691444, -0.009579111, 0.07595467, 0.11480546, + -0.09801813, 0.019894179, 0.08502348, 0.004032281, + 0.037211012, 0.068537936, -0.048005626, -0.091520436, + -0.028379958, -0.01556313, 0.06554592, -0.045599163, + -0.01672207, -0.020169014, -0.011877351, -0.20212261, + 0.010889619, 0.0047078193, 0.038385306, 0.08540671, + -0.017140968, -0.0035865551, 0.016678626, 0.005633034, + 0.015963363, 0.00871737, 0.060130805, 0.028611384, + 0.10109069, -0.015060172, -0.07894427, 0.06401885, + 0.011584063, -0.024466386, 0.0047652307, -0.09041358, + 0.030737216, -0.0046374933, 0.14215417, -0.11823516, + 0.019899689, 0.006106124, -0.027092824, 0.0786356, + 0.05052217, -0.058925, -0.011402121, -0.024987547, + -0.0013661642, -0.06832946, -0.015667673, -0.1083353, + -0.00096863037, -0.06988685, -0.053350925, -0.027275559, + -0.033664223, -0.07978348, -0.025200296, -0.017207067, + -0.058403496, -0.055697463, 0.005798788, 0.12965427, + -0.062582195, 0.0013350133, -0.10482091, 0.0379771, + 0.072521195, -0.0029455067, -0.13797039, -0.03628521, + 0.013806405, -0.017858358, -0.01008298, -0.07700066, + -0.017081132, 0.019358726, 0.0027079724, 0.004635139, + 0.062634714, -0.02338735, -0.039547626, -0.02050681, + 0.03385117, -0.083611414, 0.002862572, -0.09421313, + 0.058618143, -0.08598433, 0.00972939, 0.023867095, + -0.053934585, -0.023203006, 0.07452513, -0.048767887, + -0.07314807, -0.056307215, -0.10433547, -0.06440842, + 0.04328182, 0.04389765, -0.020006588, -0.09076438, + -0.11652589, -0.021705797, 0.03345259, -0.010329105, + -0.025767034, 0.013057034, -0.07316461, -0.10145612, + 0.06358255, 0.18531723, 0.07759293, 0.12006465, + 0.1305557, 0.058638252, -0.03393652, 0.09622831, + -0.16253184, -2.4580743e-06, 0.079869635, -0.070196845, + -0.005644518, 0.06857898, -0.12598175, -0.035084512, + 0.03156317, -0.12794146, -0.031963028, 0.04692781, + 0.030070418, 0.0071660685, -0.095516115, -0.004643372, + 0.040170413, -0.062104587, -0.0037324072, 0.0554317, + 0.08184801, -0.019164372, 0.06791302, 0.034257166, + -0.10307039, 0.021943003, 0.046745934, 0.0790918, + -0.0265588, -0.007824208, 0.042546265, -0.00977924, + -0.0002440307, -0.017384544, -0.017990116, 0.12252321, + -0.014512694, -0.08251313, 0.08861942, 0.13589665, + 0.026351685, 0.012641483, 0.07466548, 0.044301085, + -0.045414884, -0.051112458, 0.03444247, -0.08502782, + -0.04106223, -0.028126027, 0.028473156, 0.10467447}); + + lstm.SetRecurrentToForgetWeights( + {-0.057784554, -0.026057621, -0.068447545, -0.022581743, + 0.14811787, 0.10826372, 0.09471067, 0.03987225, + -0.0039523416, 0.00030638507, 0.053185795, 0.10572994, + 0.08414449, -0.022036452, -0.00066928595, -0.09203576, + 0.032950465, -0.10985798, -0.023809856, 0.0021431844, + -0.02196096, -0.00326074, 0.00058621005, -0.074678116, + -0.06193199, 0.055729095, 0.03736828, 0.020123724, + 0.061878487, -0.04729229, 0.034919553, -0.07585433, + -0.04421272, -0.044019096, 0.085488975, 0.04058006, + -0.06890133, -0.030951202, -0.024628663, -0.07672815, + 0.034293607, 0.08556707, -0.05293577, -0.033561368, + -0.04899627, 0.0241671, 0.015736353, -0.095442444, + -0.029564252, 0.016493602, -0.035026584, 0.022337519, + -0.026871363, 0.004780428, 0.0077918363, -0.03601621, + 0.016435321, -0.03263031, -0.09543275, -0.047392778, + 0.013454138, 0.028934088, 0.01685226, -0.086110644, + -0.046250615, -0.01847454, 0.047608484, 0.07339695, + 0.034546845, -0.04881143, 0.009128804, -0.08802852, + 0.03761666, 0.008096139, -0.014454086, 0.014361001, + -0.023502491, -0.0011840804, -0.07607001, 0.001856849, + -0.06509276, -0.006021153, -0.08570962, -0.1451793, + 0.060212336, 0.055259194, 0.06974018, 0.049454916, + -0.027794661, -0.08077226, -0.016179763, 0.1169753, + 0.17213494, -0.0056326236, -0.053934924, -0.0124349, + -0.11520337, 0.05409887, 0.088759385, 0.0019655675, + 0.0042065294, 0.03881498, 0.019844765, 0.041858196, + -0.05695512, 0.047233116, 0.038937137, -0.06542224, + 0.014429736, -0.09719407, 0.13908425, -0.05379757, + 0.012321099, 0.082840554, -0.029899208, 0.044217527, + 0.059855383, 0.07711018, -0.045319796, 0.0948846, + -0.011724666, -0.0033288454, -0.033542685, -0.04764985, + -0.13873616, 0.040668588, 0.034832682, -0.015319203, + -0.018715994, 0.046002675, 0.0599172, -0.043107376, + 0.0294216, -0.002314414, -0.022424703, 0.0030315618, + 0.0014641669, 0.0029166266, -0.11878115, 0.013738511, + 0.12375372, -0.0006038222, 0.029104086, 0.087442465, + 0.052958444, 0.07558703, 0.04817258, 0.044462286, + -0.015213451, -0.08783778, -0.0561384, -0.003008196, + 0.047060397, -0.002058388, 0.03429439, -0.018839769, + 0.024734668, 0.024614193, -0.042046934, 0.09597743, + -0.0043254104, 0.04320769, 0.0064070094, -0.0019131786, + -0.02558259, -0.022822596, -0.023273505, -0.02464396, + -0.10991725, -0.006240552, 0.0074488563, 0.024044557, + 0.04383914, -0.046476185, 0.028658995, 0.060410924, + 0.050786525, 0.009452605, -0.0073054377, -0.024810238, + 0.0052906186, 0.0066939713, -0.0020913032, 0.014515517, + 0.015898481, 0.021362653, -0.030262267, 0.016587038, + -0.011442813, 0.041154444, -0.007631438, -0.03423484, + -0.010977775, 0.036152758, 0.0066366293, 0.11915515, + 0.02318443, -0.041350313, 0.021485701, -0.10906167, + -0.028218046, -0.00954771, 0.020531068, -0.11995105, + -0.03672871, 0.024019798, 0.014255957, -0.05221243, + -0.00661567, -0.04630967, 0.033188973, 0.10107534, + -0.014027541, 0.030796422, -0.10270911, -0.035999842, + 0.15443139, 0.07684145, 0.036571592, -0.035900835, + -0.0034699554, 0.06209149, 0.015920248, -0.031122351, + -0.03858649, 0.01849943, 0.13872518, 0.01503974, + 0.069941424, -0.06948533, -0.0088794185, 0.061282158, + -0.047401894, 0.03100163, -0.041533746, -0.10430945, + 0.044574402, -0.01425562, -0.024290353, 0.034563623, + 0.05866852, 0.023947537, -0.09445152, 0.035450947, + 0.02247216, -0.0042998926, 0.061146557, -0.10250651, + 0.020881841, -0.06747029, 0.10062043, -0.0023941975, + 0.03532124, -0.016341697, 0.09685456, -0.016764693, + 0.051808182, 0.05875331, -0.04536488, 0.001626336, + -0.028892258, -0.01048663, -0.009793449, -0.017093895, + 0.010987891, 0.02357273, -0.00010856845, 0.0099760275, + -0.001845119, -0.03551521, 0.0018358806, 0.05763657, + -0.01769146, 0.040995963, 0.02235177, -0.060430344, + 0.11475477, -0.023854522, 0.10071741, 0.0686208, + -0.014250481, 0.034261297, 0.047418304, 0.08562733, + -0.030519066, 0.0060542435, 0.014653856, -0.038836084, + 0.04096551, 0.032249358, -0.08355519, -0.026823482, + 0.056386515, -0.010401743, -0.028396193, 0.08507674, + 0.014410365, 0.020995233, 0.17040324, 0.11511526, + 0.02459721, 0.0066619175, 0.025853224, -0.023133837, + -0.081302024, 0.017264642, -0.009585969, 0.09491168, + -0.051313367, 0.054532815, -0.014298593, 0.10657464, + 0.007076659, 0.10964551, 0.0409152, 0.008275321, + -0.07283536, 0.07937492, 0.04192024, -0.1075027}); + + lstm.SetRecurrentToCellWeights( + {-0.037322544, 0.018592842, 0.0056175636, -0.06253426, + 0.055647098, -0.05713207, -0.05626563, 0.005559383, + 0.03375411, -0.025757805, -0.088049285, 0.06017052, + -0.06570978, 0.007384076, 0.035123326, -0.07920549, + 0.053676967, 0.044480428, -0.07663568, 0.0071805613, + 0.08089997, 0.05143358, 0.038261272, 0.03339287, + -0.027673481, 0.044746667, 0.028349208, 0.020090483, + -0.019443132, -0.030755889, -0.0040000007, 0.04465846, + -0.021585021, 0.0031670958, 0.0053199246, -0.056117613, + -0.10893326, 0.076739706, -0.08509834, -0.027997585, + 0.037871376, 0.01449768, -0.09002357, -0.06111149, + -0.046195522, 0.0422062, -0.005683705, -0.1253618, + -0.012925729, -0.04890792, 0.06985068, 0.037654128, + 0.03398274, -0.004781977, 0.007032333, -0.031787455, + 0.010868644, -0.031489216, 0.09525667, 0.013939797, + 0.0058680447, 0.0167067, 0.02668468, -0.04797466, + -0.048885044, -0.12722108, 0.035304096, 0.06554885, + 0.00972396, -0.039238118, -0.05159735, -0.11329045, + 0.1613692, -0.03750952, 0.06529313, -0.071974665, + -0.11769596, 0.015524369, -0.0013754242, -0.12446318, + 0.02786344, -0.014179351, 0.005264273, 0.14376344, + 0.015983658, 0.03406988, -0.06939408, 0.040699873, + 0.02111075, 0.09669095, 0.041345075, -0.08316494, + -0.07684199, -0.045768797, 0.032298047, -0.041805092, + 0.0119405, 0.0061010392, 0.12652606, 0.0064572375, + -0.024950314, 0.11574242, 0.04508852, -0.04335324, + 0.06760663, -0.027437469, 0.07216407, 0.06977076, + -0.05438599, 0.034033038, -0.028602652, 0.05346137, + 0.043184172, -0.037189785, 0.10420091, 0.00882477, + -0.054019816, -0.074273005, -0.030617684, -0.0028467078, + 0.024302477, -0.0038869337, 0.005332455, 0.0013399826, + 0.04361412, -0.007001822, 0.09631092, -0.06702025, + -0.042049985, -0.035070654, -0.04103342, -0.10273396, + 0.0544271, 0.037184782, -0.13150354, -0.0058036847, + -0.008264958, 0.042035464, 0.05891794, 0.029673764, + 0.0063542654, 0.044788733, 0.054816857, 0.062257513, + -0.00093483756, 0.048938446, -0.004952862, -0.007730018, + -0.04043371, -0.017094059, 0.07229206, -0.023670016, + -0.052195564, -0.025616996, -0.01520939, 0.045104615, + -0.007376126, 0.003533447, 0.006570588, 0.056037236, + 0.12436656, 0.051817212, 0.028532185, -0.08686856, + 0.11868599, 0.07663395, -0.07323171, 0.03463402, + -0.050708205, -0.04458982, -0.11590894, 0.021273347, + 0.1251325, -0.15313013, -0.12224372, 0.17228661, + 0.023029093, 0.086124025, 0.006445803, -0.03496501, + 0.028332196, 0.04449512, -0.042436164, -0.026587414, + -0.006041347, -0.09292539, -0.05678812, 0.03897832, + 0.09465633, 0.008115513, -0.02171956, 0.08304309, + 0.071401566, 0.019622514, 0.032163795, -0.004167056, + 0.02295182, 0.030739572, 0.056506045, 0.004612461, + 0.06524936, 0.059999723, 0.046395954, -0.0045512207, + -0.1335546, -0.030136576, 0.11584653, -0.014678886, + 0.0020118146, -0.09688814, -0.0790206, 0.039770417, + -0.0329582, 0.07922767, 0.029322514, 0.026405897, + 0.04207835, -0.07073373, 0.063781224, 0.0859677, + -0.10925287, -0.07011058, 0.048005477, 0.03438226, + -0.09606514, -0.006669445, -0.043381985, 0.04240257, + -0.06955775, -0.06769346, 0.043903265, -0.026784198, + -0.017840602, 0.024307009, -0.040079936, -0.019946516, + 0.045318738, -0.12233574, 0.026170589, 0.0074471775, + 0.15978073, 0.10185836, 0.10298046, -0.015476589, + -0.039390966, -0.072174534, 0.0739445, -0.1211869, + -0.0347889, -0.07943156, 0.014809798, -0.12412325, + -0.0030663363, 0.039695457, 0.0647603, -0.08291318, + -0.018529687, -0.004423833, 0.0037507233, 0.084633216, + -0.01514876, -0.056505352, -0.012800942, -0.06994386, + 0.012962922, -0.031234352, 0.07029052, 0.016418684, + 0.03618972, 0.055686004, -0.08663945, -0.017404709, + -0.054761406, 0.029065743, 0.052404847, 0.020238016, + 0.0048197987, -0.0214882, 0.07078733, 0.013016777, + 0.06262858, 0.009184685, 0.020785125, -0.043904778, + -0.0270329, -0.03299152, -0.060088247, -0.015162964, + -0.001828936, 0.12642565, -0.056757294, 0.013586685, + 0.09232601, -0.035886683, 0.06000002, 0.05229691, + -0.052580316, -0.082029596, -0.010794592, 0.012947712, + -0.036429964, -0.085508935, -0.13127148, -0.017744139, + 0.031502828, 0.036232427, -0.031581745, 0.023051167, + -0.05325106, -0.03421577, 0.028793324, -0.034633752, + -0.009881397, -0.043551125, -0.018609839, 0.0019097115, + -0.008799762, 0.056595087, 0.0022273948, 0.055752404}); + + lstm.SetRecurrentToOutputWeights({ + 0.025825322, -0.05813119, 0.09495884, -0.045984812, -0.01255415, + -0.0026479573, -0.08196161, -0.054914974, -0.0046604523, -0.029587349, + -0.044576716, -0.07480124, -0.082868785, 0.023254942, 0.027502948, + -0.0039728214, -0.08683098, -0.08116779, -0.014675607, -0.037924774, + -0.023314456, -0.007401714, -0.09255757, 0.029460307, -0.08829125, + -0.005139627, -0.08989442, -0.0555066, 0.13596267, -0.025062224, + -0.048351806, -0.03850004, 0.07266485, -0.022414139, 0.05940088, + 0.075114764, 0.09597592, -0.010211725, -0.0049794707, -0.011523867, + -0.025980417, 0.072999895, 0.11091378, -0.081685916, 0.014416728, + 0.043229222, 0.034178585, -0.07530371, 0.035837382, -0.085607, + -0.007721233, -0.03287832, -0.043848954, -0.06404588, -0.06632928, + -0.073643476, 0.008214239, -0.045984086, 0.039764922, 0.03474462, + 0.060612556, -0.080590084, 0.049127717, 0.04151091, -0.030063879, + 0.008801774, -0.023021035, -0.019558564, 0.05158114, -0.010947698, + -0.011825728, 0.0075720972, 0.0699727, -0.0039981045, 0.069350146, + 0.08799282, 0.016156472, 0.035502106, 0.11695009, 0.006217345, + 0.13392477, -0.037875112, 0.025745004, 0.08940699, -0.00924166, + 0.0046702605, -0.036598757, -0.08811812, 0.10522024, -0.032441203, + 0.008176899, -0.04454919, 0.07058152, 0.0067963637, 0.039206743, + 0.03259838, 0.03725492, -0.09515802, 0.013326398, -0.052055415, + -0.025676316, 0.03198509, -0.015951829, -0.058556724, 0.036879618, + 0.043357447, 0.028362012, -0.05908629, 0.0059240665, -0.04995891, + -0.019187413, 0.0276265, -0.01628143, 0.0025863599, 0.08800015, + 0.035250366, -0.022165963, -0.07328642, -0.009415526, -0.07455109, + 0.11690406, 0.0363299, 0.07411125, 0.042103454, -0.009660886, + 0.019076364, 0.018299393, -0.046004917, 0.08891175, 0.0431396, + -0.026327137, -0.051502608, 0.08979574, -0.051670972, 0.04940282, + -0.07491107, -0.021240504, 0.022596184, -0.034280192, 0.060163025, + -0.058211457, -0.051837247, -0.01349775, -0.04639988, -0.035936575, + -0.011681591, 0.064818054, 0.0073146066, -0.021745546, -0.043124277, + -0.06471268, -0.07053354, -0.029321948, -0.05330136, 0.016933719, + -0.053782392, 0.13747959, -0.1361751, -0.11569455, 0.0033329215, + 0.05693899, -0.053219706, 0.063698, 0.07977434, -0.07924483, + 0.06936997, 0.0034815092, -0.007305279, -0.037325785, -0.07251102, + -0.033633437, -0.08677009, 0.091591336, -0.14165086, 0.021752775, + 0.019683983, 0.0011612234, -0.058154266, 0.049996935, 0.0288841, + -0.0024567875, -0.14345716, 0.010955264, -0.10234828, 0.1183656, + -0.0010731248, -0.023590032, -0.072285876, -0.0724771, -0.026382286, + -0.0014920527, 0.042667855, 0.0018776858, 0.02986552, 0.009814309, + 0.0733756, 0.12289186, 0.018043943, -0.0458958, 0.049412545, + 0.033632483, 0.05495232, 0.036686596, -0.013781798, -0.010036754, + 0.02576849, -0.08307328, 0.010112348, 0.042521734, -0.05869831, + -0.071689695, 0.03876447, -0.13275425, -0.0352966, -0.023077697, + 0.10285965, 0.084736146, 0.15568255, -0.00040734606, 0.027835453, + -0.10292561, -0.032401145, 0.10053256, -0.026142767, -0.08271222, + -0.0030240538, -0.016368777, 0.1070414, 0.042672627, 0.013456989, + -0.0437609, -0.022309763, 0.11576483, 0.04108048, 0.061026827, + -0.0190714, -0.0869359, 0.037901703, 0.0610107, 0.07202949, + 0.01675338, 0.086139716, -0.08795751, -0.014898893, -0.023771819, + -0.01965048, 0.007955471, -0.043740474, 0.03346837, -0.10549954, + 0.090567775, 0.042013682, -0.03176985, 0.12569028, -0.02421228, + -0.029526481, 0.023851605, 0.031539805, 0.05292009, -0.02344001, + -0.07811758, -0.08834428, 0.10094801, 0.16594367, -0.06861939, + -0.021256343, -0.041093912, -0.06669611, 0.035498552, 0.021757556, + -0.09302526, -0.015403468, -0.06614931, -0.051798206, -0.013874718, + 0.03630673, 0.010412845, -0.08077351, 0.046185967, 0.0035662893, + 0.03541868, -0.094149634, -0.034814864, 0.003128424, -0.020674974, + -0.03944324, -0.008110165, -0.11113267, 0.08484226, 0.043586485, + 0.040582247, 0.0968012, -0.065249965, -0.028036479, 0.0050708856, + 0.0017462453, 0.0326779, 0.041296225, 0.09164146, -0.047743853, + -0.015952192, -0.034451712, 0.084197424, -0.05347844, -0.11768019, + 0.085926116, -0.08251791, -0.045081906, 0.0948852, 0.068401024, + 0.024856757, 0.06978981, -0.057309967, -0.012775832, -0.0032452994, + 0.01977615, -0.041040014, -0.024264973, 0.063464895, 0.05431621, + }); + + lstm.SetCellToInputWeights( + {0.040369894, 0.030746894, 0.24704495, 0.018586371, -0.037586458, + -0.15312155, -0.11812848, -0.11465643, 0.20259799, 0.11418174, + -0.10116027, -0.011334949, 0.12411352, -0.076769054, -0.052169047, + 0.21198851, -0.38871562, -0.09061183, -0.09683246, -0.21929175}); + + lstm.SetCellToForgetWeights( + {-0.01998659, -0.15568835, -0.24248174, -0.012770197, 0.041331276, + -0.072311886, -0.052123554, -0.0066330447, -0.043891653, 0.036225766, + -0.047248036, 0.021479502, 0.033189066, 0.11952997, -0.020432774, + 0.64658105, -0.06650122, -0.03467612, 0.095340036, 0.23647355}); + + lstm.SetCellToOutputWeights( + {0.08286371, -0.08261836, -0.51210177, 0.002913762, 0.17764764, + -0.5495371, -0.08460716, -0.24552552, 0.030037103, 0.04123544, + -0.11940523, 0.007358328, 0.1890978, 0.4833202, -0.34441817, + 0.36312827, -0.26375428, 0.1457655, -0.19724406, 0.15548733}); + + lstm.SetProjectionWeights( + {-0.009802181, 0.09401916, 0.0717386, -0.13895074, 0.09641832, + 0.060420845, 0.08539281, 0.054285463, 0.061395317, 0.034448683, + -0.042991187, 0.019801661, -0.16840284, -0.015726732, -0.23041931, + -0.024478018, -0.10959692, -0.013875541, 0.18600968, -0.061274476, + 0.0138165, -0.08160894, -0.07661644, 0.032372914, 0.16169067, + 0.22465782, -0.03993472, -0.004017731, 0.08633481, -0.28869787, + 0.08682067, 0.17240396, 0.014975425, 0.056431185, 0.031037588, + 0.16702051, 0.0077946745, 0.15140012, 0.29405436, 0.120285, + -0.188994, -0.027265169, 0.043389652, -0.022061434, 0.014777949, + -0.20203483, 0.094781205, 0.19100232, 0.13987629, -0.036132768, + -0.06426278, -0.05108664, 0.13221376, 0.009441198, -0.16715929, + 0.15859416, -0.040437475, 0.050779544, -0.022187516, 0.012166504, + 0.027685808, -0.07675938, -0.0055694645, -0.09444123, 0.0046453946, + 0.050794356, 0.10770313, -0.20790008, -0.07149004, -0.11425117, + 0.008225835, -0.035802525, 0.14374903, 0.15262283, 0.048710253, + 0.1847461, -0.007487823, 0.11000021, -0.09542012, 0.22619456, + -0.029149994, 0.08527916, 0.009043713, 0.0042746216, 0.016261552, + 0.022461696, 0.12689082, -0.043589946, -0.12035478, -0.08361797, + -0.050666027, -0.1248618, -0.1275799, -0.071875185, 0.07377272, + 0.09944291, -0.18897448, -0.1593054, -0.06526116, -0.040107165, + -0.004618631, -0.067624845, -0.007576253, 0.10727444, 0.041546922, + -0.20424393, 0.06907816, 0.050412357, 0.00724631, 0.039827548, + 0.12449835, 0.10747581, 0.13708383, 0.09134148, -0.12617786, + -0.06428341, 0.09956831, 0.1208086, -0.14676677, -0.0727722, + 0.1126304, 0.010139365, 0.015571211, -0.038128063, 0.022913318, + -0.042050496, 0.16842307, -0.060597885, 0.10531834, -0.06411776, + -0.07451711, -0.03410368, -0.13393489, 0.06534304, 0.003620307, + 0.04490757, 0.05970546, 0.05197996, 0.02839995, 0.10434969, + -0.013699693, -0.028353551, -0.07260381, 0.047201227, -0.024575593, + -0.036445823, 0.07155557, 0.009672501, -0.02328883, 0.009533515, + -0.03606021, -0.07421458, -0.028082801, -0.2678904, -0.13221288, + 0.18419984, -0.13012612, -0.014588381, -0.035059117, -0.04824723, + 0.07830115, -0.056184657, 0.03277091, 0.025466874, 0.14494097, + -0.12522776, -0.098633975, -0.10766018, -0.08317623, 0.08594209, + 0.07749552, 0.039474737, 0.1776665, -0.07409566, -0.0477268, + 0.29323658, 0.10801441, 0.1154011, 0.013952499, 0.10739139, + 0.10708251, -0.051456142, 0.0074137426, -0.10430189, 0.10034707, + 0.045594677, 0.0635285, -0.0715442, -0.089667566, -0.10811871, + 0.00026344223, 0.08298446, -0.009525053, 0.006585689, -0.24567553, + -0.09450807, 0.09648481, 0.026996298, -0.06419476, -0.04752702, + -0.11063944, -0.23441927, -0.17608605, -0.052156363, 0.067035615, + 0.19271925, -0.0032889997, -0.043264326, 0.09663576, -0.057112187, + -0.10100678, 0.0628376, 0.04447668, 0.017961001, -0.10094388, + -0.10190601, 0.18335468, 0.10494553, -0.052095775, -0.0026118709, + 0.10539724, -0.04383912, -0.042349473, 0.08438151, -0.1947263, + 0.02251204, 0.11216432, -0.10307853, 0.17351969, -0.039091777, + 0.08066188, -0.00561982, 0.12633002, 0.11335965, -0.0088127935, + -0.019777594, 0.06864014, -0.059751723, 0.016233567, -0.06894641, + -0.28651384, -0.004228674, 0.019708522, -0.16305895, -0.07468996, + -0.0855457, 0.099339016, -0.07580735, -0.13775392, 0.08434318, + 0.08330512, -0.12131499, 0.031935584, 0.09180414, -0.08876437, + -0.08049874, 0.008753825, 0.03498998, 0.030215185, 0.03907079, + 0.089751154, 0.029194152, -0.03337423, -0.019092513, 0.04331237, + 0.04299654, -0.036394123, -0.12915532, 0.09793732, 0.07512415, + -0.11319543, -0.032502122, 0.15661901, 0.07671967, -0.005491124, + -0.19379048, -0.218606, 0.21448623, 0.017840758, 0.1416943, + -0.07051762, 0.19488361, 0.02664691, -0.18104725, -0.09334311, + 0.15026465, -0.15493552, -0.057762887, -0.11604192, -0.262013, + -0.01391798, 0.012185008, 0.11156489, -0.07483202, 0.06693364, + -0.26151478, 0.046425626, 0.036540434, -0.16435726, 0.17338543, + -0.21401681, -0.11385144, -0.08283257, -0.069031075, 0.030635102, + 0.010969227, 0.11109743, 0.010919218, 0.027526086, 0.13519906, + 0.01891392, -0.046839405, -0.040167913, 0.017953383, -0.09700955, + 0.0061885654, -0.07000971, 0.026893595, -0.038844477, 0.14543656}); + + static float lstm_input[][20] = { + {// Batch0: 4 (input_sequence_size) * 5 (n_input) + 0.787926, 0.151646, 0.071352, 0.118426, 0.458058, 0.596268, 0.998386, + 0.568695, 0.864524, 0.571277, 0.073204, 0.296072, 0.743333, 0.069199, + 0.045348, 0.867394, 0.291279, 0.013714, 0.482521, 0.626339}, + + {// Batch1: 4 (input_sequence_size) * 5 (n_input) + 0.295743, 0.544053, 0.690064, 0.858138, 0.497181, 0.642421, 0.524260, + 0.134799, 0.003639, 0.162482, 0.640394, 0.930399, 0.050782, 0.432485, + 0.988078, 0.082922, 0.563329, 0.865614, 0.333232, 0.259916}}; + + static float lstm_golden_output[][64] = { + {// Batch0: 4 (input_sequence_size) * 16 (n_output) + -0.00396806, 0.029352, -0.00279226, 0.0159977, -0.00835576, + -0.0211779, 0.0283512, -0.0114597, 0.00907307, -0.0244004, + -0.0152191, -0.0259063, 0.00914318, 0.00415118, 0.017147, + 0.0134203, -0.0166936, 0.0381209, 0.000889694, 0.0143363, + -0.0328911, -0.0234288, 0.0333051, -0.012229, 0.0110322, + -0.0457725, -0.000832209, -0.0202817, 0.0327257, 0.0121308, + 0.0155969, 0.0312091, -0.0213783, 0.0350169, 0.000324794, + 0.0276012, -0.0263374, -0.0371449, 0.0446149, -0.0205474, + 0.0103729, -0.0576349, -0.0150052, -0.0292043, 0.0376827, + 0.0136115, 0.0243435, 0.0354492, -0.0189322, 0.0464512, + -0.00251373, 0.0225745, -0.0308346, -0.0317124, 0.0460407, + -0.0189395, 0.0149363, -0.0530162, -0.0150767, -0.0340193, + 0.0286833, 0.00824207, 0.0264887, 0.0305169}, + {// Batch1: 4 (input_sequence_size) * 16 (n_output) + -0.013869, 0.0287268, -0.00334693, 0.00733398, -0.0287926, + -0.0186926, 0.0193662, -0.0115437, 0.00422612, -0.0345232, + 0.00223253, -0.00957321, 0.0210624, 0.013331, 0.0150954, + 0.02168, -0.0141913, 0.0322082, 0.00227024, 0.0260507, + -0.0188721, -0.0296489, 0.0399134, -0.0160509, 0.0116039, + -0.0447318, -0.0150515, -0.0277406, 0.0316596, 0.0118233, + 0.0214762, 0.0293641, -0.0204549, 0.0450315, -0.00117378, + 0.0167673, -0.0375007, -0.0238314, 0.038784, -0.0174034, + 0.0131743, -0.0506589, -0.0048447, -0.0240239, 0.0325789, + 0.00790065, 0.0220157, 0.0333314, -0.0264787, 0.0387855, + -0.000764675, 0.0217599, -0.037537, -0.0335206, 0.0431679, + -0.0211424, 0.010203, -0.062785, -0.00832363, -0.025181, + 0.0412031, 0.0118723, 0.0239643, 0.0394009}}; + + // Resetting cell_state and output_state + lstm.ResetCellState(); + lstm.ResetOutputState(); + + const int input_sequence_size = + sizeof(lstm_input[0]) / sizeof(float) / (lstm.num_inputs()); + for (int i = 0; i < input_sequence_size; i++) { + float* batch0_start = lstm_input[0] + i * lstm.num_inputs(); + float* batch0_end = batch0_start + lstm.num_inputs(); + + lstm.SetInput(0, batch0_start, batch0_end); + + float* batch1_start = lstm_input[1] + i * lstm.num_inputs(); + float* batch1_end = batch1_start + lstm.num_inputs(); + lstm.SetInput(lstm.num_inputs(), batch1_start, batch1_end); + + lstm.Invoke(); + + float* golden_start_batch0 = lstm_golden_output[0] + i * lstm.num_outputs(); + float* golden_end_batch0 = golden_start_batch0 + lstm.num_outputs(); + float* golden_start_batch1 = lstm_golden_output[1] + i * lstm.num_outputs(); + float* golden_end_batch1 = golden_start_batch1 + lstm.num_outputs(); + std::vector expected; + expected.insert(expected.end(), golden_start_batch0, golden_end_batch0); + expected.insert(expected.end(), golden_start_batch1, golden_end_batch1); + EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); + } +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/mul.cc b/tensorflow/contrib/lite/kernels/mul.cc new file mode 100644 index 0000000000000000000000000000000000000000..81c73f2523186c2d4072d56bdc8980fcdbb588a3 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/mul.cc @@ -0,0 +1,167 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace mul { + +// This file has three implementation of Mul. +enum KernelType { + kReference, + kGenericOptimized, // Neon-free + kNeonOptimized, +}; + +constexpr int kInputTensor1 = 0; +constexpr int kInputTensor2 = 1; +constexpr int kOutputTensor = 0; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + TF_LITE_ENSURE_EQ(context, NumDimensions(input1), NumDimensions(input2)); + for (int i = 0; i < NumDimensions(input1); ++i) { + TF_LITE_ENSURE_EQ(context, SizeOfDimension(input1, i), + SizeOfDimension(input2, i)); + } + + TF_LITE_ENSURE_EQ(context, input1->type, output->type); + TF_LITE_ENSURE_EQ(context, input2->type, output->type); + + TfLiteIntArray* output_size = TfLiteIntArrayCopy(input1->dims); + return context->ResizeTensor(context, output, output_size); +} + +template +void EvalFloat(TfLiteContext* context, TfLiteNode* node, + TfLiteMulParams* params, TfLiteTensor* input1, + TfLiteTensor* input2, TfLiteTensor* output) { + float output_activation_min, output_activation_max; + CalculateActivationRangeFloat(params->activation, &output_activation_min, + &output_activation_max); +#define TF_LITE_MUL(type) \ + type::Mul(GetTensorData(input1), GetTensorDims(input1), \ + GetTensorData(input2), GetTensorDims(input2), \ + output_activation_min, output_activation_max, \ + GetTensorData(output), GetTensorDims(output)) + if (kernel_type == kReference) { + TF_LITE_MUL(reference_ops); + } else { + TF_LITE_MUL(optimized_ops); + } +#undef TF_LITE_MUL +} + +template +void EvalQuantized(TfLiteContext* context, TfLiteNode* node, + TfLiteMulParams* params, TfLiteTensor* input1, + TfLiteTensor* input2, TfLiteTensor* output) { + auto input1_offset = -input1->params.zero_point; + auto input2_offset = -input2->params.zero_point; + auto output_offset = output->params.zero_point; + + int32_t output_multiplier; + int output_shift; + + double real_multiplier = + input1->params.scale * input2->params.scale / output->params.scale; + QuantizeMultiplierSmallerThanOne(real_multiplier, &output_multiplier, + &output_shift); + + int32 output_activation_min, output_activation_max; + CalculateActivationRangeUint8(params->activation, output, + &output_activation_min, &output_activation_max); + +#define TF_LITE_MUL(type) \ + type::BroadcastMul(GetTensorData(input1), GetTensorDims(input1), \ + input1_offset, GetTensorData(input2), \ + GetTensorDims(input2), input2_offset, output_offset, \ + output_multiplier, output_shift, output_activation_min, \ + output_activation_max, GetTensorData(output), \ + GetTensorDims(output)); + if (kernel_type == kReference) { + TF_LITE_MUL(reference_ops); + } else { + TF_LITE_MUL(optimized_ops); + } +#undef TF_LITE_MUL +} + +template +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + + TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + if (output->type == kTfLiteFloat32) { + EvalFloat(context, node, params, input1, input2, output); + } else if (output->type == kTfLiteUInt8) { + EvalQuantized(context, node, params, input1, input2, output); + } else { + context->ReportError(context, + "Mul only supports FLOAT32 and quantized UINT8 now."); + return kTfLiteError; + } + + return kTfLiteOk; +} + +} // namespace mul + +TfLiteRegistration* Register_MUL_REF() { + static TfLiteRegistration r = {nullptr, nullptr, mul::Prepare, + mul::Eval}; + return &r; +} + +TfLiteRegistration* Register_MUL_GENERIC_OPT() { + static TfLiteRegistration r = {nullptr, nullptr, mul::Prepare, + mul::Eval}; + return &r; +} + +TfLiteRegistration* Register_MUL_NEON_OPT() { + static TfLiteRegistration r = {nullptr, nullptr, mul::Prepare, + mul::Eval}; + return &r; +} + +TfLiteRegistration* Register_MUL() { +#ifdef USE_NEON + return Register_MUL_NEON_OPT(); +#else + return Register_MUL_GENERIC_OPT(); +#endif +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/mul_test.cc b/tensorflow/contrib/lite/kernels/mul_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..4255cfe18a043c55f3ce7292afdedb6e988a28a2 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/mul_test.cc @@ -0,0 +1,126 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class BaseMulOpModel : public SingleOpModel { + public: + BaseMulOpModel(TensorData input, TensorData output, + ActivationFunctionType activation_type) { + input1_ = AddInput(input); + input2_ = AddInput(input); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_MUL, BuiltinOptions_MulOptions, + CreateMulOptions(builder_, activation_type).Union()); + BuildInterpreter({GetShape(input1_), GetShape(input2_)}); + } + + int input1() { return input1_; } + int input2() { return input2_; } + + protected: + int input1_; + int input2_; + int output_; +}; + +class FloatMulOpModel : public BaseMulOpModel { + public: + using BaseMulOpModel::BaseMulOpModel; + + std::vector GetOutput() { return ExtractVector(output_); } +}; + +// For quantized Mul, the error shouldn't exceed (2*step + step^2). +// The param min=-1.0 & max=1.0 is used in the following tests. +// The tolerance value is ~0.0157. +const float kQuantizedStep = 2.0 / 255.0; +const float kQuantizedTolerance = + 2.0 * kQuantizedStep + kQuantizedStep * kQuantizedStep; + +class QuantizedMulOpModel : public BaseMulOpModel { + public: + using BaseMulOpModel::BaseMulOpModel; + + std::vector GetDequantizedOutput() { + return Dequantize(ExtractVector(output_), + GetScale(output_), GetZeroPoint(output_)); + } +}; + +TEST(FloatMulOpTest, NoActivation) { + FloatMulOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE); + m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.8}); + m.PopulateTensor(m.input2(), {0.1, 0.2, 0.3, 0.5}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({-0.2, 0.04, 0.21, 0.4}))); +} + +TEST(FloatMulOpTest, ActivationRELU1) { + FloatMulOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {}}, ActivationFunctionType_RELU1); + m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.8}); + m.PopulateTensor(m.input2(), {0.1, 0.2, 0.3, 5}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({-0.2, 0.04, 0.21, 1.0}))); +} + +TEST(FloatMulOpTest, VariousInputShapes) { + std::vector> test_shapes = { + {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; + for (int i = 0; i < test_shapes.size(); ++i) { + FloatMulOpModel m({TensorType_FLOAT32, test_shapes[i]}, + {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE); + m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.8, 1.1, 2.0}); + m.PopulateTensor(m.input2(), {0.1, 0.2, 0.3, 0.5, 1.1, 0.1}); + m.Invoke(); + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray(ArrayFloatNear({-0.2, 0.04, 0.21, 0.4, 1.21, 0.2}))) + << "With shape number " << i; + } +} + +TEST(QuantizedMulOpTest, NoActivation) { + QuantizedMulOpModel m({TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0}, + {TensorType_UINT8, {}, -1.0, 1.0}, + ActivationFunctionType_NONE); + m.QuantizeAndPopulate(m.input1(), {-0.8, 0.2, 0.9, 0.7}); + m.QuantizeAndPopulate(m.input2(), {0.6, 0.4, 0.9, 0.8}); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({-0.48, 0.08, 0.81, 0.56}, + kQuantizedTolerance))); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/op_macros.h b/tensorflow/contrib/lite/kernels/op_macros.h new file mode 100644 index 0000000000000000000000000000000000000000..7535afaf8ea52d855e2e4773e56ce2118a16447c --- /dev/null +++ b/tensorflow/contrib/lite/kernels/op_macros.h @@ -0,0 +1,32 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_OP_UTIL_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_OP_UTIL_H_ + +#define TF_LITE_FATAL(msg) \ + do { \ + fprintf(stderr, "%s\n", (msg)); \ + exit(1); \ + } while (0) +#define TF_LITE_ASSERT(x) \ + do { \ + if (!(x)) TF_LITE_FATAL(#x); \ + } while (0) +#define TF_LITE_ASSERT_EQ(x, y) \ + do { \ + if ((x) != (y)) TF_LITE_FATAL(#x " didn't equal " #y); \ + } while (0) + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_OP_UTIL_H_ diff --git a/tensorflow/contrib/lite/kernels/optional_tensor_test.cc b/tensorflow/contrib/lite/kernels/optional_tensor_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..17166715ca30ff3d8ba3d384110e403f8910e39d --- /dev/null +++ b/tensorflow/contrib/lite/kernels/optional_tensor_test.cc @@ -0,0 +1,340 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Unit test for TFLite LSTM op. + +#include +#include +#include + +#include +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +class LSTMOpModel : public SingleOpModel { + public: + LSTMOpModel(int n_batch, int n_input, int n_cell, int n_output, bool use_cifg, + bool use_peephole, bool use_projection_weights, + bool use_projection_bias, float cell_clip, float proj_clip, + const std::vector>& input_shapes) + : n_batch_(n_batch), + n_input_(n_input), + n_cell_(n_cell), + n_output_(n_output) { + input_ = AddInput(TensorType_FLOAT32); + + if (use_cifg) { + input_to_input_weights_ = AddNullInput(); + } else { + input_to_input_weights_ = AddInput(TensorType_FLOAT32); + } + + input_to_forget_weights_ = AddInput(TensorType_FLOAT32); + input_to_cell_weights_ = AddInput(TensorType_FLOAT32); + input_to_output_weights_ = AddInput(TensorType_FLOAT32); + + if (use_cifg) { + recurrent_to_input_weights_ = AddNullInput(); + } else { + recurrent_to_input_weights_ = AddInput(TensorType_FLOAT32); + } + + recurrent_to_forget_weights_ = AddInput(TensorType_FLOAT32); + recurrent_to_cell_weights_ = AddInput(TensorType_FLOAT32); + recurrent_to_output_weights_ = AddInput(TensorType_FLOAT32); + + if (use_peephole) { + if (use_cifg) { + cell_to_input_weights_ = AddNullInput(); + } else { + cell_to_input_weights_ = AddInput(TensorType_FLOAT32); + } + cell_to_forget_weights_ = AddInput(TensorType_FLOAT32); + cell_to_output_weights_ = AddInput(TensorType_FLOAT32); + } else { + cell_to_input_weights_ = AddNullInput(); + cell_to_forget_weights_ = AddNullInput(); + cell_to_output_weights_ = AddNullInput(); + } + + if (use_cifg) { + input_gate_bias_ = AddNullInput(); + } else { + input_gate_bias_ = AddInput(TensorType_FLOAT32); + } + forget_gate_bias_ = AddInput(TensorType_FLOAT32); + cell_bias_ = AddInput(TensorType_FLOAT32); + output_gate_bias_ = AddInput(TensorType_FLOAT32); + + if (use_projection_weights) { + projection_weights_ = AddInput(TensorType_FLOAT32); + if (use_projection_bias) { + projection_bias_ = AddInput(TensorType_FLOAT32); + } else { + projection_bias_ = AddNullInput(); + } + } else { + projection_weights_ = AddNullInput(); + projection_bias_ = AddNullInput(); + } + + scratch_buffer_ = AddOutput(TensorType_FLOAT32); + // TODO(ghodrat): Modify these states when we have a permanent solution for + // persistent buffer. + output_state_ = AddOutput(TensorType_FLOAT32); + cell_state_ = AddOutput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + + SetBuiltinOp(BuiltinOperator_LSTM, BuiltinOptions_LSTMOptions, + CreateLSTMOptions(builder_, ActivationFunctionType_TANH, + cell_clip, proj_clip) + .Union()); + BuildInterpreter(input_shapes); + } + + void SetInputToInputWeights(std::initializer_list f) { + PopulateTensor(input_to_input_weights_, f); + } + + void SetInputToForgetWeights(std::initializer_list f) { + PopulateTensor(input_to_forget_weights_, f); + } + + void SetInputToCellWeights(std::initializer_list f) { + PopulateTensor(input_to_cell_weights_, f); + } + + void SetInputToOutputWeights(std::initializer_list f) { + PopulateTensor(input_to_output_weights_, f); + } + + void SetRecurrentToInputWeights(std::initializer_list f) { + PopulateTensor(recurrent_to_input_weights_, f); + } + + void SetRecurrentToForgetWeights(std::initializer_list f) { + PopulateTensor(recurrent_to_forget_weights_, f); + } + + void SetRecurrentToCellWeights(std::initializer_list f) { + PopulateTensor(recurrent_to_cell_weights_, f); + } + + void SetRecurrentToOutputWeights(std::initializer_list f) { + PopulateTensor(recurrent_to_output_weights_, f); + } + + void SetCellToInputWeights(std::initializer_list f) { + PopulateTensor(cell_to_input_weights_, f); + } + + void SetCellToForgetWeights(std::initializer_list f) { + PopulateTensor(cell_to_forget_weights_, f); + } + + void SetCellToOutputWeights(std::initializer_list f) { + PopulateTensor(cell_to_output_weights_, f); + } + + void SetInputGateBias(std::initializer_list f) { + PopulateTensor(input_gate_bias_, f); + } + + void SetForgetGateBias(std::initializer_list f) { + PopulateTensor(forget_gate_bias_, f); + } + + void SetCellBias(std::initializer_list f) { + PopulateTensor(cell_bias_, f); + } + + void SetOutputGateBias(std::initializer_list f) { + PopulateTensor(output_gate_bias_, f); + } + + void SetProjectionWeights(std::initializer_list f) { + PopulateTensor(projection_weights_, f); + } + + void SetProjectionBias(std::initializer_list f) { + PopulateTensor(projection_bias_, f); + } + + void ResetOutputState() { + const int zero_buffer_size = n_cell_ * n_batch_; + std::unique_ptr zero_buffer(new float[zero_buffer_size]); + memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float)); + PopulateTensor(output_state_, 0, zero_buffer.get(), + zero_buffer.get() + zero_buffer_size); + } + + void ResetCellState() { + const int zero_buffer_size = n_cell_ * n_batch_; + std::unique_ptr zero_buffer(new float[zero_buffer_size]); + memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float)); + PopulateTensor(cell_state_, 0, zero_buffer.get(), + zero_buffer.get() + zero_buffer_size); + } + + void SetInput(int offset, float* begin, float* end) { + PopulateTensor(input_, offset, begin, end); + } + + std::vector GetOutput() { return ExtractVector(output_); } + void Verify() { + auto model = tflite::UnPackModel(builder_.GetBufferPointer()); + EXPECT_NE(model, nullptr); + } + + int num_inputs() { return n_input_; } + int num_outputs() { return n_output_; } + int num_cells() { return n_cell_; } + int num_batches() { return n_batch_; } + + private: + int input_; + int input_to_input_weights_; + int input_to_forget_weights_; + int input_to_cell_weights_; + int input_to_output_weights_; + + int recurrent_to_input_weights_; + int recurrent_to_forget_weights_; + int recurrent_to_cell_weights_; + int recurrent_to_output_weights_; + + int cell_to_input_weights_; + int cell_to_forget_weights_; + int cell_to_output_weights_; + + int input_gate_bias_; + int forget_gate_bias_; + int cell_bias_; + int output_gate_bias_; + + int projection_weights_; + int projection_bias_; + + int output_; + int output_state_; + int cell_state_; + int scratch_buffer_; + + int n_batch_; + int n_input_; + int n_cell_; + int n_output_; +}; + + +TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) { + const int n_batch = 1; + const int n_input = 2; + // n_cell and n_output have the same size when there is no projection. + const int n_cell = 4; + const int n_output = 4; + + LSTMOpModel lstm(n_batch, n_input, n_cell, n_output, + /*use_cifg=*/true, /*use_peephole=*/true, + /*use_projection_weights=*/false, + /*use_projection_bias=*/false, + /*cell_clip=*/0.0, /*proj_clip=*/0.0, + { + {n_batch, n_input}, // input tensor + + {0, 0}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {0, 0}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {0}, // cell_to_input_weight tensor + {n_cell}, // cell_to_forget_weight tensor + {n_cell}, // cell_to_output_weight tensor + + {0}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {0, 0}, // projection_weight tensor + {0}, // projection_bias tensor + }); + + + lstm.SetInputToCellWeights({-0.49770179, -0.27711356, -0.09624726, 0.05100781, + 0.04717243, 0.48944736, -0.38535351, + -0.17212132}); + + lstm.SetInputToForgetWeights({-0.55291498, -0.42866567, 0.13056988, + -0.3633365, -0.22755712, 0.28253698, 0.24407166, + 0.33826375}); + + lstm.SetInputToOutputWeights({0.10725588, -0.02335852, -0.55932593, + -0.09426838, -0.44257352, 0.54939759, + 0.01533556, 0.42751634}); + + lstm.SetCellBias({0., 0., 0., 0.}); + + lstm.SetForgetGateBias({1., 1., 1., 1.}); + + lstm.SetOutputGateBias({0., 0., 0., 0.}); + + lstm.SetRecurrentToCellWeights( + {0.54066205, -0.32668582, -0.43562764, -0.56094903, 0.42957711, + 0.01841056, -0.32764608, -0.33027974, -0.10826075, 0.20675004, + 0.19069612, -0.03026325, -0.54532051, 0.33003211, 0.44901288, + 0.21193194}); + + lstm.SetRecurrentToForgetWeights( + {-0.13832897, -0.0515101, -0.2359007, -0.16661474, -0.14340827, + 0.36986142, 0.23414481, 0.55899, 0.10798943, -0.41174671, 0.17751795, + -0.34484994, -0.35874045, -0.11352962, 0.27268326, 0.54058349}); + + lstm.SetRecurrentToOutputWeights( + {0.41613156, 0.42610586, -0.16495961, -0.5663873, 0.30579174, -0.05115908, + -0.33941799, 0.23364776, 0.11178309, 0.09481031, -0.26424935, 0.46261835, + 0.50248802, 0.26114327, -0.43736315, 0.33149987}); + + lstm.SetCellToForgetWeights( + {0.47485286, -0.51955009, -0.24458408, 0.31544167}); + lstm.SetCellToOutputWeights( + {-0.17135078, 0.82760304, 0.85573703, -0.77109635}); + + // Resetting cell_state and output_state + lstm.ResetCellState(); + lstm.ResetOutputState(); + + // Verify the model by unpacking it. + lstm.Verify(); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/padding.h b/tensorflow/contrib/lite/kernels/padding.h new file mode 100644 index 0000000000000000000000000000000000000000..3a60274524c468ef29e522de5569e0d8354974c2 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/padding.h @@ -0,0 +1,28 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_PADDING_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_PADDING_H_ + +namespace tflite { + +inline int ComputePadding(int stride, int in_size, int filter_size, + int out_size) { + int padding = ((out_size - 1) * stride + filter_size - in_size) / 2; + return padding > 0 ? padding : 0; +} + +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_PADDING_H_ diff --git a/tensorflow/contrib/lite/kernels/pooling.cc b/tensorflow/contrib/lite/kernels/pooling.cc new file mode 100644 index 0000000000000000000000000000000000000000..b79880110897a1438a589d97363fd861c61667e7 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/pooling.cc @@ -0,0 +1,355 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/contrib/lite/kernels/padding.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace pooling { + +// This file has two implementation of each pooling op. +enum KernelType { + kReference, + kGenericOptimized, +}; + +enum PoolType { + kAverage, + kMax, + kL2, +}; + +struct OpData { + TfLitePaddingValues padding; +}; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + // This is a builtin op, so we don't use the contents in 'buffer', if any. + // Instead, we allocate a new object to carry information from Prepare() to + // Eval(). + return new OpData; +} + +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast(buffer); +} + +template +TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + OpData* data = reinterpret_cast(node->user_data); + + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + TfLiteTensor* output = GetOutput(context, node, 0); + TfLiteTensor* input = GetInput(context, node, 0); + TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4); + TF_LITE_ENSURE_EQ(context, input->type, output->type); + + int batches = input->dims->data[0]; + int height = input->dims->data[1]; + int width = input->dims->data[2]; + int channels_out = input->dims->data[3]; + + // Matching GetWindowedOutputSize in TensorFlow. + auto padding = params->padding; + auto computeOutSize = [padding](int imageSize, int filterSize, + int stride) -> int { + return padding == kTfLitePaddingSame + ? (imageSize + stride - 1) / stride + : padding == kTfLitePaddingValid + ? (imageSize - filterSize + stride) / stride + : 0; + }; + + int outWidth = + computeOutSize(width, params->filter_width, params->stride_width); + int outHeight = + computeOutSize(height, params->filter_height, params->stride_height); + + data->padding.height = ComputePadding(params->stride_height, height, + params->filter_height, outHeight); + data->padding.width = ComputePadding(params->stride_width, width, + params->filter_width, outWidth); + + if (input->type == kTfLiteUInt8) { + if (pool_type == kAverage || pool_type == kMax) { + TF_LITE_ENSURE_EQ(context, input->params.scale, output->params.scale); + TF_LITE_ENSURE_EQ(context, input->params.zero_point, + output->params.zero_point); + } + if (pool_type == kL2) { + // We currently don't have a quantized implementation of L2Pool + TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32); + } + } + + TfLiteIntArray* outputSize = TfLiteIntArrayCreate(4); + outputSize->data[0] = batches; + outputSize->data[1] = outHeight; + outputSize->data[2] = outWidth; + outputSize->data[3] = channels_out; + return context->ResizeTensor(context, output, outputSize); +} + +template +void AverageEvalFloat(TfLiteContext* context, TfLiteNode* node, + TfLitePoolParams* params, OpData* data, + TfLiteTensor* input, TfLiteTensor* output) { + float activation_min, activation_max; + CalculateActivationRangeFloat(params->activation, &activation_min, + &activation_max); +#define TF_LITE_AVERAGE_POOL(type) \ + type::AveragePool( \ + GetTensorData(input), GetTensorDims(input), params->stride_width, \ + params->stride_height, data->padding.width, data->padding.height, \ + params->filter_width, params->filter_height, activation_min, \ + activation_max, GetTensorData(output), GetTensorDims(output)) + if (kernel_type == kReference) { + TF_LITE_AVERAGE_POOL(reference_ops); + } else { + TF_LITE_AVERAGE_POOL(optimized_ops); + } +#undef TF_LITE_AVERAGE_POOL +} + +template +void AverageEvalQuantized(TfLiteContext* context, TfLiteNode* node, + TfLitePoolParams* params, OpData* data, + TfLiteTensor* input, TfLiteTensor* output) { + int32_t activation_min; + int32_t activation_max; + CalculateActivationRangeUint8(params->activation, output, &activation_min, + &activation_max); +#define TF_LITE_AVERAGE_POOL(type) \ + type::AveragePool(GetTensorData(input), GetTensorDims(input), \ + params->stride_width, params->stride_height, \ + data->padding.width, data->padding.height, \ + params->filter_width, params->filter_height, \ + activation_min, activation_max, \ + GetTensorData(output), GetTensorDims(output)) + if (kernel_type == kReference) { + TF_LITE_AVERAGE_POOL(reference_ops); + } else { + TF_LITE_AVERAGE_POOL(optimized_ops); + } +#undef TF_LITE_AVERAGE_POOL +} + +template +void MaxEvalFloat(TfLiteContext* context, TfLiteNode* node, + TfLitePoolParams* params, OpData* data, TfLiteTensor* input, + TfLiteTensor* output) { + float activation_min, activation_max; + CalculateActivationRangeFloat(params->activation, &activation_min, + &activation_max); +#define TF_LITE_MAX_POOL(type) \ + type::MaxPool( \ + GetTensorData(input), GetTensorDims(input), params->stride_width, \ + params->stride_height, data->padding.width, data->padding.height, \ + params->filter_width, params->filter_height, activation_min, \ + activation_max, GetTensorData(output), GetTensorDims(output)) + if (kernel_type == kReference) { + TF_LITE_MAX_POOL(reference_ops); + } else { + TF_LITE_MAX_POOL(optimized_ops); + } +#undef TF_LITE_MAX_POOL +} + +template +void MaxEvalQuantized(TfLiteContext* context, TfLiteNode* node, + TfLitePoolParams* params, OpData* data, + TfLiteTensor* input, TfLiteTensor* output) { + int32_t activation_min; + int32_t activation_max; + CalculateActivationRangeUint8(params->activation, output, &activation_min, + &activation_max); +#define TF_LITE_MAX_POOL(type) \ + type::MaxPool(GetTensorData(input), GetTensorDims(input), \ + params->stride_width, params->stride_height, \ + data->padding.width, data->padding.height, \ + params->filter_width, params->filter_height, activation_min, \ + activation_max, GetTensorData(output), \ + GetTensorDims(output)) + if (kernel_type == kReference) { + TF_LITE_MAX_POOL(reference_ops); + } else { + TF_LITE_MAX_POOL(optimized_ops); + } +#undef TF_LITE_MAX_POOL +} + +template +void L2EvalFloat(TfLiteContext* context, TfLiteNode* node, + TfLitePoolParams* params, OpData* data, TfLiteTensor* input, + TfLiteTensor* output) { + float activation_min, activation_max; + CalculateActivationRangeFloat(params->activation, &activation_min, + &activation_max); +#define TF_LITE_L2_POOL(type) \ + type::L2Pool( \ + GetTensorData(input), GetTensorDims(input), params->stride_width, \ + params->stride_height, data->padding.width, data->padding.height, \ + params->filter_width, params->filter_height, activation_min, \ + activation_max, GetTensorData(output), GetTensorDims(output)) + if (kernel_type == kReference) { + TF_LITE_L2_POOL(reference_ops); + } else { + TF_LITE_L2_POOL(optimized_ops); + } +#undef TF_LITE_L2_POOL +} + +#undef TF_LITE_KERNEL_TYPE_DISPATCH + +template +TfLiteStatus AverageEval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + OpData* data = reinterpret_cast(node->user_data); + + TfLiteTensor* output = GetOutput(context, node, 0); + TfLiteTensor* input = GetInput(context, node, 0); + switch (input->type) { // Already know in/out types are same. + case kTfLiteFloat32: + AverageEvalFloat(context, node, params, data, input, output); + break; + case kTfLiteUInt8: + AverageEvalQuantized(context, node, params, data, input, + output); + break; + default: + context->ReportError(context, "Type not currently supported."); + return kTfLiteError; + } + return kTfLiteOk; +} + +template +TfLiteStatus MaxEval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + OpData* data = reinterpret_cast(node->user_data); + + TfLiteTensor* output = GetOutput(context, node, 0); + TfLiteTensor* input = GetInput(context, node, 0); + switch (input->type) { // Already know in/out types are same. + case kTfLiteFloat32: + MaxEvalFloat(context, node, params, data, input, output); + break; + case kTfLiteUInt8: + MaxEvalQuantized(context, node, params, data, input, output); + break; + default: + context->ReportError(context, "Type not currently supported."); + return kTfLiteError; + } + return kTfLiteOk; +} + +template +TfLiteStatus L2Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + OpData* data = reinterpret_cast(node->user_data); + + TfLiteTensor* output = GetOutput(context, node, 0); + TfLiteTensor* input = GetInput(context, node, 0); + switch (input->type) { // Already know in/out types are same. + case kTfLiteFloat32: + L2EvalFloat(context, node, params, data, input, output); + break; + case kTfLiteUInt8: + // We don't have a quantized implementation, so just fall through to the + // 'default' case. + default: + context->ReportError(context, "Type not currently supported."); + return kTfLiteError; + } + return kTfLiteOk; +} + +} // namespace pooling + +TfLiteRegistration* Register_AVERAGE_POOL_REF() { + static TfLiteRegistration r = {pooling::Init, pooling::Free, + pooling::GenericPrepare, + pooling::AverageEval}; + return &r; +} + +TfLiteRegistration* Register_MAX_POOL_REF() { + static TfLiteRegistration r = {pooling::Init, pooling::Free, + pooling::GenericPrepare, + pooling::MaxEval}; + return &r; +} + +TfLiteRegistration* Register_L2_POOL_REF() { + static TfLiteRegistration r = {pooling::Init, pooling::Free, + pooling::GenericPrepare, + pooling::L2Eval}; + return &r; +} + +TfLiteRegistration* Register_AVERAGE_POOL_GENERIC_OPT() { + static TfLiteRegistration r = { + pooling::Init, pooling::Free, pooling::GenericPrepare, + pooling::AverageEval}; + return &r; +} + +TfLiteRegistration* Register_MAX_POOL_GENERIC_OPT() { + static TfLiteRegistration r = {pooling::Init, pooling::Free, + pooling::GenericPrepare, + pooling::MaxEval}; + return &r; +} + +TfLiteRegistration* Register_L2_POOL_GENERIC_OPT() { + static TfLiteRegistration r = {pooling::Init, pooling::Free, + pooling::GenericPrepare, + pooling::L2Eval}; + return &r; +} + +TfLiteRegistration* Register_AVERAGE_POOL_2D() { + return Register_AVERAGE_POOL_GENERIC_OPT(); +} + +TfLiteRegistration* Register_MAX_POOL_2D() { + return Register_MAX_POOL_GENERIC_OPT(); +} + +TfLiteRegistration* Register_L2_POOL_2D() { + return Register_L2_POOL_GENERIC_OPT(); +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/pooling_test.cc b/tensorflow/contrib/lite/kernels/pooling_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..01c91b2ba905e249c36af19f175c68a7e7f17f6d --- /dev/null +++ b/tensorflow/contrib/lite/kernels/pooling_test.cc @@ -0,0 +1,161 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class BasePoolingOpModel : public SingleOpModel { + public: + // TODO(ahentz): Also test different activation types, bias, padding types, + // stride values. + BasePoolingOpModel(BuiltinOperator type, const TensorData& input, + int filter_width, int filter_height, + const TensorData& output) { + input_ = AddInput(input); + output_ = AddOutput(output); + + SetBuiltinOp( + type, BuiltinOptions_Pool2DOptions, + CreatePool2DOptions(builder_, Padding_VALID, 2, 2, filter_width, + filter_height, ActivationFunctionType_NONE) + .Union()); + + BuildInterpreter({GetShape(input_)}); + } + + protected: + int input_; + int output_; +}; + +class FloatPoolingOpModel : public BasePoolingOpModel { + public: + using BasePoolingOpModel::BasePoolingOpModel; + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } +}; + +class QuantizedPoolingOpModel : public BasePoolingOpModel { + public: + using BasePoolingOpModel::BasePoolingOpModel; + + void SetInput(std::initializer_list data) { + QuantizeAndPopulate(input_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetDequantizedOutput() { + return Dequantize(ExtractVector(output_), + GetScale(output_), GetZeroPoint(output_)); + } +}; + +TEST(FloatPoolingOpTest, AveragePool) { + FloatPoolingOpModel m(BuiltinOperator_AVERAGE_POOL_2D, + /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}}, + /*filter_width=*/2, /*filter_height=*/2, + /*output=*/{TensorType_FLOAT32, {}}); + m.SetInput({ + 0, 6, 2, 4, // + 3, 2, 10, 7, // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({2.75, 5.75})); +} + +TEST(QuantizedPoolingOpTest, AveragePool) { + // Choose the input ranges carefully so that the dequantized output matches + // the results of the float model above. + QuantizedPoolingOpModel m( + BuiltinOperator_AVERAGE_POOL_2D, + /*input=*/{TensorType_UINT8, {1, 2, 4, 1}, 0, 15.9375}, + /*filter_width=*/2, /*filter_height=*/2, + /*output=*/{TensorType_UINT8, {}, 0, 15.9375}); + m.SetInput({ + 0, 6, 2, 4, // + 3, 2, 10, 7, // + }); + m.Invoke(); + + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({2.75, 5.75}))); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({44, 92})); +} + +TEST(FloatPoolingOpTest, MaxPool) { + FloatPoolingOpModel m(BuiltinOperator_MAX_POOL_2D, + /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}}, + /*filter_width=*/2, /*filter_height=*/2, + /*output=*/{TensorType_FLOAT32, {}}); + m.SetInput({ + 0, 6, 2, 4, // + 3, 2, 10, 7, // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({6, 10})); +} + +TEST(QuantizedPoolingOpTest, MaxPool) { + // Choose the input ranges carefully so that the dequantized output matches + // the results of the float model above. + QuantizedPoolingOpModel m( + BuiltinOperator_MAX_POOL_2D, + /*input=*/{TensorType_UINT8, {1, 2, 4, 1}, 0, 15.9375}, + /*filter_width=*/2, /*filter_height=*/2, + /*output=*/{TensorType_UINT8, {}, 0, 15.9375}); + m.SetInput({ + 0, 6, 2, 4, // + 3, 2, 10, 7, // + }); + m.Invoke(); + + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({6, 10}))); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({96, 160})); +} + +TEST(FloatPoolingOpTest, L2Pool) { + FloatPoolingOpModel m(BuiltinOperator_L2_POOL_2D, + /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}}, + /*filter_width=*/2, /*filter_height=*/2, + /*output=*/{TensorType_FLOAT32, {}}); + m.SetInput({ + 0, 6, 2, 4, // + 3, 2, 10, 7, // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({3.5, 6.5})); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc new file mode 100644 index 0000000000000000000000000000000000000000..ca7a0dd1949a3a31d26be770a7df781cc5fe7533 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/register.cc @@ -0,0 +1,109 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/kernels/register.h" + +namespace tflite { +namespace ops { +namespace builtin { + +TfLiteRegistration* Register_RELU(); +TfLiteRegistration* Register_RELU1(); +TfLiteRegistration* Register_RELU6(); +TfLiteRegistration* Register_TANH(); +TfLiteRegistration* Register_LOGISTIC(); +TfLiteRegistration* Register_AVERAGE_POOL_2D(); +TfLiteRegistration* Register_MAX_POOL_2D(); +TfLiteRegistration* Register_L2_POOL_2D(); +TfLiteRegistration* Register_CONV_2D(); +TfLiteRegistration* Register_DEPTHWISE_CONV_2D(); +TfLiteRegistration* Register_SVDF(); +TfLiteRegistration* Register_RNN(); +TfLiteRegistration* Register_EMBEDDING_LOOKUP(); +TfLiteRegistration* Register_EMBEDDING_LOOKUP_SPARSE(); +TfLiteRegistration* Register_FULLY_CONNECTED(); +TfLiteRegistration* Register_LSH_PROJECTION(); +TfLiteRegistration* Register_HASHTABLE_LOOKUP(); +TfLiteRegistration* Register_SOFTMAX(); +TfLiteRegistration* Register_CONCATENATION(); +TfLiteRegistration* Register_ADD(); +TfLiteRegistration* Register_MUL(); +TfLiteRegistration* Register_L2_NORMALIZATION(); +TfLiteRegistration* Register_LOCAL_RESPONSE_NORMALIZATION(); +TfLiteRegistration* Register_LSTM(); +TfLiteRegistration* Register_RESHAPE(); +TfLiteRegistration* Register_RESIZE_BILINEAR(); +TfLiteRegistration* Register_SKIP_GRAM(); +TfLiteRegistration* Register_SPACE_TO_DEPTH(); + +BuiltinOpResolver::BuiltinOpResolver() { + AddBuiltin(BuiltinOperator_RELU, Register_RELU()); + AddBuiltin(BuiltinOperator_RELU1, Register_RELU1()); + AddBuiltin(BuiltinOperator_RELU6, Register_RELU6()); + AddBuiltin(BuiltinOperator_TANH, Register_TANH()); + AddBuiltin(BuiltinOperator_LOGISTIC, Register_LOGISTIC()); + AddBuiltin(BuiltinOperator_AVERAGE_POOL_2D, Register_AVERAGE_POOL_2D()); + AddBuiltin(BuiltinOperator_MAX_POOL_2D, Register_MAX_POOL_2D()); + AddBuiltin(BuiltinOperator_L2_POOL_2D, Register_L2_POOL_2D()); + AddBuiltin(BuiltinOperator_CONV_2D, Register_CONV_2D()); + AddBuiltin(BuiltinOperator_DEPTHWISE_CONV_2D, Register_DEPTHWISE_CONV_2D()); + AddBuiltin(BuiltinOperator_SVDF, Register_SVDF()); + AddBuiltin(BuiltinOperator_RNN, Register_RNN()); + AddBuiltin(BuiltinOperator_EMBEDDING_LOOKUP, Register_EMBEDDING_LOOKUP()); + AddBuiltin(BuiltinOperator_EMBEDDING_LOOKUP_SPARSE, + Register_EMBEDDING_LOOKUP_SPARSE()); + AddBuiltin(BuiltinOperator_FULLY_CONNECTED, Register_FULLY_CONNECTED()); + AddBuiltin(BuiltinOperator_LSH_PROJECTION, Register_LSH_PROJECTION()); + AddBuiltin(BuiltinOperator_HASHTABLE_LOOKUP, Register_HASHTABLE_LOOKUP()); + AddBuiltin(BuiltinOperator_SOFTMAX, Register_SOFTMAX()); + AddBuiltin(BuiltinOperator_CONCATENATION, Register_CONCATENATION()); + AddBuiltin(BuiltinOperator_ADD, Register_ADD()); + AddBuiltin(BuiltinOperator_MUL, Register_MUL()); + AddBuiltin(BuiltinOperator_L2_NORMALIZATION, Register_L2_NORMALIZATION()); + AddBuiltin(BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION, + Register_LOCAL_RESPONSE_NORMALIZATION()); + AddBuiltin(BuiltinOperator_LSTM, Register_LSTM()); + AddBuiltin(BuiltinOperator_RESHAPE, Register_RESHAPE()); + AddBuiltin(BuiltinOperator_RESIZE_BILINEAR, Register_RESIZE_BILINEAR()); + AddBuiltin(BuiltinOperator_SKIP_GRAM, Register_SKIP_GRAM()); + AddBuiltin(BuiltinOperator_SPACE_TO_DEPTH, Register_SPACE_TO_DEPTH()); +} + +TfLiteRegistration* BuiltinOpResolver::FindOp( + tflite::BuiltinOperator op) const { + auto it = builtins_.find(op); + return it != builtins_.end() ? it->second : nullptr; +} + +TfLiteRegistration* BuiltinOpResolver::FindOp(const char* op) const { + auto it = custom_ops_.find(op); + return it != custom_ops_.end() ? it->second : nullptr; +} + +void BuiltinOpResolver::AddBuiltin(tflite::BuiltinOperator op, + TfLiteRegistration* registration) { + registration->builtin_code = op; + builtins_.insert(std::make_pair(op, registration)); +} + +void BuiltinOpResolver::AddCustom(const char* name, + TfLiteRegistration* registration) { + registration->builtin_code = BuiltinOperator_CUSTOM; + custom_ops_.insert(std::make_pair(std::string(name), registration)); +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/register.h b/tensorflow/contrib/lite/kernels/register.h new file mode 100644 index 0000000000000000000000000000000000000000..28f5e0fcc80a14cf9fb6fb19b795d0c0d55e0df9 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/register.h @@ -0,0 +1,50 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_REGISTER_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_REGISTER_H_ + +#include +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace ops { +namespace builtin { + +class BuiltinOpResolver : public OpResolver { + public: + BuiltinOpResolver(); + TfLiteRegistration* FindOp(tflite::BuiltinOperator op) const override; + TfLiteRegistration* FindOp(const char* op) const override; + void AddBuiltin(tflite::BuiltinOperator op, TfLiteRegistration* registration); + void AddCustom(const char* name, TfLiteRegistration* registration); + + private: + struct BuiltinOperatorHasher { + size_t operator()(const tflite::BuiltinOperator& x) const { + return std::hash()(static_cast(x)); + } + }; + std::unordered_map + builtins_; + std::unordered_map custom_ops_; +}; + +} // namespace builtin +} // namespace ops +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_BUILTIN_KERNELS_H diff --git a/tensorflow/contrib/lite/kernels/reshape.cc b/tensorflow/contrib/lite/kernels/reshape.cc new file mode 100644 index 0000000000000000000000000000000000000000..f3e6ddc9f480e3863cac52157ae28b7329ee2088 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/reshape.cc @@ -0,0 +1,91 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace reshape { + +constexpr int kInputTensor = 0; +constexpr int kOutputTensor = 0; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + + // TODO(ahentz): we are often given a tensor with the shape but we only pay + // attention to what the shape specified in 'params'. + TF_LITE_ENSURE(context, NumInputs(node) == 1 || NumInputs(node) == 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + // Tensorflow's Reshape allows one of the shape components to have the + // special -1 value, meaning it will be calculated automatically based on the + // input. Here we calculate what that dimension should be so that the number + // of output elements in the same as the number of input elements. + int num_input_elements = 1; + for (int i = 0; i < NumDimensions(input); ++i) { + num_input_elements *= SizeOfDimension(input, i); + } + + TfLiteIntArray* output_size = TfLiteIntArrayCreate(params->num_dimensions); + int num_output_elements = 1; + int strech_dim = -1; + for (int i = 0; i < params->num_dimensions; ++i) { + int value = params->shape[i]; + if (value == -1) { + TF_LITE_ENSURE_EQ(context, strech_dim, -1); + strech_dim = i; + } else { + num_output_elements *= value; + output_size->data[i] = value; + } + } + if (strech_dim != -1) { + output_size->data[strech_dim] = num_input_elements / num_output_elements; + num_output_elements *= output_size->data[strech_dim]; + } + + TF_LITE_ENSURE_EQ(context, num_input_elements, num_output_elements); + return context->ResizeTensor(context, output, output_size); +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + memcpy(output->data.raw, input->data.raw, input->bytes); + + return kTfLiteOk; +} + +} // namespace reshape + +TfLiteRegistration* Register_RESHAPE() { + static TfLiteRegistration r = {nullptr, nullptr, reshape::Prepare, + reshape::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/reshape_test.cc b/tensorflow/contrib/lite/kernels/reshape_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..0fbcf6e6aa311d2cac491336ee54ccf58bbda8fd --- /dev/null +++ b/tensorflow/contrib/lite/kernels/reshape_test.cc @@ -0,0 +1,89 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class ReshapeOpModel : public SingleOpModel { + public: + ReshapeOpModel(std::initializer_list input_shape, + std::initializer_list new_shape) { + input_ = AddInput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp( + BuiltinOperator_RESHAPE, BuiltinOptions_ReshapeOptions, + CreateReshapeOptions(builder_, builder_.CreateVector(new_shape)) + .Union()); + BuildInterpreter({input_shape}); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + private: + int input_; + int output_; +}; + +TEST(ReshapeOpTest, MismatchedDimensions) { + EXPECT_DEATH(ReshapeOpModel({1, 2, 4, 1}, {2, 1}), + "num_input_elements != num_output_elements"); +} + +TEST(ReshapeOpTest, TooManyDimensions) { + EXPECT_DEATH( + ReshapeOpModel({1, 2, 3, 4, 5, 6, 7, 8, 9}, {1, 2, 3, 4, 5, 6, 7, 8, 9}), + "Found too many dimensions"); +} + +TEST(ReshapeOpTest, TooManySpecialDimensions) { + EXPECT_DEATH(ReshapeOpModel({1, 2, 4, 1}, {-1, -1, 2, 4}), + "strech_dim != -1"); +} + +TEST(ReshapeOpTest, SimpleTest) { + ReshapeOpModel m({1, 2, 4, 1}, {2, 2, 2}); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 2})); +} + +TEST(ReshapeOpTest, WithStretchDimension) { + ReshapeOpModel m({1, 2, 4, 1}, {2, 1, -1}); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 1, 4})); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/resize_bilinear.cc b/tensorflow/contrib/lite/kernels/resize_bilinear.cc new file mode 100644 index 0000000000000000000000000000000000000000..1613c9a89faa3579b913408cc09cdad7f942cb99 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/resize_bilinear.cc @@ -0,0 +1,129 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace resize_bilinear { + +// This file has three implementation of RESIZE_BILINEAR. +enum KernelType { + kReference, + kGenericOptimized, // Neon-free + kNeonOptimized, +}; + +constexpr int kInputTensor = 0; +constexpr int kOutputTensor = 0; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + auto* params = + reinterpret_cast(node->builtin_data); + + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + // TODO(ahentz): Our current implementations rely on the inputs being 4D. + TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4); + + // TODO(ahentz): Our current implementations only support float32. + TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32); + TF_LITE_ENSURE_EQ(context, input->type, output->type); + + TfLiteIntArray* output_size = TfLiteIntArrayCreate(4); + output_size->data[0] = input->dims->data[0]; + output_size->data[1] = params->new_height; + output_size->data[2] = params->new_width; + output_size->data[3] = input->dims->data[3]; + + return context->ResizeTensor(context, output, output_size); +} + +template +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = + reinterpret_cast(node->builtin_data); + + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + // We have to fake a tensor here, to satisfy ResizeBilinear(). + int32 output_size_data[2] = {params->new_height, params->new_width}; + + if (output->type == kTfLiteFloat32) { +#define TF_LITE_RESIZE_BILINEAR(type) \ + type::ResizeBilinear(GetTensorData(input), GetTensorDims(input), \ + output_size_data, GetTensorDims({1, 1, 1, 2}), \ + GetTensorData(output), GetTensorDims(output)) + + if (kernel_type == kReference) { + TF_LITE_RESIZE_BILINEAR(reference_ops); + } + if (kernel_type == kGenericOptimized || kernel_type == kNeonOptimized) { + TF_LITE_RESIZE_BILINEAR(optimized_ops); + } +#undef TF_LITE_RESIZE_BILINEAR + } else { + context->ReportError(context, "Inputs and outputs not all float types."); + return kTfLiteError; + } + + return kTfLiteOk; +} + +} // namespace resize_bilinear + +TfLiteRegistration* Register_RESIZE_BILINEAR_REF() { + static TfLiteRegistration r = { + nullptr, nullptr, resize_bilinear::Prepare, + resize_bilinear::Eval}; + return &r; +} + +TfLiteRegistration* Register_RESIZE_BILINEAR_GENERIC_OPT() { + static TfLiteRegistration r = { + nullptr, nullptr, resize_bilinear::Prepare, + resize_bilinear::Eval}; + return &r; +} + +TfLiteRegistration* Register_RESIZE_BILINEAR_NEON_OPT() { + static TfLiteRegistration r = { + nullptr, nullptr, resize_bilinear::Prepare, + resize_bilinear::Eval}; + return &r; +} + +TfLiteRegistration* Register_RESIZE_BILINEAR() { +#ifdef USE_NEON + return Register_RESIZE_BILINEAR_NEON_OPT(); +#else + return Register_RESIZE_BILINEAR_GENERIC_OPT(); +#endif +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc b/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..314a71e210d9b5ea75bb137ef228273ef48f28b5 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc @@ -0,0 +1,117 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class ResizeBilinearOpModel : public SingleOpModel { + public: + ResizeBilinearOpModel(std::initializer_list input_shape, int new_height, + int new_width) { + input_ = AddInput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp( + BuiltinOperator_RESIZE_BILINEAR, BuiltinOptions_ResizeBilinearOptions, + CreateResizeBilinearOptions(builder_, new_height, new_width).Union()); + BuildInterpreter({input_shape}); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + + private: + int input_; + int output_; +}; + +TEST(ResizeBilinearOpTest, HorizontalResize) { + ResizeBilinearOpModel m({1, 1, 2, 1}, 1, 3); + m.SetInput({3, 6}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 5, 6}))); +} + +TEST(ResizeBilinearOpTest, VerticalResize) { + ResizeBilinearOpModel m({1, 2, 1, 1}, 3, 1); + m.SetInput({3, 9}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 7, 9}))); +} + +TEST(ResizeBilinearOpTest, TwoDimensionalResize) { + ResizeBilinearOpModel m({1, 2, 2, 1}, 3, 3); + m.SetInput({ + 3, 6, // + 9, 12 // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 5, 6, // + 7, 9, 10, // + 9, 11, 12, // + }))); +} + +TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches) { + ResizeBilinearOpModel m({2, 2, 2, 1}, 3, 3); + m.SetInput({ + 3, 6, // + 9, 12, // + 4, 10, // + 10, 16 // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 5, 6, // + 7, 9, 10, // + 9, 11, 12, // + 4, 8, 10, // + 8, 12, 14, // + 10, 14, 16, // + }))); +} + +TEST(ResizeBilinearOpTest, ThreeDimensionalResize) { + ResizeBilinearOpModel m({1, 2, 2, 2}, 3, 3); + m.SetInput({ + 3, 4, 6, 10, // + 9, 10, 12, 16, // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 4, 5, 8, 6, 10, // + 7, 8, 9, 12, 10, 14, // + 9, 10, 11, 14, 12, 16, // + }))); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/skip_gram.cc b/tensorflow/contrib/lite/kernels/skip_gram.cc new file mode 100644 index 0000000000000000000000000000000000000000..c90a15b3a2e79028128260e579f41742a46289f6 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/skip_gram.cc @@ -0,0 +1,160 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Generate a list of skip grams from an input. +// +// Options: +// ngram_size: num of words for each output item. +// max_skip_size: max num of words to skip. +// The op generates ngrams when it is 0. +// include_all_ngrams: include all ngrams with size up to ngram_size. +// +// Input: +// A string tensor to generate n-grams. +// Dim = {1} +// +// Output: +// A list of strings, each of which contains ngram_size words. +// Dim = {num_ngram} + +#include +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/contrib/lite/string_util.h" + +namespace tflite { +namespace ops { +namespace builtin { + +namespace { + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + TF_LITE_ENSURE_EQ(context, GetInput(context, node, 0)->type, kTfLiteString); + TF_LITE_ENSURE_EQ(context, GetOutput(context, node, 0)->type, kTfLiteString); + return kTfLiteOk; +} + +bool ShouldIncludeCurrentNgram(const TfLiteSkipGramParams* params, int size) { + if (size <= 0) { + return false; + } + if (params->include_all_ngrams) { + return size <= params->ngram_size; + } else { + return size == params->ngram_size; + } +} + +bool ShouldStepInRecursion(const TfLiteSkipGramParams* params, + const std::vector& stack, int stack_idx, + int num_words) { + // If current stack size and next word enumeration are within valid range. + if (stack_idx < params->ngram_size && stack[stack_idx] + 1 < num_words) { + // If this stack is empty, step in for first word enumeration. + if (stack_idx == 0) { + return true; + } + // If next word enumeration are within the range of max_skip_size. + // NOTE: equivalent to + // next_word_idx = stack[stack_idx] + 1 + // next_word_idx - stack[stack_idx-1] <= max_skip_size + 1 + if (stack[stack_idx] - stack[stack_idx - 1] <= params->max_skip_size) { + return true; + } + } + return false; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + + // Split sentence to words. + std::vector words; + tflite::StringRef strref = tflite::GetString(GetInput(context, node, 0), 0); + int prev_idx = 0; + for (int i = 1; i < strref.len; i++) { + if (isspace(*(strref.str + i))) { + if (i > prev_idx && !isspace(*(strref.str + prev_idx))) { + words.push_back({strref.str + prev_idx, i - prev_idx}); + } + prev_idx = i + 1; + } + } + if (strref.len > prev_idx) { + words.push_back({strref.str + prev_idx, strref.len - prev_idx}); + } + + // Generate n-grams recursively. + tflite::DynamicBuffer buf; + if (words.size() < params->ngram_size) { + buf.WriteToTensor(GetOutput(context, node, 0)); + return kTfLiteOk; + } + + // Stack stores the index of word used to generate ngram. + // The size of stack is the size of ngram. + std::vector stack(params->ngram_size, 0); + // Stack index that indicates which depth the recursion is operating at. + int stack_idx = 1; + int num_words = words.size(); + + while (stack_idx >= 0) { + if (ShouldStepInRecursion(params, stack, stack_idx, num_words)) { + // When current depth can fill with a new word + // and the new word is within the max range to skip, + // fill this word to stack, recurse into next depth. + stack[stack_idx]++; + stack_idx++; + if (stack_idx < params->ngram_size) { + stack[stack_idx] = stack[stack_idx - 1]; + } + } else { + if (ShouldIncludeCurrentNgram(params, stack_idx)) { + // Add n-gram to tensor buffer when the stack has filled with enough + // words to generate the ngram. + std::vector gram(stack_idx); + for (int i = 0; i < stack_idx; i++) { + gram[i] = words[stack[i]]; + } + buf.AddJoinedString(gram, ' '); + } + // When current depth cannot fill with a valid new word, + // and not in last depth to generate ngram, + // step back to previous depth to iterate to next possible word. + stack_idx--; + } + } + + buf.WriteToTensor(GetOutput(context, node, 0)); + return kTfLiteOk; +} +} // namespace + +TfLiteRegistration* Register_SKIP_GRAM() { + static TfLiteRegistration r = {nullptr, nullptr, Prepare, Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/skip_gram_test.cc b/tensorflow/contrib/lite/kernels/skip_gram_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..185b64cb44969b57588ea5d0b40f55b6ddf8e11f --- /dev/null +++ b/tensorflow/contrib/lite/kernels/skip_gram_test.cc @@ -0,0 +1,257 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/string_util.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAre; + +static char kSentence[] = "The quick\t brown fox\n jumps over\n the lazy dog!"; + +class SkipGramOp : public SingleOpModel { + public: + SkipGramOp(int ngram_size, int max_skip_size, bool include_all_ngrams) { + input_ = AddInput(TensorType_STRING); + output_ = AddOutput(TensorType_STRING); + + SetBuiltinOp(BuiltinOperator_SKIP_GRAM, BuiltinOptions_SkipGramOptions, + CreateSkipGramOptions(builder_, ngram_size, max_skip_size, + include_all_ngrams) + .Union()); + BuildInterpreter({{1}}); + } + void SetInput(const string& content) { + PopulateStringTensor(input_, {content}); + } + + std::vector GetOutput() { + std::vector ans; + TfLiteTensor* tensor = interpreter_->tensor(output_); + + int num = GetStringCount(tensor); + for (int i = 0; i < num; i++) { + StringRef strref = GetString(tensor, i); + ans.push_back(string(strref.str, strref.len)); + } + return ans; + } + + private: + int input_; + int output_; +}; + +TEST(SkipGramTest, TestUnigram) { + SkipGramOp m(1, 0, false); + + m.SetInput(kSentence); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), testing::UnorderedElementsAreArray( + {"The", "quick", "brown", "fox", "jumps", + "over", "the", "lazy", "dog!"})); +} + +TEST(SkipGramTest, TestBigram) { + SkipGramOp m(2, 0, false); + m.SetInput(kSentence); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + testing::UnorderedElementsAreArray( + {"The quick", "quick brown", "brown fox", "fox jumps", + "jumps over", "over the", "the lazy", "lazy dog!"})); +} + +TEST(SkipGramTest, TestAllBigram) { + SkipGramOp m(2, 0, true); + m.SetInput(kSentence); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + testing::UnorderedElementsAreArray( + {// Unigram + "The", "quick", "brown", "fox", "jumps", "over", "the", + "lazy", "dog!", + // Bigram + "The quick", "quick brown", "brown fox", "fox jumps", + "jumps over", "over the", "the lazy", "lazy dog!"})); +} + +TEST(SkipGramTest, TestAllTrigram) { + SkipGramOp m(3, 0, true); + m.SetInput(kSentence); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + testing::UnorderedElementsAreArray( + {// Unigram + "The", "quick", "brown", "fox", "jumps", "over", "the", + "lazy", "dog!", + // Bigram + "The quick", "quick brown", "brown fox", "fox jumps", + "jumps over", "over the", "the lazy", "lazy dog!", + // Trigram + "The quick brown", "quick brown fox", "brown fox jumps", + "fox jumps over", "jumps over the", "over the lazy", + "the lazy dog!"})); +} + +TEST(SkipGramTest, TestSkip1Bigram) { + SkipGramOp m(2, 1, false); + m.SetInput(kSentence); + m.Invoke(); + EXPECT_THAT( + m.GetOutput(), + testing::UnorderedElementsAreArray( + {"The quick", "The brown", "quick brown", "quick fox", "brown fox", + "brown jumps", "fox jumps", "fox over", "jumps over", "jumps the", + "over the", "over lazy", "the lazy", "the dog!", "lazy dog!"})); +} + +TEST(SkipGramTest, TestSkip2Bigram) { + SkipGramOp m(2, 2, false); + m.SetInput(kSentence); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + testing::UnorderedElementsAreArray( + {"The quick", "The brown", "The fox", "quick brown", + "quick fox", "quick jumps", "brown fox", "brown jumps", + "brown over", "fox jumps", "fox over", "fox the", + "jumps over", "jumps the", "jumps lazy", "over the", + "over lazy", "over dog!", "the lazy", "the dog!", + "lazy dog!"})); +} + +TEST(SkipGramTest, TestSkip1Trigram) { + SkipGramOp m(3, 1, false); + m.SetInput(kSentence); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + testing::UnorderedElementsAreArray( + {"The quick brown", "The quick fox", "The brown fox", + "The brown jumps", "quick brown fox", "quick brown jumps", + "quick fox jumps", "quick fox over", "brown fox jumps", + "brown fox over", "brown jumps over", "brown jumps the", + "fox jumps over", "fox jumps the", "fox over the", + "fox over lazy", "jumps over the", "jumps over lazy", + "jumps the lazy", "jumps the dog!", "over the lazy", + "over the dog!", "over lazy dog!", "the lazy dog!"})); +} + +TEST(SkipGramTest, TestSkip2Trigram) { + SkipGramOp m(3, 2, false); + m.SetInput(kSentence); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + testing::UnorderedElementsAreArray( + {"The quick brown", "The quick fox", "The quick jumps", + "The brown fox", "The brown jumps", "The brown over", + "The fox jumps", "The fox over", "The fox the", + "quick brown fox", "quick brown jumps", "quick brown over", + "quick fox jumps", "quick fox over", "quick fox the", + "quick jumps over", "quick jumps the", "quick jumps lazy", + "brown fox jumps", "brown fox over", "brown fox the", + "brown jumps over", "brown jumps the", "brown jumps lazy", + "brown over the", "brown over lazy", "brown over dog!", + "fox jumps over", "fox jumps the", "fox jumps lazy", + "fox over the", "fox over lazy", "fox over dog!", + "fox the lazy", "fox the dog!", "jumps over the", + "jumps over lazy", "jumps over dog!", "jumps the lazy", + "jumps the dog!", "jumps lazy dog!", "over the lazy", + "over the dog!", "over lazy dog!", "the lazy dog!"})); +} + +TEST(SkipGramTest, TestAllSkip2Trigram) { + SkipGramOp m(3, 2, true); + m.SetInput(kSentence); + m.Invoke(); + EXPECT_THAT( + m.GetOutput(), + testing::UnorderedElementsAreArray( + {// Unigram + "The", "quick", "brown", "fox", "jumps", "over", "the", "lazy", + "dog!", + // Bigram + "The quick", "The brown", "The fox", "quick brown", "quick fox", + "quick jumps", "brown fox", "brown jumps", "brown over", "fox jumps", + "fox over", "fox the", "jumps over", "jumps the", "jumps lazy", + "over the", "over lazy", "over dog!", "the lazy", "the dog!", + "lazy dog!", + // Trigram + "The quick brown", "The quick fox", "The quick jumps", + "The brown fox", "The brown jumps", "The brown over", + "The fox jumps", "The fox over", "The fox the", "quick brown fox", + "quick brown jumps", "quick brown over", "quick fox jumps", + "quick fox over", "quick fox the", "quick jumps over", + "quick jumps the", "quick jumps lazy", "brown fox jumps", + "brown fox over", "brown fox the", "brown jumps over", + "brown jumps the", "brown jumps lazy", "brown over the", + "brown over lazy", "brown over dog!", "fox jumps over", + "fox jumps the", "fox jumps lazy", "fox over the", "fox over lazy", + "fox over dog!", "fox the lazy", "fox the dog!", "jumps over the", + "jumps over lazy", "jumps over dog!", "jumps the lazy", + "jumps the dog!", "jumps lazy dog!", "over the lazy", + "over the dog!", "over lazy dog!", "the lazy dog!"})); +} + +TEST(SkipGramTest, TestSingleWord) { + SkipGramOp m(1, 1, false); + m.SetInput("Hi"); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAre("Hi")); +} + +TEST(SkipGramTest, TestWordsLessThanGram) { + SkipGramOp m(3, 1, false); + m.SetInput("Hi hi"); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), std::vector()); +} + +TEST(SkipGramTest, TestEmptyInput) { + SkipGramOp m(1, 1, false); + m.SetInput(""); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAre()); +} + +TEST(SkipGramTest, TestWhitespaceInput) { + SkipGramOp m(1, 1, false); + m.SetInput(" "); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAre()); +} + +TEST(SkipGramTest, TestInputWithExtraSpace) { + SkipGramOp m(1, 1, false); + m.SetInput(" Hello world ! "); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAre("Hello", "world", "!")); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/softmax_test.cc b/tensorflow/contrib/lite/kernels/softmax_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..6c5338ff0fd26337c9adc8e0b94a0a88edfde37f --- /dev/null +++ b/tensorflow/contrib/lite/kernels/softmax_test.cc @@ -0,0 +1,142 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Unit test for TFLite SOFTMAX op. + +#include +#include +#include + +#include +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +class SoftmaxOpModel : public SingleOpModel { + public: + SoftmaxOpModel(int batches, int size, float beta) + : batches_(batches), input_size_(size), beta_(beta) { + input_ = AddInput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp(BuiltinOperator_SOFTMAX, BuiltinOptions_SoftmaxOptions, + CreateSoftmaxOptions(builder_, beta_).Union()); + BuildInterpreter({{batches_, input_size_}}); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + void SetInput(int offset, float* begin, float* end) { + PopulateTensor(input_, offset, begin, end); + } + + std::vector GetOutput() { return ExtractVector(output_); } + + private: + int input_; + int output_; + + int batches_; + int input_size_; + float beta_; +}; + +TEST(SoftmaxOpTest, SimpleTest) { + SoftmaxOpModel m(/*batches=*/2, /*size=*/5, /*beta=*/1.0); + m.SetInput({ + 1.0, 2.0, 3.0, 4.0, 5.0, // b = 0 + -1.0, -2.0, -3.0, -4.0, -5.0, // b = 0 + }); + + m.Invoke(); + + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray(ArrayFloatNear( + {0.011656231, 0.031684921, 0.086128544, 0.234121657, 0.636408647, + 0.636408647, 0.234121657, 0.086128544, 0.031684921, 0.011656231}, + 1e-6))); +} + +TEST(SoftmaxOpTest, CompareWithTFminiBetaEq1) { + const int batch_size = 2; + const int input_size = 5; + const float beta = 1.0; + static float input_buffer[] = { + 1.0, 2.0, 3.0, 4.0, 5.0, // b = 0 + -1.0, -2.0, -3.0, -4.0, -5.0, // b = 1 + }; + + SoftmaxOpModel m(batch_size, input_size, beta); + + m.SetInput(0, input_buffer, input_buffer + input_size * batch_size); + + m.Invoke(); + + std::unique_ptr output_buffer(new float[input_size * batch_size]); + static tflite::Dims<4> input_dims = {{input_size, 1, 1, batch_size}, + {1, 0, 0, input_size}}; + tflite::reference_ops::Softmax(input_buffer, input_dims, beta, + output_buffer.get(), input_dims); + + std::vector expected; + expected.insert(expected.end(), output_buffer.get(), + output_buffer.get() + input_size * batch_size); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(expected, 1e-6))); +} + +TEST(SoftmaxOpTest, CompareWithTFminiBetaNotEq1) { + const int batch_size = 2; + const int input_size = 5; + const float beta = 0.5; + static float input_buffer[] = { + 1.0, 2.0, 3.0, 4.0, 5.0, // b = 0 + -1.0, -2.0, -3.0, -4.0, -5.0, // b = 1 + }; + + SoftmaxOpModel m(batch_size, input_size, beta); + + m.SetInput(0, input_buffer, input_buffer + input_size * batch_size); + + m.Invoke(); + + std::unique_ptr output_buffer(new float[input_size * batch_size]); + static tflite::Dims<4> input_dims = {{input_size, 1, 1, batch_size}, + {1, 0, 0, input_size}}; + tflite::reference_ops::Softmax(input_buffer, input_dims, beta, + output_buffer.get(), input_dims); + + std::vector expected; + expected.insert(expected.end(), output_buffer.get(), + output_buffer.get() + input_size * batch_size); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(expected, 1e-6))); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/space_to_depth.cc b/tensorflow/contrib/lite/kernels/space_to_depth.cc new file mode 100644 index 0000000000000000000000000000000000000000..cb2e509c9811b1469c4d3f676532edff570a6c4a --- /dev/null +++ b/tensorflow/contrib/lite/kernels/space_to_depth.cc @@ -0,0 +1,146 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace space_to_depth { + +// This file has two implementation of SpaceToDepth. Note that SpaceToDepth +// only works on 4D tensors. +enum KernelType { + kReference, + kGenericOptimized, +}; + +constexpr int kInputTensor = 0; +constexpr int kOutputTensor = 0; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + auto* params = + reinterpret_cast(node->builtin_data); + + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4); + + auto data_type = output->type; + TF_LITE_ENSURE(context, + data_type == kTfLiteFloat32 || data_type == kTfLiteUInt8 || + data_type == kTfLiteInt32 || data_type == kTfLiteInt64); + TF_LITE_ENSURE_EQ(context, input->type, output->type); + + const int block_size = params->block_size; + const int input_height = input->dims->data[1]; + const int input_width = input->dims->data[2]; + int output_height = input_height / block_size; + int output_width = input_width / block_size; + + TF_LITE_ENSURE_EQ(context, input_height, output_height * block_size); + TF_LITE_ENSURE_EQ(context, input_width, output_width * block_size); + + TfLiteIntArray* output_size = TfLiteIntArrayCreate(4); + output_size->data[0] = input->dims->data[0]; + output_size->data[1] = output_height; + output_size->data[2] = output_width; + output_size->data[3] = input->dims->data[3] * block_size * block_size; + + return context->ResizeTensor(context, output, output_size); +} + +template +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = + reinterpret_cast(node->builtin_data); + + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + +#define TF_LITE_SPACE_TO_DEPTH(type, scalar) \ + type::SpaceToDepth( \ + GetTensorData(input), GetTensorDims(input), params->block_size, \ + GetTensorData(output), GetTensorDims(output)) + switch (input->type) { // Already know in/out types are same. + case kTfLiteFloat32: + if (kernel_type == kReference) { + TF_LITE_SPACE_TO_DEPTH(reference_ops, float); + } else { + TF_LITE_SPACE_TO_DEPTH(optimized_ops, float); + } + break; + case kTfLiteUInt8: + if (kernel_type == kReference) { + TF_LITE_SPACE_TO_DEPTH(reference_ops, uint8_t); + } else { + TF_LITE_SPACE_TO_DEPTH(optimized_ops, uint8_t); + } + break; + case kTfLiteInt32: + if (kernel_type == kReference) { + TF_LITE_SPACE_TO_DEPTH(reference_ops, int32_t); + } else { + TF_LITE_SPACE_TO_DEPTH(optimized_ops, int32_t); + } + break; + case kTfLiteInt64: + if (kernel_type == kReference) { + TF_LITE_SPACE_TO_DEPTH(reference_ops, int64_t); + } else { + TF_LITE_SPACE_TO_DEPTH(optimized_ops, int64_t); + } + break; + default: + context->ReportError(context, "Type not currently supported."); + return kTfLiteError; + } +#undef TF_LITE_SPACE_TO_DEPTH + + return kTfLiteOk; +} + +} // namespace space_to_depth + +TfLiteRegistration* Register_SPACE_TO_DEPTH_REF() { + static TfLiteRegistration r = { + nullptr, nullptr, space_to_depth::Prepare, + space_to_depth::Eval}; + return &r; +} + +TfLiteRegistration* Register_SPACE_TO_DEPTH_GENERIC_OPT() { + static TfLiteRegistration r = { + nullptr, nullptr, space_to_depth::Prepare, + space_to_depth::Eval}; + return &r; +} + +TfLiteRegistration* Register_SPACE_TO_DEPTH() { + return Register_SPACE_TO_DEPTH_GENERIC_OPT(); +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/space_to_depth_test.cc b/tensorflow/contrib/lite/kernels/space_to_depth_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..997f354861a235fb511235e4d64544dc8c3ddb34 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/space_to_depth_test.cc @@ -0,0 +1,101 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAre; +using ::testing::ElementsAreArray; + +class SpaceToDepthOpModel : public SingleOpModel { + public: + SpaceToDepthOpModel(const TensorData& tensor_data, int block_size) { + input_ = AddInput(tensor_data); + output_ = AddOutput(tensor_data); + SetBuiltinOp(BuiltinOperator_SPACE_TO_DEPTH, + BuiltinOptions_SpaceToDepthOptions, + CreateSpaceToDepthOptions(builder_, block_size).Union()); + BuildInterpreter({GetShape(input_)}); + } + + template + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + template + std::vector GetOutput() { + return ExtractVector(output_); + } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + private: + int input_; + int output_; +}; + +TEST(SpaceToDepthOpModel, BadBlockSize) { + EXPECT_DEATH(SpaceToDepthOpModel({TensorType_FLOAT32, {1, 2, 2, 1}}, 3), + "Cannot allocate tensors"); +} + +TEST(SpaceToDepthOpModel, Float32) { + SpaceToDepthOpModel m({TensorType_FLOAT32, {1, 2, 2, 2}}, 2); + m.SetInput({1.4, 2.3, 3.2, 4.1, 5.4, 6.3, 7.2, 8.1}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({1.4, 2.3, 3.2, 4.1, 5.4, 6.3, 7.2, 8.1})); + EXPECT_THAT(m.GetOutputShape(), ElementsAre(1, 1, 1, 8)); +} + +TEST(SpaceToDepthOpModel, Uint8) { + SpaceToDepthOpModel m({TensorType_UINT8, {1, 2, 2, 1}}, 2); + m.SetInput({1, 2, 3, 4}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4})); + EXPECT_THAT(m.GetOutputShape(), ElementsAre(1, 1, 1, 4)); +} + +TEST(SpaceToDepthOpModel, Int32) { + SpaceToDepthOpModel m({TensorType_INT32, {1, 2, 2, 3}}, 2); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12})); + EXPECT_THAT(m.GetOutputShape(), ElementsAre(1, 1, 1, 12)); +} + +TEST(SpaceToDepthOpModel, Int64) { + SpaceToDepthOpModel m({TensorType_INT64, {1, 4, 4, 1}}, 2); + m.SetInput({1, 2, 5, 6, 3, 4, 7, 8, 9, 10, 13, 14, 11, 12, 15, 16}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray( + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16})); + EXPECT_THAT(m.GetOutputShape(), ElementsAre(1, 2, 2, 4)); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/svdf.cc b/tensorflow/contrib/lite/kernels/svdf.cc new file mode 100644 index 0000000000000000000000000000000000000000..72f705fe4242b01c1516c99d3500484e8729fd9a --- /dev/null +++ b/tensorflow/contrib/lite/kernels/svdf.cc @@ -0,0 +1,222 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/activation_functor.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace svdf { + +constexpr int kInputTensor = 0; +constexpr int kWeightsFeatureTensor = 1; +constexpr int kWeightsTimeTensor = 2; +constexpr int kBiasTensor = 3; +constexpr int kStateTensor = 0; +constexpr int KOutputTensor = 1; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + auto* scratch_tensor_index = new int; + context->AddTensors(context, 1, scratch_tensor_index); + return scratch_tensor_index; +} + +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast(buffer); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + int* scratch_tensor_index = reinterpret_cast(node->user_data); + + // Check we have all the inputs and outputs we need. + TF_LITE_ENSURE_EQ(context, node->inputs->size, 4); + TF_LITE_ENSURE_EQ(context, node->outputs->size, 2); + + TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]]; + TfLiteTensor* weights_feature = + &context->tensors[node->inputs->data[kWeightsFeatureTensor]]; + TfLiteTensor* weights_time = + &context->tensors[node->inputs->data[kWeightsTimeTensor]]; + + // Check all the parameters of tensor match within themselves and match the + // input configuration. + const int rank = params->rank; + const int batch_size = input->dims->data[0]; + const int num_filters = weights_feature->dims->data[0]; + TF_LITE_ASSERT_EQ(num_filters % rank, 0); + const int num_units = num_filters / rank; + const int memory_size = weights_time->dims->data[1]; + TF_LITE_ASSERT_EQ(input->dims->data[1], weights_feature->dims->data[1]); + TF_LITE_ASSERT_EQ(weights_time->dims->data[0], num_filters); + + TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); + if (bias) { + TF_LITE_ASSERT_EQ(bias->dims->data[0], num_units); + } + + TfLiteTensor* state = &context->tensors[node->outputs->data[kStateTensor]]; + TfLiteTensor* output = &context->tensors[node->outputs->data[KOutputTensor]]; + + // Resize state. + // For each batch, the state is a 2-D tensor: memory_size * num_filters + // The left most column is used to save current cycle activation. + // The right most column is used to save temporary output which will be + // reduced to num_units outputs. + TfLiteIntArray* state_size_array = TfLiteIntArrayCreate(2); + state_size_array->data[0] = batch_size; + state_size_array->data[1] = memory_size * num_filters; + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, state, state_size_array)); + + // Mark state as a persistent tensor. + state->allocation_type = kTfLiteArenaRwPersistent; + + // Resize output. + TfLiteIntArray* output_size_array = TfLiteIntArrayCreate(2); + output_size_array->data[0] = batch_size; + output_size_array->data[1] = num_units; + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, output, output_size_array)); + + // Resize scratch. + TfLiteIntArrayFree(node->temporaries); + node->temporaries = TfLiteIntArrayCreate(1); + node->temporaries->data[0] = *scratch_tensor_index; + + TfLiteIntArray* scratch_size_array = TfLiteIntArrayCreate(2); + scratch_size_array->data[0] = batch_size; + scratch_size_array->data[1] = num_filters; + + TfLiteTensor* scratch_tensor = &context->tensors[node->temporaries->data[0]]; + scratch_tensor->type = input->type; + scratch_tensor->allocation_type = kTfLiteArenaRw; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_tensor, + scratch_size_array)); + + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + + TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]]; + TfLiteTensor* weights_feature = + &context->tensors[node->inputs->data[kWeightsFeatureTensor]]; + TfLiteTensor* weights_time = + &context->tensors[node->inputs->data[kWeightsTimeTensor]]; + + TfLiteTensor* state = &context->tensors[node->outputs->data[kStateTensor]]; + TfLiteTensor* output = &context->tensors[node->outputs->data[KOutputTensor]]; + TfLiteTensor* scratch = &context->tensors[node->temporaries->data[0]]; + + TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); + + const int rank = params->rank; + const int batch_size = input->dims->data[0]; + const int input_size = input->dims->data[1]; + const int num_filters = weights_feature->dims->data[0]; + const int num_units = num_filters / rank; + const int memory_size = weights_time->dims->data[1]; + + // Clear the activation (state left most column). + // TODO(ghodrat): Add a test which initialize state with invalid values in + // left most column and make sure it passes. + for (int b = 0; b < batch_size; b++) { + float* state_ptr_batch = state->data.f + b * memory_size * num_filters; + for (int c = 0; c < num_filters; c++) { + float* state_ptr = state_ptr_batch + c * memory_size; + state_ptr[memory_size - 1] = 0.0; + } + } + + // Compute conv1d(inputs, weights_feature). + // The state left most column is used to save current cycle activation. This + // is achieved by starting at state->data.f[memory_size - 1] and having the + // stride equal to memory_size. + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + weights_feature->data.f, num_filters, input_size, input->data.f, + batch_size, &state->data.f[memory_size - 1], memory_size); + + // Compute matmul(state, weights_time). + // The right most column is used to save temporary output (with the size of + // num_filters). This is achieved by starting at state->data.f and having the + // stride equal to memory_size. + for (int b = 0; b < batch_size; b++) { + float* state_ptr_batch = state->data.f + b * memory_size * num_filters; + float* scratch_ptr_batch = scratch->data.f + b * num_filters; + tensor_utils::BatchVectorBatchVectorDotProduct( + weights_time->data.f, state_ptr_batch, memory_size, num_filters, + scratch_ptr_batch, /*result_stride=*/1); + } + + // Initialize output with bias if provided. + if (bias) { + tensor_utils::VectorBatchVectorAssign(bias->data.f, num_units, batch_size, + output->data.f); + } else { + tensor_utils::ZeroVector(output->data.f, batch_size * num_units); + } + + // Reduction sum + for (int b = 0; b < batch_size; b++) { + float* output_ptr_batch = output->data.f + b * num_units; + float* scratch_ptr_batch = scratch->data.f + b * num_filters; + tensor_utils::ReductionSumVector(scratch_ptr_batch, output_ptr_batch, + num_units, rank); + } + + // Apply activation. + for (int b = 0; b < batch_size; b++) { + float* output_ptr_batch = output->data.f + b * num_units; + tensor_utils::ApplyActivationToVector(output_ptr_batch, num_units, + params->activation, output_ptr_batch); + } + + // Right shift the state. + for (int b = 0; b < batch_size; b++) { + float* state_ptr_batch = state->data.f + b * memory_size * num_filters; + for (int f = 0; f < num_filters; f++) { + tensor_utils::VectorShiftLeft(state_ptr_batch, memory_size, + /*shift_value=*/0.0); + state_ptr_batch += memory_size; + } + } + return kTfLiteOk; +} + +} // namespace svdf + +TfLiteRegistration* Register_SVDF() { + static TfLiteRegistration r = {svdf::Init, svdf::Free, svdf::Prepare, + svdf::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/svdf_test.cc b/tensorflow/contrib/lite/kernels/svdf_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..4de2ceaf053df31a4bc857fb250db416c071e80f --- /dev/null +++ b/tensorflow/contrib/lite/kernels/svdf_test.cc @@ -0,0 +1,312 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Unit test for TFLite SVDF op. + +#include +#include + +#include +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +static float svdf_input[] = { + 0.12609188, -0.46347019, -0.89598465, + 0.35867718, 0.36897406, 0.73463392, + + 0.14278367, -1.64410412, -0.75222826, + -0.57290924, 0.12729003, 0.7567004, + + 0.49837467, 0.19278903, 0.26584083, + 0.17660543, 0.52949083, -0.77931279, + + -0.11186574, 0.13164264, -0.05349274, + -0.72674477, -0.5683046, 0.55900657, + + -0.68892461, 0.37783599, 0.18263303, + -0.63690937, 0.44483393, -0.71817774, + + -0.81299269, -0.86831826, 1.43940818, + -0.95760226, 1.82078898, 0.71135032, + + -1.45006323, -0.82251364, -1.69082689, + -1.65087092, -1.89238167, 1.54172635, + + 0.03966608, -0.24936394, -0.77526885, + 2.06740379, -1.51439476, 1.43768692, + + 0.11771342, -0.23761693, -0.65898693, + 0.31088525, -1.55601168, -0.87661445, + + -0.89477462, 1.67204106, -0.53235275, + -0.6230064, 0.29819036, 1.06939757, +}; + +static float svdf_golden_output_rank_1[] = { + 0.014899, -0.0517661, -0.143725, -0.00271883, + -0.03004015, 0.09565311, 0.1587342, 0.00784263, + + 0.068281, -0.162217, -0.152268, 0.00323521, + 0.01582633, 0.03858774, -0.03001583, -0.02671271, + + -0.0317821, -0.0333089, 0.0609602, 0.0333759, + -0.01432795, 0.05524484, 0.1101355, -0.02382665, + + -0.00623099, -0.077701, -0.391193, -0.0136691, + -0.02333033, 0.02293761, 0.12338032, 0.04326871, + + 0.201551, -0.164607, -0.179462, -0.0592739, + 0.01064911, -0.17503069, 0.07821996, -0.00224009, + + 0.0886511, -0.0875401, -0.269283, 0.0281379, + -0.02282338, 0.09741908, 0.32973239, 0.12281385, + + -0.201174, -0.586145, -0.628624, -0.0330412, + 0.24780814, -0.39304617, -0.22473189, 0.02589256, + + -0.0839096, -0.299329, 0.108746, 0.109808, + 0.10084175, -0.06416984, 0.28936723, 0.0026358, + + 0.419114, -0.237824, -0.422627, 0.175115, + -0.2314795, -0.18584411, -0.4228974, -0.12928449, + + 0.36726, -0.522303, -0.456502, -0.175475, + 0.17012937, -0.34447709, 0.38505614, -0.28158101, +}; + +static float svdf_golden_output_rank_2[] = { + -0.09623547, -0.10193135, 0.11083051, -0.0347917, + 0.1141196, 0.12965347, -0.12652366, 0.01007236, + + -0.16396809, -0.21247184, 0.11259045, -0.04156673, + 0.10132131, -0.06143532, -0.00924693, 0.10084561, + + 0.01257364, 0.0506071, -0.19287863, -0.07162561, + -0.02033747, 0.22673416, 0.15487903, 0.02525555, + + -0.1411963, -0.37054959, 0.01774767, 0.05867489, + 0.09607603, -0.0141301, -0.08995658, 0.12867066, + + -0.27142537, -0.16955489, 0.18521598, -0.12528358, + 0.00331409, 0.11167502, 0.02218599, -0.07309391, + + 0.09593632, -0.28361851, -0.0773851, 0.17199151, + -0.00075242, 0.33691186, -0.1536046, 0.16572715, + + -0.27916506, -0.27626723, 0.42615682, 0.3225764, + -0.37472126, -0.55655634, -0.05013514, 0.289112, + + -0.24418658, 0.07540751, -0.1940318, -0.08911639, + 0.00732617, 0.46737891, 0.26449674, 0.24888524, + + -0.17225097, -0.54660404, -0.38795233, 0.08389944, + 0.07736043, -0.28260678, 0.15666828, 1.14949894, + + -0.57454878, -0.64704704, 0.73235172, -0.34616736, + 0.21120001, -0.22927976, 0.02455296, -0.35906726, +}; + +// Derived class of SingleOpModel, which is used to test SVDF TFLite op. +class SVDFOpModel : public SingleOpModel { + public: + SVDFOpModel(int batches, int units, int input_size, int memory_size, int rank) + : batches_(batches), + units_(units), + input_size_(input_size), + memory_size_(memory_size), + rank_(rank) { + input_ = AddInput(TensorType_FLOAT32); + weights_feature_ = AddInput(TensorType_FLOAT32); + weights_time_ = AddInput(TensorType_FLOAT32); + bias_ = AddNullInput(); + state_ = AddOutput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp( + BuiltinOperator_SVDF, BuiltinOptions_SVDFOptions, + CreateSVDFOptions(builder_, rank, ActivationFunctionType_NONE).Union()); + BuildInterpreter({ + {batches_, input_size_}, // Input tensor + {units_ * rank, input_size_}, // weights_feature tensor + {units_ * rank, memory_size_}, // weights_time tensor + {units_} // bias tensor + }); + } + + // Populates the weights_feature tensor. + void SetWeightsFeature(std::initializer_list f) { + PopulateTensor(weights_feature_, f); + } + + // Populates the weights_time tensor. + void SetWeightsTime(std::initializer_list f) { + PopulateTensor(weights_time_, f); + } + + // Populates the input tensor. + void SetInput(int offset, float* begin, float* end) { + PopulateTensor(input_, offset, begin, end); + } + + // Resets the state of SVDF op by filling it with 0's. + void ResetState() { + const int zero_buffer_size = rank_ * units_ * batches_ * memory_size_; + std::unique_ptr zero_buffer(new float[zero_buffer_size]); + memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float)); + PopulateTensor(state_, 0, zero_buffer.get(), + zero_buffer.get() + zero_buffer_size); + } + + // Extracts the output tensor from the SVDF op. + std::vector GetOutput() { return ExtractVector(output_); } + + int input_size() { return input_size_; } + int num_units() { return units_; } + int num_batches() { return batches_; } + + private: + int input_; + int weights_feature_; + int weights_time_; + int bias_; + int state_; + int output_; + + int batches_; + int units_; + int input_size_; + int memory_size_; + int rank_; +}; + +TEST(SVDFOpTest, BlackBoxTestRank1) { + SVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3, + /*memory_size=*/10, /*rank=*/1); + svdf.SetWeightsFeature({-0.31930989, -0.36118156, 0.0079667, 0.37613347, + 0.22197971, 0.12416199, 0.27901134, 0.27557442, + 0.3905206, -0.36137494, -0.06634006, -0.10640851}); + + svdf.SetWeightsTime( + {-0.31930989, 0.37613347, 0.27901134, -0.36137494, -0.36118156, + 0.22197971, 0.27557442, -0.06634006, 0.0079667, 0.12416199, + + 0.3905206, -0.10640851, -0.0976817, 0.15294972, 0.39635518, + -0.02702999, 0.39296314, 0.15785322, 0.21931258, 0.31053296, + + -0.36916667, 0.38031587, -0.21580373, 0.27072677, 0.23622236, + 0.34936687, 0.18174365, 0.35907319, -0.17493086, 0.324846, + + -0.10781813, 0.27201805, 0.14324132, -0.23681851, -0.27115166, + -0.01580888, -0.14943552, 0.15465137, 0.09784451, -0.0337657}); + + svdf.ResetState(); + const int svdf_num_batches = svdf.num_batches(); + const int svdf_input_size = svdf.input_size(); + const int svdf_num_units = svdf.num_units(); + const int input_sequence_size = + sizeof(svdf_input) / sizeof(float) / (svdf_input_size * svdf_num_batches); + // Going over each input batch, setting the input tensor, invoking the SVDF op + // and checking the output with the expected golden values. + for (int i = 0; i < input_sequence_size; i++) { + float* batch_start = svdf_input + i * svdf_input_size * svdf_num_batches; + float* batch_end = batch_start + svdf_input_size * svdf_num_batches; + svdf.SetInput(0, batch_start, batch_end); + + svdf.Invoke(); + + float* golden_start = + svdf_golden_output_rank_1 + i * svdf_num_units * svdf_num_batches; + float* golden_end = golden_start + svdf_num_units * svdf_num_batches; + std::vector expected; + expected.insert(expected.end(), golden_start, golden_end); + + EXPECT_THAT(svdf.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); + } +} + +TEST(SVDFOpTest, BlackBoxTestRank2) { + SVDFOpModel svdf(/*batches=*/2, /*units=*/4, /*input_size=*/3, + /*memory_size=*/10, /*rank=*/2); + svdf.SetWeightsFeature({-0.31930989, 0.0079667, 0.39296314, 0.37613347, + 0.12416199, 0.15785322, 0.27901134, 0.3905206, + 0.21931258, -0.36137494, -0.10640851, 0.31053296, + -0.36118156, -0.0976817, -0.36916667, 0.22197971, + 0.15294972, 0.38031587, 0.27557442, 0.39635518, + -0.21580373, -0.06634006, -0.02702999, 0.27072677}); + + svdf.SetWeightsTime( + {-0.31930989, 0.37613347, 0.27901134, -0.36137494, -0.36118156, + 0.22197971, 0.27557442, -0.06634006, 0.0079667, 0.12416199, + + 0.3905206, -0.10640851, -0.0976817, 0.15294972, 0.39635518, + -0.02702999, 0.39296314, 0.15785322, 0.21931258, 0.31053296, + + -0.36916667, 0.38031587, -0.21580373, 0.27072677, 0.23622236, + 0.34936687, 0.18174365, 0.35907319, -0.17493086, 0.324846, + + -0.10781813, 0.27201805, 0.14324132, -0.23681851, -0.27115166, + -0.01580888, -0.14943552, 0.15465137, 0.09784451, -0.0337657, + + -0.14884081, 0.19931212, -0.36002168, 0.34663299, -0.11405486, + 0.12672701, 0.39463779, -0.07886535, -0.06384811, 0.08249187, + + -0.26816407, -0.19905911, 0.29211238, 0.31264046, -0.28664589, + 0.05698794, 0.11613581, 0.14078894, 0.02187902, -0.21781836, + + -0.15567942, 0.08693647, -0.38256618, 0.36580828, -0.22922277, + -0.0226903, 0.12878349, -0.28122205, -0.10850525, -0.11955214, + + 0.27179423, -0.04710215, 0.31069002, 0.22672787, 0.09580326, + 0.08682203, 0.1258215, 0.1851041, 0.29228821, 0.12366763}); + + svdf.ResetState(); + const int svdf_num_batches = svdf.num_batches(); + const int svdf_input_size = svdf.input_size(); + const int svdf_num_units = svdf.num_units(); + const int input_sequence_size = + sizeof(svdf_input) / sizeof(float) / (svdf_input_size * svdf_num_batches); + // Going over each input batch, setting the input tensor, invoking the SVDF op + // and checking the output with the expected golden values. + for (int i = 0; i < input_sequence_size; i++) { + float* batch_start = svdf_input + i * svdf_input_size * svdf_num_batches; + float* batch_end = batch_start + svdf_input_size * svdf_num_batches; + svdf.SetInput(0, batch_start, batch_end); + + svdf.Invoke(); + + float* golden_start = + svdf_golden_output_rank_2 + i * svdf_num_units * svdf_num_batches; + float* golden_end = golden_start + svdf_num_units * svdf_num_batches; + std::vector expected; + expected.insert(expected.end(), golden_start, golden_end); + + EXPECT_THAT(svdf.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); + } +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/test_util.cc b/tensorflow/contrib/lite/kernels/test_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..f716ba8741fd469e7ee405ac300924b53c5c48e5 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/test_util.cc @@ -0,0 +1,183 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/kernels/test_util.h" + +#include "tensorflow/contrib/lite/version.h" +#include "tensorflow/core/platform/logging.h" + +namespace tflite { + +using ::testing::FloatNear; +using ::testing::Matcher; + +namespace { +template +std::pair QuantizationParams(float f_min, float f_max) { + // These are required by many quantized operations. + CHECK_LE(f_min, 0); + CHECK_GE(f_max, 0); + T q_min = std::numeric_limits::min(); + T q_max = std::numeric_limits::max(); + float range = q_max - q_min; + float scale = (f_max - f_min) / range; + int32_t zero_point = std::min( + q_max, + std::max(q_min, static_cast(std::round(q_min - f_min / scale)))); + return {scale, zero_point}; +} +} // namespace + +std::vector> ArrayFloatNear(const std::vector& values, + float max_abs_error) { + std::vector> matchers; + matchers.reserve(values.size()); + for (const float& v : values) { + matchers.emplace_back(FloatNear(v, max_abs_error)); + } + return matchers; +} + +int SingleOpModel::AddTensor(TensorData t) { + int id = tensors_.size(); + + // This is slightly different depending on whether we are adding a + // quantized or a regular tensor. + bool is_quantized = (t.min != 0 || t.max != 0 || t.scale != 0); + + flatbuffers::Offset q_params = 0; + + if (is_quantized) { + if (t.min != 0 || t.max != 0) { + if (t.type == TensorType_UINT8) { + std::tie(t.scale, t.zero_point) = + QuantizationParams(t.min, t.max); + } else if (t.type == TensorType_INT32) { + std::tie(t.scale, t.zero_point) = + QuantizationParams(t.min, t.max); + } else { + LOG(FATAL) << "No support for the requested quantized type"; + } + t.min = 0; + t.max = 0; + } + + q_params = CreateQuantizationParameters( + builder_, /*min=*/0, /*max=*/0, builder_.CreateVector({t.scale}), + builder_.CreateVector({t.zero_point})); + } + + tensors_.push_back(CreateTensor(builder_, builder_.CreateVector({}), + t.type, /*buffer=*/0, + /*name=*/0, q_params)); + + tensor_data_[id] = t; + + return id; +} + +int SingleOpModel::AddInput(const TensorData& t) { + int id = AddTensor(t); + inputs_.push_back(id); + return id; +} + +int SingleOpModel::AddNullInput() { + int id = kOptionalTensor; + inputs_.push_back(id); + return id; +} + +int SingleOpModel::AddOutput(const TensorData& t) { + int id = AddTensor(t); + outputs_.push_back(id); + return id; +} + +void SingleOpModel::SetBuiltinOp(BuiltinOperator type, + BuiltinOptions builtin_options_type, + flatbuffers::Offset builtin_options) { + opcodes_.push_back(CreateOperatorCode(builder_, type, 0)); + operators_.push_back(CreateOperator( + builder_, /*opcode_index=*/0, builder_.CreateVector(inputs_), + builder_.CreateVector(outputs_), builtin_options_type, + builtin_options, + /*custom_options=*/0, CustomOptionsFormat_FLEXBUFFERS)); +} + +void SingleOpModel::SetCustomOp( + const string& name, const std::vector& custom_option, + const std::function& registeration) { + custom_registrations_[name] = registeration; + opcodes_.push_back( + CreateOperatorCodeDirect(builder_, BuiltinOperator_CUSTOM, name.data())); + operators_.push_back(CreateOperator( + builder_, /*opcode_index=*/0, builder_.CreateVector(inputs_), + builder_.CreateVector(outputs_), BuiltinOptions_NONE, 0, + builder_.CreateVector(custom_option), + CustomOptionsFormat_FLEXBUFFERS)); +} + +void SingleOpModel::BuildInterpreter( + std::vector> input_shapes) { + auto opcodes = builder_.CreateVector(opcodes_); + auto operators = builder_.CreateVector(operators_); + auto tensors = builder_.CreateVector(tensors_); + auto inputs = builder_.CreateVector(inputs_); + auto outputs = builder_.CreateVector(outputs_); + // Create a single subgraph + std::vector> subgraphs; + auto subgraph = CreateSubGraph(builder_, tensors, inputs, outputs, operators); + subgraphs.push_back(subgraph); + auto subgraphs_flatbuffer = builder_.CreateVector(subgraphs); + + std::vector> buffers_vec; + auto buffers = builder_.CreateVector(buffers_vec); + auto description = builder_.CreateString("programmatic model"); + builder_.Finish(CreateModel(builder_, TFLITE_SCHEMA_VERSION, opcodes, + subgraphs_flatbuffer, description, buffers)); + + auto* model = GetModel(builder_.GetBufferPointer()); + + ops::builtin::BuiltinOpResolver builtins; + for (const auto& reg : custom_registrations_) { + builtins.AddCustom(reg.first.data(), reg.second()); + } + InterpreterBuilder(model, builtins)(&interpreter_); + + CHECK(interpreter_ != nullptr); + + int i = 0; + for (const auto& shape : input_shapes) { + int input_idx = interpreter_->inputs()[i++]; + if (input_idx == kOptionalTensor) continue; + CHECK(interpreter_->ResizeInputTensor(input_idx, shape) == kTfLiteOk); + } + CHECK(interpreter_->AllocateTensors() == kTfLiteOk) + << "Cannot allocate tensors"; +} + +void SingleOpModel::Invoke() { CHECK(interpreter_->Invoke() == kTfLiteOk); } + +int32_t SingleOpModel::GetTensorSize(int index) const { + TfLiteTensor* t = interpreter_->tensor(index); + CHECK(t); + int total_size = 1; + for (int i = 0; i < t->dims->size; ++i) { + total_size *= t->dims->data[i]; + } + return total_size; +} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/test_util.h b/tensorflow/contrib/lite/kernels/test_util.h new file mode 100644 index 0000000000000000000000000000000000000000..adcdeddbfc9d3b3313b09cd6310171160e0be645 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/test_util.h @@ -0,0 +1,197 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_TEST_UTIL_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_TEST_UTIL_H_ + +#include + +#include +#include + +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/string_util.h" +#include "tensorflow/contrib/lite/testing/util.h" +#include "tensorflow/core/platform/logging.h" + +namespace tflite { + +// A gmock matcher that check that elements of a float vector match to a given +// tolerance. +std::vector<::testing::Matcher> ArrayFloatNear( + const std::vector& values, float max_abs_error = 1e-5); + +template +inline std::vector Quantize(const std::vector& data, float scale, + int32_t zero_point) { + std::vector q; + for (float f : data) { + q.push_back(std::max( + std::numeric_limits::min(), + std::min(std::numeric_limits::max(), + static_cast(std::round(zero_point + (f / scale)))))); + } + return q; +} + +template +inline std::vector Dequantize(const std::vector& data, float scale, + int32_t zero_point) { + std::vector f; + for (T q : data) { + f.push_back(scale * (q - zero_point)); + } + return f; +} + +// A test model that contains a single operator. All operator inputs and +// output are external to the model, so the tests can directly access them. +// Typical usage: +// SingleOpModel m; +// int a = m.AddInput({TensorType_FLOAT32, a_shape}); +// int b = m.AddInput({TensorType_FLOAT32, b_shape}); +// int c = m.AddOutput({TensorType_FLOAT32, {}}); +// m.SetBuiltinOp(...); +// m.BuildInterpreter({GetShape(a), GetShape(b)}); +// m.PopulateTensor(a, {...}); +// m.PopulateTensor(b, {...}); +// m.Invoke(); +// EXPECT_THAT(m.ExtractVector(c), ArrayFloatNear({...})); +// + +// A helper struct to construct test tensors. This is particularly useful for +// quantized tensor which must have their scale and zero_point defined before +// the actual data is known. This mimics what happens in practice: quantization +// parameters are calculate during training. +struct TensorData { + TensorType type; + std::vector shape; + float min; + float max; + float scale; + int32_t zero_point; +}; + +class SingleOpModel { + public: + SingleOpModel() {} + ~SingleOpModel() {} + + // Copying or assignment is disallowed to simplify ownership semantics. + SingleOpModel(const SingleOpModel&) = delete; + SingleOpModel& operator=(const SingleOpModel&) = delete; + + // Add a TensorType input tensor and return its index. + int AddInput(TensorType type) { return AddInput(TensorData{type}); } + int AddInput(const TensorData& t); + + // Add a null input tensor (optional input) and return kOptionalTensor. + int AddNullInput(); + + // Add a TensorType output tensor and return its index. + int AddOutput(TensorType type) { return AddOutput(TensorData{type}); } + int AddOutput(const TensorData& t); + + template + void QuantizeAndPopulate(int index, std::initializer_list data) { + TfLiteTensor* t = interpreter_->tensor(index); + auto q = Quantize(data, t->params.scale, t->params.zero_point); + PopulateTensor(index, 0, q.data(), q.data() + q.size()); + } + + const std::vector& GetShape(int id) { return tensor_data_.at(id).shape; } + + float GetScale(int id) { return tensor_data_.at(id).scale; } + int32_t GetZeroPoint(int id) { return tensor_data_.at(id).zero_point; } + + // Define the operator in this model. + void SetBuiltinOp(BuiltinOperator type, BuiltinOptions builtin_options_type, + flatbuffers::Offset builtin_options); + void SetCustomOp(const string& name, + const std::vector& custom_option, + const std::function& registeration); + + // Build the interpreter for this model. Also, resize and allocate all + // tensors given the shapes of the inputs. + void BuildInterpreter(std::vector> input_shapes); + + void Invoke(); + + void PopulateStringTensor(int index, const std::vector& content) { + auto tensor = interpreter_->tensor(index); + DynamicBuffer buf; + for (const string& s : content) { + buf.AddString(s.data(), s.length()); + } + buf.WriteToTensor(tensor); + } + + // Populate the tensor given its index. + template + void PopulateTensor(int index, std::initializer_list data) { + T* v = interpreter_->typed_tensor(index); + CHECK(v) << "No tensor with index '" << index << "'."; + for (T f : data) { + *v = f; + ++v; + } + } + + // Partially populate the tensor, starting at the given offset. + template + void PopulateTensor(int index, int offset, T* begin, T* end) { + T* v = interpreter_->typed_tensor(index); + memcpy(v + offset, begin, (end - begin) * sizeof(T)); + } + + // Return a vector with the flattened contents of a tensor. + template + std::vector ExtractVector(int index) { + T* v = interpreter_->typed_tensor(index); + CHECK(v); + return std::vector(v, v + GetTensorSize(index)); + } + + std::vector GetTensorShape(int index) { + std::vector result; + TfLiteTensor* t = interpreter_->tensor(index); + for (int i = 0; i < t->dims->size; ++i) { + result.push_back(t->dims->data[i]); + } + return result; + } + + protected: + int32_t GetTensorSize(int index) const; + + flatbuffers::FlatBufferBuilder builder_; + std::unique_ptr interpreter_; + + private: + int AddTensor(TensorData t); + + std::map tensor_data_; + std::vector inputs_; + std::vector outputs_; + std::vector> tensors_; + std::vector> opcodes_; + std::vector> operators_; + std::map> custom_registrations_; +}; + +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_TEST_UTIL_H_ diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc new file mode 100644 index 0000000000000000000000000000000000000000..54efad94afa73ccdfb3f26513e934c7eb5001400 --- /dev/null +++ b/tensorflow/contrib/lite/model.cc @@ -0,0 +1,700 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/allocation.h" +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/nnapi_delegate.h" +#include "tensorflow/contrib/lite/version.h" + +namespace tflite { + +namespace { +inline const tflite::Model* VerifyAndGetModel(const void* buf, size_t len) { + ::flatbuffers::Verifier verifier(static_cast(buf), len); + if (VerifyModelBuffer(verifier)) { + return ::tflite::GetModel(buf); + } else { + return nullptr; + } +} +} // namespace + +const char* kEmptyTensorName = ""; + +std::unique_ptr FlatBufferModel::BuildFromFile( + const char* filename, ErrorReporter* error_reporter) { + std::unique_ptr model; + model.reset(new FlatBufferModel(filename, /*mmap_file=*/true, error_reporter, + /*use_nnapi=*/true)); + if (!model->initialized()) model.reset(); + return model; +} + +std::unique_ptr FlatBufferModel::BuildFromBuffer( + const char* buffer, size_t buffer_size, ErrorReporter* error_reporter) { + std::unique_ptr model; + model.reset(new FlatBufferModel(buffer, buffer_size, error_reporter)); + if (!model->initialized()) model.reset(); + return model; +} + +std::unique_ptr FlatBufferModel::BuildFromModel( + const tflite::Model* model_spec, ErrorReporter* error_reporter) { + std::unique_ptr model; + model.reset(new FlatBufferModel(model_spec, error_reporter)); + if (!model->initialized()) model.reset(); + return model; +} + +FlatBufferModel::FlatBufferModel(const char* filename, bool mmap_file, + ErrorReporter* error_reporter, bool use_nnapi) + : error_reporter_(error_reporter ? error_reporter + : DefaultErrorReporter()) { + if (mmap_file) { + if (use_nnapi && NNAPIExists()) + allocation_ = new NNAPIAllocation(filename, error_reporter); + else + allocation_ = new MMAPAllocation(filename, error_reporter); + } else { + allocation_ = new FileCopyAllocation(filename, error_reporter); + } + if (!allocation_->valid()) return; + if (!CheckModelIdentifier()) return; + + model_ = VerifyAndGetModel(allocation_->base(), allocation_->bytes()); +} + +bool FlatBufferModel::CheckModelIdentifier() const { + if (!tflite::ModelBufferHasIdentifier(allocation_->base())) { + const char* ident = flatbuffers::GetBufferIdentifier(allocation_->base()); + error_reporter_->Report( + "Model provided has model identifier '%c%c%c%c', should be '%s'\n", + ident[0], ident[1], ident[2], ident[3], tflite::ModelIdentifier()); + return false; + } + return true; +} + +FlatBufferModel::FlatBufferModel(const char* ptr, size_t num_bytes, + ErrorReporter* error_reporter) + : error_reporter_(error_reporter ? error_reporter + : DefaultErrorReporter()) { + allocation_ = new MemoryAllocation(ptr, num_bytes, error_reporter); + if (!allocation_->valid()) return; + + model_ = VerifyAndGetModel(allocation_->base(), allocation_->bytes()); +} + +FlatBufferModel::FlatBufferModel(const Model* model, + ErrorReporter* error_reporter) + : error_reporter_(error_reporter ? error_reporter + : DefaultErrorReporter()) { + model_ = model; +} + +FlatBufferModel::~FlatBufferModel() { delete allocation_; } + +InterpreterBuilder::InterpreterBuilder(const FlatBufferModel& model, + const OpResolver& op_resolver) + : model_(model.GetModel()), + op_resolver_(op_resolver), + error_reporter_(model.error_reporter()), + allocation_(model.allocation()) {} + +InterpreterBuilder::InterpreterBuilder(const ::tflite::Model* model, + const OpResolver& op_resolver, + ErrorReporter* error_reporter) + : model_(model), + op_resolver_(op_resolver), + error_reporter_(error_reporter ? error_reporter + : DefaultErrorReporter()) {} + +TfLiteStatus InterpreterBuilder::BuildLocalIndexToRegistrationMapping() { + TfLiteStatus status = kTfLiteOk; + auto opcodes = model_->operator_codes(); + for (const OperatorCode* opcode : *opcodes) { + TfLiteRegistration* registration = nullptr; + + if (opcode->builtin_code() != BuiltinOperator_CUSTOM) { + auto x = opcode->builtin_code(); + flatbuffer_op_index_to_registration_types_.push_back(x); + registration = op_resolver_.FindOp(x); + if (registration == nullptr) { + error_reporter_->Report("Didn't find op for builtin opcode '%s'\n", + EnumNameBuiltinOperator(x)); + status = kTfLiteError; + } + } else if (!opcode->custom_code()) { + error_reporter_->Report( + "Operator with builtin_code==0 has no custom_code.\n"); + status = kTfLiteError; + } else { + const char* name = opcode->custom_code()->c_str(); + registration = op_resolver_.FindOp(name); + flatbuffer_op_index_to_registration_types_.push_back( + BuiltinOperator_CUSTOM); + if (registration == nullptr) { + error_reporter_->Report("Didn't find custom op for name '%s'\n", name); + status = kTfLiteError; + } + } + flatbuffer_op_index_to_registration_.push_back(registration); + } + return status; +} + +namespace { +template +std::vector FlatBufferIntArrayToVector(T* flat_array) { + std::vector ret(flat_array->Length()); + for (int i = 0; i < flat_array->Length(); i++) { + ret[i] = flat_array->Get(i); + } + return ret; +} + +// Allocate a structure using C malloc, but make sure the structure is a +// POD structure that doesn't require constructors to run. The reason we do +// this, is that Interpreter's C extension part will take ownership and wants +// to use malloc() and free(). +template +T* MallocPOD() { + static_assert(std::is_pod::value, "Builtin data structure must be POD."); + return static_cast(malloc(sizeof(T))); +} + +// Parse the appropriate data out of the op. +// +// This handles builtin data explicitly as there are flatbuffer schemas. +// +// Returns memory that must be feed. +void* ParseOpData(const Operator* op, BuiltinOperator op_type, + ErrorReporter* error_reporter) { + auto parse_padding = [](Padding padding) { + switch (padding) { + case Padding_SAME: + return kTfLitePaddingSame; + case Padding_VALID: + return kTfLitePaddingValid; + } + return kTfLitePaddingUnknown; + }; + auto parse_activation = [](ActivationFunctionType activation) { + switch (activation) { + case ActivationFunctionType_NONE: + return kTfLiteActNone; + case ActivationFunctionType_RELU: + return kTfLiteActRelu; + case ActivationFunctionType_RELU1: + return kTfLiteActRelu1; + case ActivationFunctionType_RELU6: + return kTfLiteActRelu6; + case ActivationFunctionType_TANH: + return kTfLiteActTanh; + case ActivationFunctionType_SIGN_BIT: + return kTfLiteActSignBit; + } + return kTfLiteActNone; + }; + auto parseLSHProjectionType = [](LSHProjectionType type) { + switch (type) { + case LSHProjectionType_SPARSE: + return kTfLiteLshProjectionSparse; + case LSHProjectionType_DENSE: + return kTfLiteLshProjectionDense; + default: + return kTfLiteLshProjectionUnknown; + } + }; + auto parseCombinerType = [](CombinerType type) { + switch (type) { + case CombinerType_MEAN: + return kTfLiteCombinerTypeMean; + case CombinerType_SQRTN: + return kTfLiteCombinerTypeSqrtn; + case CombinerType_SUM: + default: + return kTfLiteCombinerTypeSum; + } + }; + + void* builtin_data = nullptr; + switch (op_type) { + case BuiltinOperator_CALL: + // TODO(aselle): Implement call in BuiltinOptions, but nullptrs are + // ok for now, since there is no call implementation either. + break; + case BuiltinOperator_CUSTOM: + break; + case BuiltinOperator_CONV_2D: { + TfLiteConvParams* params = MallocPOD(); + if (auto* conv_params = op->builtin_options_as_Conv2DOptions()) { + params->padding = parse_padding(conv_params->padding()); + params->stride_width = conv_params->stride_w(); + params->stride_height = conv_params->stride_h(); + params->activation = + parse_activation(conv_params->fused_activation_function()); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_TANH: + case BuiltinOperator_LOGISTIC: + case BuiltinOperator_RELU: + case BuiltinOperator_RELU1: + case BuiltinOperator_RELU6: + case BuiltinOperator_CONCAT_EMBEDDINGS: + break; + case BuiltinOperator_LSH_PROJECTION: { + TfLiteLSHProjectionParams* params = + MallocPOD(); + if (auto* lshParams = op->builtin_options_as_LSHProjectionOptions()) { + params->type = parseLSHProjectionType(lshParams->type()); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_AVERAGE_POOL_2D: + case BuiltinOperator_MAX_POOL_2D: + case BuiltinOperator_L2_POOL_2D: { + TfLitePoolParams* params = MallocPOD(); + if (auto* pool_params = op->builtin_options_as_Pool2DOptions()) { + params->padding = parse_padding(pool_params->padding()); + params->stride_width = pool_params->stride_w(); + params->stride_height = pool_params->stride_h(); + params->filter_width = pool_params->filter_width(); + params->filter_height = pool_params->filter_height(); + params->activation = + parse_activation(pool_params->fused_activation_function()); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_DEPTHWISE_CONV_2D: { + TfLiteDepthwiseConvParams* params = + MallocPOD(); + if (auto* conv_params = op->builtin_options_as_DepthwiseConv2DOptions()) { + params->padding = parse_padding(conv_params->padding()); + params->stride_width = conv_params->stride_w(); + params->stride_height = conv_params->stride_h(); + params->depth_multiplier = conv_params->depth_multiplier(); + params->activation = + parse_activation(conv_params->fused_activation_function()); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_SVDF: { + TfLiteSVDFParams* params = MallocPOD(); + if (auto* svdf_params = op->builtin_options_as_SVDFOptions()) { + params->rank = svdf_params->rank(); + params->activation = + parse_activation(svdf_params->fused_activation_function()); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_RNN: { + TfLiteRNNParams* params = MallocPOD(); + if (auto* rnn_params = op->builtin_options_as_RNNOptions()) { + params->activation = + parse_activation(rnn_params->fused_activation_function()); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_EMBEDDING_LOOKUP: + // no-op. + break; + case BuiltinOperator_EMBEDDING_LOOKUP_SPARSE: { + TfLiteEmbeddingLookupSparseParams* params = + MallocPOD(); + if (auto* embedding_params = + op->builtin_options_as_EmbeddingLookupSparseOptions()) { + params->combiner = parseCombinerType(embedding_params->combiner()); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_FULLY_CONNECTED: { + TfLiteFullyConnectedParams* params = + MallocPOD(); + if (auto* fully_connected_params = + op->builtin_options_as_FullyConnectedOptions()) { + params->activation = parse_activation( + fully_connected_params->fused_activation_function()); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_HASHTABLE_LOOKUP: + // no-op. + break; + case BuiltinOperator_SOFTMAX: { + TfLiteSoftmaxParams* params = MallocPOD(); + if (auto* softmax_params = op->builtin_options_as_SoftmaxOptions()) { + params->beta = softmax_params->beta(); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_CONCATENATION: { + TfLiteConcatenationParams* params = + MallocPOD(); + if (auto* concatenation_params = + op->builtin_options_as_ConcatenationOptions()) { + params->activation = + parse_activation(concatenation_params->fused_activation_function()); + params->axis = concatenation_params->axis(); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_MUL: { + auto* params = MallocPOD(); + if (auto* schema_params = op->builtin_options_as_MulOptions()) { + params->activation = + parse_activation(schema_params->fused_activation_function()); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_ADD: { + auto* params = MallocPOD(); + if (auto* schema_params = op->builtin_options_as_AddOptions()) { + params->activation = + parse_activation(schema_params->fused_activation_function()); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_L2_NORMALIZATION: { + auto* params = MallocPOD(); + if (auto* schema_params = op->builtin_options_as_L2NormOptions()) { + params->activation = + parse_activation(schema_params->fused_activation_function()); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION: { + auto* params = MallocPOD(); + if (auto* schema_params = + op->builtin_options_as_LocalResponseNormalizationOptions()) { + params->radius = schema_params->radius(); + params->bias = schema_params->bias(); + params->alpha = schema_params->alpha(); + params->beta = schema_params->beta(); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_LSTM: { + TfLiteLSTMParams* params = MallocPOD(); + if (auto* lstm_params = op->builtin_options_as_LSTMOptions()) { + params->activation = + parse_activation(lstm_params->fused_activation_function()); + params->cell_clip = lstm_params->cell_clip(); + params->proj_clip = lstm_params->proj_clip(); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_RESIZE_BILINEAR: { + auto* params = MallocPOD(); + if (auto* schema_params = + op->builtin_options_as_ResizeBilinearOptions()) { + params->new_height = schema_params->new_height(); + params->new_width = schema_params->new_width(); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_RESHAPE: { + auto* params = MallocPOD(); + if (auto* schema_params = op->builtin_options_as_ReshapeOptions()) { + auto* new_shape = schema_params->new_shape(); + if (!new_shape) { + error_reporter->Report("No new_shape provided for Reshape\n"); + } else { + params->num_dimensions = new_shape->Length(); + if (params->num_dimensions > sizeof(params->shape) / sizeof(int)) { + error_reporter->Report( + "Found too many dimensions in Reshape's new_shape\n"); + } else { + for (int i = 0; i < params->num_dimensions; ++i) { + params->shape[i] = new_shape->Get(i); + } + } + } + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_SKIP_GRAM: { + TfLiteSkipGramParams* params = MallocPOD(); + if (auto* skip_gram_params = op->builtin_options_as_SkipGramOptions()) { + params->ngram_size = skip_gram_params->ngram_size(); + params->max_skip_size = skip_gram_params->max_skip_size(); + params->include_all_ngrams = skip_gram_params->include_all_ngrams(); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_SPACE_TO_DEPTH: { + auto* params = MallocPOD(); + if (auto* schema_params = op->builtin_options_as_SpaceToDepthOptions()) { + params->block_size = schema_params->block_size(); + } + builtin_data = reinterpret_cast(params); + break; + } + } + return builtin_data; +} + +} // namespace + +TfLiteStatus InterpreterBuilder::ParseNodes( + const flatbuffers::Vector>* operators, + Interpreter* interpreter) { + TfLiteStatus status = kTfLiteOk; + for (int i = 0; i < operators->Length(); ++i) { + const auto* op = operators->Get(i); + int index = op->opcode_index(); + if (index < 0 || index >= flatbuffer_op_index_to_registration_.size()) { + error_reporter_->Report("Missing registration for opcode_index %d\n", + index); + status = kTfLiteError; + continue; + } + const TfLiteRegistration* reg = + flatbuffer_op_index_to_registration_[op->opcode_index()]; + if (reg == nullptr) { + error_reporter_->Report("Skipping op for opcode_index %d\n", index); + status = kTfLiteError; + continue; + } + + auto op_type = + flatbuffer_op_index_to_registration_types_[op->opcode_index()]; + if (op_type != BuiltinOperator_CUSTOM && op->custom_options()) { + error_reporter_->Report( + "Found builtin operator %s with custom options.\n", + EnumNameBuiltinOperator(op_type)); + } + if (op->custom_options()) { + interpreter->AddNodeWithParameters( + FlatBufferIntArrayToVector(op->inputs()), + FlatBufferIntArrayToVector(op->outputs()), + reinterpret_cast(op->custom_options()->data()), + op->custom_options()->size(), nullptr, reg); + } else { + interpreter->AddNodeWithParameters( + FlatBufferIntArrayToVector(op->inputs()), + FlatBufferIntArrayToVector(op->outputs()), nullptr, 0, + ParseOpData(op, op_type, error_reporter_), reg); + } + } + + return status; +} + +TfLiteStatus InterpreterBuilder::ParseTensors( + const flatbuffers::Vector>* buffers, + const flatbuffers::Vector>* tensors, + Interpreter* interpreter) { + TfLiteStatus status = kTfLiteOk; + + // A little helper to get the names of inputs and outputs. Note that they + // must outlive the interpreter. + auto get_name = [](const tflite::Tensor* t) -> const char* { + auto name = t->name(); + if (name) return name->c_str(); + return kEmptyTensorName; + }; + + for (int i = 0; i < tensors->Length(); ++i) { + const auto* tensor = tensors->Get(i); + std::vector dims = FlatBufferIntArrayToVector(tensor->shape()); + + TfLiteQuantizationParams quantization; + quantization.scale = 0; + quantization.zero_point = 0; + auto* q_params = tensor->quantization(); + if (q_params) { + // Note that the schema could hold per-channel quantization parameters + // but we really only support one value for the whole tensor. + // TODO(aselle): This breaks as well if these are nullptr's. + // TODO(aselle): This assumes non per-channel quantization. + if (q_params->scale()) quantization.scale = q_params->scale()->Get(0); + if (q_params->zero_point()) + quantization.zero_point = q_params->zero_point()->Get(0); + } + + TfLiteType type; + switch (tensor->type()) { + case TensorType_FLOAT32: + type = kTfLiteFloat32; + break; + case TensorType_INT32: + type = kTfLiteInt32; + break; + case TensorType_UINT8: + type = kTfLiteUInt8; + break; + case TensorType_INT64: + type = kTfLiteInt64; + break; + case TensorType_STRING: + type = kTfLiteString; + break; + default: + // tensorType = ArrayType::NONE; + error_reporter_->Report("Unimplemented data type %s (%d) in tensor\n", + EnumNameTensorType(tensor->type()), + tensor->type()); + status = kTfLiteError; + continue; + } + auto get_readonly_data = [&](const char** buffer_data, + size_t* buffer_size) { + // TODO(aselle): Check what happens if we have an unspecified size + // constant. + *buffer_data = nullptr; + if (tensor->buffer() == 0) return kTfLiteOk; + if (tensor->buffer() >= buffers->size()) { + error_reporter_->Report( + "Tensor %d specifies out of range buffer %d (only %d buffers).\n", + i, tensor->buffer(), buffers->size()); + return kTfLiteError; + } + if (auto* buffer = (*buffers)[tensor->buffer()]) { + if (auto* array = buffer->data()) { + if (size_t size = array->size()) { + *buffer_size = size; + *buffer_data = reinterpret_cast(array->data()); + return kTfLiteOk; + } + } + } + return kTfLiteOk; + }; + size_t buffer_size = 0; + const char* buffer_ptr; + TF_LITE_ENSURE_STATUS(get_readonly_data(&buffer_ptr, &buffer_size)); + + if (buffer_ptr) { + if (interpreter->SetTensorParametersReadOnly( + i, type, get_name(tensor), dims, quantization, buffer_ptr, + buffer_size, allocation_) != kTfLiteOk) { + error_reporter_->Report("Tensor %d is invalidly specified in schema.\n", + i); + status = kTfLiteError; + } + } else { + if (interpreter->SetTensorParametersReadWrite( + i, type, get_name(tensor), dims, quantization) != kTfLiteOk) { + error_reporter_->Report("Tensor %d is invalidly specified in schema.\n", + i); + status = kTfLiteError; + } + } + } + + return status; +} + +TfLiteStatus InterpreterBuilder::operator()( + std::unique_ptr* interpreter) { + if (!interpreter) { + error_reporter_->Report( + "Null output pointer passed to InterpreterBuilder."); + return kTfLiteError; + } + + // Safe exit by deleting partially created interpreter, to reduce verbosity + // on error conditions. Use by return cleanup_on_error(); + auto cleanup_and_error = [&interpreter]() { + interpreter->reset(); + return kTfLiteError; + }; + + if (!model_) { + error_reporter_->Report("Null pointer passed in as model."); + return cleanup_and_error(); + } + + if (model_->version() != TFLITE_SCHEMA_VERSION) { + error_reporter_->Report( + "Model provided is schema version %d not equal " + "to supported version %d.\n", + model_->version(), TFLITE_SCHEMA_VERSION); + return cleanup_and_error(); + } + + if (BuildLocalIndexToRegistrationMapping() != kTfLiteOk) { + error_reporter_->Report("Registration failed.\n"); + return cleanup_and_error(); + } + + // Flatbuffer model schemas define a list of opcodes independent of the graph. + // We first map those to registrations. This reduces string lookups for custom + // ops since we only do it once per custom op rather than once per custom op + // invocation in the model graph. + // Construct interpreter with correct number of tensors and operators. + auto* subgraphs = model_->subgraphs(); + auto* buffers = model_->buffers(); + if (subgraphs->size() != 1) { + error_reporter_->Report("Only 1 subgraph is currently supported.\n"); + return cleanup_and_error(); + } + const tflite::SubGraph* subgraph = (*subgraphs)[0]; + auto operators = subgraph->operators(); + auto tensors = subgraph->tensors(); + if (!operators || !tensors || !buffers) { + error_reporter_->Report( + "Did not get operators, tensors, or buffers in input flat buffer.\n"); + return cleanup_and_error(); + } + interpreter->reset(new Interpreter(error_reporter_)); + if ((**interpreter).AddTensors(tensors->Length()) != kTfLiteOk) { + return cleanup_and_error(); + } + + // Parse inputs/outputs + (**interpreter).SetInputs(FlatBufferIntArrayToVector(subgraph->inputs())); + (**interpreter).SetOutputs(FlatBufferIntArrayToVector(subgraph->outputs())); + + // Finally setup nodes and tensors + if (ParseNodes(operators, interpreter->get()) != kTfLiteOk) + return cleanup_and_error(); + if (ParseTensors(buffers, tensors, interpreter->get()) != kTfLiteOk) + return cleanup_and_error(); + + return kTfLiteOk; +} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/model.h b/tensorflow/contrib/lite/model.h new file mode 100644 index 0000000000000000000000000000000000000000..e0c96f7f0480cd3146f95a22957477809cf0096d --- /dev/null +++ b/tensorflow/contrib/lite/model.h @@ -0,0 +1,176 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Deserialization infrastructure for tflite. Provides functionality +// to go from a serialized tflite model in flatbuffer format to an +// interpreter. +// +// using namespace tflite; +// StderrReporter error_reporter; +// auto model = FlatBufferModel::BuildFromFile("interesting_model.tflite", +// &error_reporter); +// MyOpResolver resolver; // You need to subclass OpResolver to provide +// // implementations. +// InterpreterBuilder builder(*model, resolver); +// std::unique_ptr interpreter; +// if(builder(&interpreter) == kTfLiteOk) { +// .. run model inference with interpreter +// } +// +// OpResolver must be defined to provide your kernel implementations to the +// interpreter. This is environment specific and may consist of just the builtin +// ops, or some custom operators you defined to extend tflite. +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODEL_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODEL_H_ + +#include +#include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/schema/schema_generated.h" + +namespace tflite { + +// An RAII object that represents a read-only tflite model, copied from disk, +// or mmapped. This uses flatbuffers as the serialization format. +class FlatBufferModel { + public: + // Builds a model based on a file. Returns a nullptr in case of failure. + static std::unique_ptr BuildFromFile( + const char* filename, + ErrorReporter* error_reporter = DefaultErrorReporter()); + + // Builds a model based on a pre-loaded flatbuffer. The caller retains + // ownership of the buffer and should keep it alive until the returned object + // is destroyed. Returns a nullptr in case of failure. + static std::unique_ptr BuildFromBuffer( + const char* buffer, size_t buffer_size, + ErrorReporter* error_reporter = DefaultErrorReporter()); + + // Builds a model directly from a flatbuffer pointer. The caller retains + // ownership of the buffer and should keep it alive until the returned object + // is destroyed. Returns a nullptr in case of failure. + static std::unique_ptr BuildFromModel( + const tflite::Model* model_spec, + ErrorReporter* error_reporter = DefaultErrorReporter()); + + // Releases memory or unmaps mmaped meory. + ~FlatBufferModel(); + + // Copying or assignment is disallowed to simplify ownership semantics. + FlatBufferModel(const FlatBufferModel&) = delete; + FlatBufferModel& operator=(const FlatBufferModel&) = delete; + + bool initialized() const { return model_ != nullptr; } + const tflite::Model* operator->() const { return model_; } + const tflite::Model* GetModel() const { return model_; } + ErrorReporter* error_reporter() const { return error_reporter_; } + const Allocation* allocation() const { return allocation_; } + + // Returns true if the model identifier is correct (otherwise false and + // reports an error). + bool CheckModelIdentifier() const; + + private: + // Loads a model from `filename`. If `mmap_file` is true then use mmap, + // otherwise make a copy of the model in a buffer. + // + // Note, if `error_reporter` is null, then a DefaultErrorReporter() will be + // used. + explicit FlatBufferModel( + const char* filename, bool mmap_file = true, + ErrorReporter* error_reporter = DefaultErrorReporter(), + bool use_nnapi = false); + + // Loads a model from `ptr` and `num_bytes` of the model file. The `ptr` has + // to remain alive and unchanged until the end of this flatbuffermodel's + // lifetime. + // + // Note, if `error_reporter` is null, then a DefaultErrorReporter() will be + // used. + FlatBufferModel(const char* ptr, size_t num_bytes, + ErrorReporter* error_reporter = DefaultErrorReporter()); + + // Loads a model from Model flatbuffer. The `model` has to remain alive and + // unchanged until the end of this flatbuffermodel's lifetime. + FlatBufferModel(const Model* model, ErrorReporter* error_reporter); + + // Flatbuffer traverser pointer. (Model* is a pointer that is within the + // allocated memory of the data allocated by allocation's internals. + const tflite::Model* model_ = nullptr; + ErrorReporter* error_reporter_; + Allocation* allocation_ = nullptr; +}; + +// Abstract interface that returns TfLiteRegistrations given op codes or custom +// op names. This is the mechanism that ops being referenced in the flatbuffer +// model are mapped to executable function pointers (TfLiteRegistrations). +class OpResolver { + public: + // Finds the op registration for a builtin operator by enum code. + virtual TfLiteRegistration* FindOp(tflite::BuiltinOperator op) const = 0; + // Finds the op registration of a custom operator by op name. + virtual TfLiteRegistration* FindOp(const char* op) const = 0; + virtual ~OpResolver() {} +}; + +// Build an interpreter capable of interpreting `model`. +// +// model: a scoped model whose lifetime must be at least as long as +// the interpreter. In principle multiple interpreters can be made from +// a single model. +// op_resolver: An instance that implements the Resolver interface which maps +// custom op names and builtin op codes to op registrations. +// reportError: a functor that is called to report errors that handles +// printf var arg semantics. The lifetime of the reportError object must +// be greater than or equal to the Interpreter created by operator(). +// +// Returns a kTfLiteOk when successful and sets interpreter to a valid +// Interpreter. Note: the user must ensure the model lifetime is at least as +// long as interpreter's lifetime. +class InterpreterBuilder { + public: + InterpreterBuilder(const FlatBufferModel& model, + const OpResolver& op_resolver); + // Builds an interpreter given only the raw flatbuffer Model object (instead + // of a FlatBufferModel). Mostly used for testing. + // If `error_reporter` is null, then DefaultErrorReporter() is used. + InterpreterBuilder(const ::tflite::Model* model, + const OpResolver& op_resolver, + ErrorReporter* error_reporter = DefaultErrorReporter()); + InterpreterBuilder(const InterpreterBuilder&) = delete; + InterpreterBuilder& operator=(const InterpreterBuilder&) = delete; + TfLiteStatus operator()(std::unique_ptr* interpreter); + + private: + TfLiteStatus BuildLocalIndexToRegistrationMapping(); + TfLiteStatus ParseNodes( + const flatbuffers::Vector>* operators, + Interpreter* interpreter); + TfLiteStatus ParseTensors( + const flatbuffers::Vector>* buffers, + const flatbuffers::Vector>* tensors, + Interpreter* interpreter); + + const ::tflite::Model* model_; + const OpResolver& op_resolver_; + ErrorReporter* error_reporter_; + + std::vector flatbuffer_op_index_to_registration_; + std::vector flatbuffer_op_index_to_registration_types_; + const Allocation* allocation_ = nullptr; +}; + +} // namespace tflite + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODEL_H_ diff --git a/tensorflow/contrib/lite/model_test.cc b/tensorflow/contrib/lite/model_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..5330c8f594593655b2a8776cf6b399c0d16cdc19 --- /dev/null +++ b/tensorflow/contrib/lite/model_test.cc @@ -0,0 +1,290 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/model.h" + +#include +#include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/testing/util.h" + +// Comparison for TfLiteRegistration. Since TfLiteRegistration is a C object, +// we must declare this in global namespace, so argument-dependent operator +// lookup works. +inline bool operator==(const TfLiteRegistration& a, + const TfLiteRegistration& b) { + return a.invoke == b.invoke && a.init == b.init && a.prepare == b.prepare && + a.free == b.free; +} + +namespace tflite { + +// Provide a dummy operation that does nothing. +namespace { +void* dummy_init(TfLiteContext*, const char*, size_t) { return nullptr; } +void dummy_free(TfLiteContext*, void*) {} +TfLiteStatus dummy_resize(TfLiteContext*, TfLiteNode*) { return kTfLiteOk; } +TfLiteStatus dummy_invoke(TfLiteContext*, TfLiteNode*) { return kTfLiteOk; } +TfLiteRegistration dummy_reg = {dummy_init, dummy_free, dummy_resize, + dummy_invoke}; +} // namespace + +// Provide a trivial resolver that returns a constant value no matter what +// op is asked for. +class TrivialResolver : public OpResolver { + public: + explicit TrivialResolver(TfLiteRegistration* constant_return = nullptr) + : constant_return_(constant_return) {} + // Find the op registration of a custom operator by op name. + TfLiteRegistration* FindOp(tflite::BuiltinOperator op) const override { + return constant_return_; + } + // Find the op registration of a custom operator by op name. + TfLiteRegistration* FindOp(const char* op) const override { + return constant_return_; + } + + private: + TfLiteRegistration* constant_return_; +}; + +TEST(BasicFlatBufferModel, TestNonExistantFiles) { + ASSERT_TRUE(!FlatBufferModel::BuildFromFile("/tmp/tflite_model_1234")); +} + +// Make sure a model with nothing in it loads properly. +TEST(BasicFlatBufferModel, TestEmptyModelsAndNullDestination) { + auto model = FlatBufferModel::BuildFromFile( + "tensorflow/contrib/lite/testdata/empty_model.bin"); + ASSERT_TRUE(model); + // Now try to build it into a model. + std::unique_ptr interpreter; + ASSERT_EQ(InterpreterBuilder(*model, TrivialResolver())(&interpreter), + kTfLiteOk); + ASSERT_NE(interpreter, nullptr); + ASSERT_NE(InterpreterBuilder(*model, TrivialResolver())(nullptr), kTfLiteOk); +} + +// Make sure currently unsupported # of subgraphs are checked +// TODO(aselle): Replace this test when multiple subgraphs are supported. +TEST(BasicFlatBufferModel, TestZeroAndMultipleSubgraphs) { + auto m1 = FlatBufferModel::BuildFromFile( + "tensorflow/contrib/lite/testdata/0_subgraphs.bin"); + ASSERT_TRUE(m1); + std::unique_ptr interpreter1; + ASSERT_NE(InterpreterBuilder(*m1, TrivialResolver())(&interpreter1), + kTfLiteOk); + + auto m2 = FlatBufferModel::BuildFromFile( + "tensorflow/contrib/lite/testdata/2_subgraphs.bin"); + ASSERT_TRUE(m2); + std::unique_ptr interpreter2; + ASSERT_NE(InterpreterBuilder(*m2, TrivialResolver())(&interpreter2), + kTfLiteOk); +} + +// Test what happens if we cannot bind any of the ops. +TEST(BasicFlatBufferModel, TestModelWithoutNullRegistrations) { + auto model = FlatBufferModel::BuildFromFile( + "tensorflow/contrib/lite/testdata/test_model.bin"); + ASSERT_TRUE(model); + // Check that we get an error code and interpreter pointer is reset. + std::unique_ptr interpreter(new Interpreter); + ASSERT_NE(InterpreterBuilder(*model, TrivialResolver(nullptr))(&interpreter), + kTfLiteOk); + ASSERT_EQ(interpreter, nullptr); +} + +// Make sure model is read to interpreter propelrly +TEST(BasicFlatBufferModel, TestModelInInterpreter) { + auto model = FlatBufferModel::BuildFromFile( + "tensorflow/contrib/lite/testdata/test_model.bin"); + ASSERT_TRUE(model); + // Check that we get an error code and interpreter pointer is reset. + std::unique_ptr interpreter(new Interpreter); + ASSERT_EQ( + InterpreterBuilder(*model, TrivialResolver(&dummy_reg))(&interpreter), + kTfLiteOk); + ASSERT_NE(interpreter, nullptr); + ASSERT_EQ(interpreter->tensors_size(), 4); + ASSERT_EQ(interpreter->nodes_size(), 2); + std::vector inputs = {0, 1}; + std::vector outputs = {2, 3}; + ASSERT_EQ(interpreter->inputs(), inputs); + ASSERT_EQ(interpreter->outputs(), outputs); + + EXPECT_EQ(std::string(interpreter->GetInputName(0)), "input0"); + EXPECT_EQ(std::string(interpreter->GetInputName(1)), "input1"); + EXPECT_EQ(std::string(interpreter->GetOutputName(0)), "out1"); + EXPECT_EQ(std::string(interpreter->GetOutputName(1)), "out2"); + + // Make sure all input tensors are correct + TfLiteTensor* i0 = interpreter->tensor(0); + ASSERT_EQ(i0->type, kTfLiteFloat32); + ASSERT_NE(i0->data.raw, nullptr); // mmapped + ASSERT_EQ(i0->allocation_type, kTfLiteMmapRo); + TfLiteTensor* i1 = interpreter->tensor(1); + ASSERT_EQ(i1->type, kTfLiteFloat32); + ASSERT_EQ(i1->data.raw, nullptr); + ASSERT_EQ(i1->allocation_type, kTfLiteArenaRw); + TfLiteTensor* o0 = interpreter->tensor(2); + ASSERT_EQ(o0->type, kTfLiteFloat32); + ASSERT_EQ(o0->data.raw, nullptr); + ASSERT_EQ(o0->allocation_type, kTfLiteArenaRw); + TfLiteTensor* o1 = interpreter->tensor(3); + ASSERT_EQ(o1->type, kTfLiteFloat32); + ASSERT_EQ(o1->data.raw, nullptr); + ASSERT_EQ(o1->allocation_type, kTfLiteArenaRw); + + // Check op 0 which has inputs {0, 1} outputs {2}. + { + const std::pair* node_and_reg0 = + interpreter->node_and_registration(0); + ASSERT_NE(node_and_reg0, nullptr); + const TfLiteNode& node0 = node_and_reg0->first; + const TfLiteRegistration& reg0 = node_and_reg0->second; + TfLiteIntArray* desired_inputs = TfLiteIntArrayCreate(2); + desired_inputs->data[0] = 0; + desired_inputs->data[1] = 1; + TfLiteIntArray* desired_outputs = TfLiteIntArrayCreate(1); + desired_outputs->data[0] = 2; + ASSERT_TRUE(TfLiteIntArrayEqual(node0.inputs, desired_inputs)); + ASSERT_TRUE(TfLiteIntArrayEqual(node0.outputs, desired_outputs)); + TfLiteIntArrayFree(desired_inputs); + TfLiteIntArrayFree(desired_outputs); + ASSERT_EQ(reg0, dummy_reg); + } + + // Check op 1 which has inputs {2} outputs {3}. + { + const std::pair* node_and_reg1 = + interpreter->node_and_registration(1); + ASSERT_NE(node_and_reg1, nullptr); + const TfLiteNode& node1 = node_and_reg1->first; + const TfLiteRegistration& reg1 = node_and_reg1->second; + TfLiteIntArray* desired_inputs = TfLiteIntArrayCreate(1); + TfLiteIntArray* desired_outputs = TfLiteIntArrayCreate(1); + desired_inputs->data[0] = 2; + desired_outputs->data[0] = 3; + ASSERT_TRUE(TfLiteIntArrayEqual(node1.inputs, desired_inputs)); + ASSERT_TRUE(TfLiteIntArrayEqual(node1.outputs, desired_outputs)); + TfLiteIntArrayFree(desired_inputs); + TfLiteIntArrayFree(desired_outputs); + ASSERT_EQ(reg1, dummy_reg); + } +} + +// This tests on a flatbuffer that defines a shape of 2 to be a memory mapped +// buffer. But the buffer is provided to be only 1 element. +TEST(BasicFlatBufferModel, TestBrokenMmap) { + ASSERT_FALSE(FlatBufferModel::BuildFromFile( + "tensorflow/contrib/lite/testdata/test_model_broken.bin")); +} + +TEST(BasicFlatBufferModel, TestNullModel) { + // Check that we get an error code and interpreter pointer is reset. + std::unique_ptr interpreter(new Interpreter); + ASSERT_NE( + InterpreterBuilder(nullptr, TrivialResolver(&dummy_reg))(&interpreter), + kTfLiteOk); + ASSERT_EQ(interpreter.get(), nullptr); +} + +struct TestErrorReporter : public ErrorReporter { + int Report(const char* format, va_list args) override { + calls++; + return 0; + } + int calls = 0; +}; + +// This makes sure the ErrorReporter is marshalled from FlatBufferModel to +// the Interpreter. +TEST(BasicFlatBufferModel, TestCustomErrorReporter) { + TestErrorReporter reporter; + auto model = FlatBufferModel::BuildFromFile( + "tensorflow/contrib/lite/testdata/empty_model.bin", + &reporter); + ASSERT_TRUE(model); + + std::unique_ptr interpreter; + TrivialResolver resolver; + InterpreterBuilder(*model, resolver)(&interpreter); + ASSERT_NE(interpreter->Invoke(), kTfLiteOk); + ASSERT_EQ(reporter.calls, 1); +} + +// This makes sure the ErrorReporter is marshalled from FlatBufferModel to +// the Interpreter. +TEST(BasicFlatBufferModel, TestNullErrorReporter) { + auto model = FlatBufferModel::BuildFromFile( + "tensorflow/contrib/lite/testdata/empty_model.bin", nullptr); + ASSERT_TRUE(model); + + std::unique_ptr interpreter; + TrivialResolver resolver; + InterpreterBuilder(*model, resolver)(&interpreter); + ASSERT_NE(interpreter->Invoke(), kTfLiteOk); +} + +// Test what happens if we cannot bind any of the ops. +TEST(BasicFlatBufferModel, TestBuildModelFromCorruptedData) { + std::string corrupted_data = "123"; + auto model = FlatBufferModel::BuildFromBuffer(corrupted_data.c_str(), + corrupted_data.length()); + ASSERT_FALSE(model); +} + +// Test that loading model directly from a Model flatbuffer works. +TEST(BasicFlatBufferModel, TestBuildFromModel) { + TestErrorReporter reporter; + FileCopyAllocation model_allocation( + "tensorflow/contrib/lite/testdata/test_model.bin", &reporter); + ASSERT_TRUE(model_allocation.valid()); + ::flatbuffers::Verifier verifier( + reinterpret_cast(model_allocation.base()), + model_allocation.bytes()); + ASSERT_TRUE(VerifyModelBuffer(verifier)); + const Model* model_fb = ::tflite::GetModel(model_allocation.base()); + + auto model = FlatBufferModel::BuildFromModel(model_fb); + ASSERT_TRUE(model); + + std::unique_ptr interpreter; + ASSERT_EQ( + InterpreterBuilder(*model, TrivialResolver(&dummy_reg))(&interpreter), + kTfLiteOk); + ASSERT_NE(interpreter, nullptr); +} + +// TODO(aselle): Add tests for serialization of builtin op data types. +// These tests will occur with the evaluation tests of individual operators, +// not here. + +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/models/smartreply/BUILD b/tensorflow/contrib/lite/models/smartreply/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..733c3f4c7fa0605f24a1e6b4c458e34310c079c4 --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/BUILD @@ -0,0 +1,100 @@ +package(default_visibility = ["//visibility:public"]) + +load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts", "gen_selected_ops") + +licenses(["notice"]) # Apache 2.0 + +gen_selected_ops( + name = "smartreply_ops", + model = "@tflite_smartreply//:smartreply.tflite", +) + +cc_library( + name = "custom_ops", + srcs = [ + "ops/extract_feature.cc", + "ops/normalize.cc", + "ops/predict.cc", + ":smartreply_ops", + ], + copts = tflite_copts(), + deps = [ + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:string_util", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/contrib/lite/tools:mutable_op_resolver", + "@com_google_absl//absl/strings", + "@com_googlesource_code_re2//:re2", + "@farmhash_archive//:farmhash", + ], +) + +cc_library( + name = "predictor_lib", + srcs = ["predictor.cc"], + hdrs = ["predictor.h"], + copts = tflite_copts(), + deps = [ + ":custom_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:string_util", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/contrib/lite/tools:mutable_op_resolver", + "@com_google_absl//absl/strings", + "@com_googlesource_code_re2//:re2", + ], +) + +cc_test( + name = "extract_feature_op_test", + size = "small", + srcs = ["ops/extract_feature_test.cc"], + deps = [ + ":custom_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + "@farmhash_archive//:farmhash", + ], +) + +cc_test( + name = "normalize_op_test", + size = "small", + srcs = ["ops/normalize_test.cc"], + deps = [ + ":custom_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:string_util", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +cc_test( + name = "predict_op_test", + size = "small", + srcs = ["ops/predict_test.cc"], + deps = [ + ":custom_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:string_util", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/AndroidManifest.xml b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/AndroidManifest.xml new file mode 100644 index 0000000000000000000000000000000000000000..75ed9432c8fcdfd77a64d3c659e6336c977cdda2 --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/AndroidManifest.xml @@ -0,0 +1,38 @@ + + + + + + + + + + + + + + + + diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/BUILD b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..f8767b443a2aa64b666c3b6bfb7db30cc0be62ea --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/BUILD @@ -0,0 +1,65 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +load( + "//tensorflow/contrib/lite:build_def.bzl", + "tflite_copts", + "tflite_jni_binary", +) + +filegroup( + name = "assets", + srcs = [ + "@tflite_smartreply//:model_files", + ], +) + +android_binary( + name = "SmartReplyDemo", + srcs = glob(["java/**/*.java"]), + assets = [":assets"], + assets_dir = "", + custom_package = "com.example.android.smartreply", + manifest = "AndroidManifest.xml", + nocompress_extensions = [ + ".tflite", + ], + resource_files = glob(["res/**"]), + tags = ["manual"], + deps = [ + ":smartreply_runtime", + "@androidsdk//com.android.support:support-v13-25.2.0", + "@androidsdk//com.android.support:support-v4-25.2.0", + ], +) + +cc_library( + name = "smartreply_runtime", + srcs = ["libsmartreply_jni.so"], + visibility = ["//visibility:public"], +) + +tflite_jni_binary( + name = "libsmartreply_jni.so", + deps = [ + ":smartreply_jni_lib", + ], +) + +cc_library( + name = "smartreply_jni_lib", + srcs = [ + "smartreply_jni.cc", + ], + copts = tflite_copts(), + linkopts = [ + "-lm", + "-ldl", + ], + deps = [ + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/models/smartreply:predictor_lib", + ], + alwayslink = 1, +) diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/assets/BUILD b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/assets/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..3c882ffc43fde577801428151a43b592e8faaed1 --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/assets/BUILD @@ -0,0 +1,15 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(glob(["*"])) + +filegroup( + name = "assets_files", + srcs = glob( + ["**/*"], + exclude = [ + "BUILD", + ], + ), +) diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/assets/backoff_response.txt b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/assets/backoff_response.txt new file mode 100644 index 0000000000000000000000000000000000000000..a0a5b46b5f8d5fd6a0297c8056bb2fb9b6ad9ada --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/assets/backoff_response.txt @@ -0,0 +1,16 @@ +Ok +Yes +No +👍 +☺ +😟 +❤️ +Lol +Thanks +Got it +Done +Nice +I don't know +What? +Why? +What's up? diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/MainActivity.java b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/MainActivity.java new file mode 100644 index 0000000000000000000000000000000000000000..02fec9ae5e971ad756ae6c2b0149a6aacfa27cad --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/MainActivity.java @@ -0,0 +1,99 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package com.example.android.smartreply; + +import android.app.Activity; +import android.os.Bundle; +import android.os.Handler; +import android.util.Log; +import android.view.View; +import android.widget.Button; +import android.widget.EditText; +import android.widget.TextView; + +/** + * The main (and only) activity of this demo app. Displays a text box which updates as messages are + * received. + */ +public class MainActivity extends Activity { + private static final String TAG = "SmartReplyDemo"; + private SmartReplyClient client; + + private Button sendButton; + private TextView messageTextView; + private EditText messageInput; + + private Handler handler; + + @Override + protected void onCreate(Bundle savedInstanceState) { + super.onCreate(savedInstanceState); + Log.v(TAG, "onCreate"); + setContentView(R.layout.main_activity); + + client = new SmartReplyClient(getApplicationContext()); + handler = new Handler(); + + sendButton = (Button) findViewById(R.id.send_button); + sendButton.setOnClickListener( + (View v) -> { + send(messageInput.getText().toString()); + }); + + messageTextView = (TextView) findViewById(R.id.message_text); + messageInput = (EditText) findViewById(R.id.message_input); + } + + @Override + protected void onStart() { + super.onStart(); + Log.v(TAG, "onStart"); + handler.post( + () -> { + client.loadModel(); + }); + } + + @Override + protected void onStop() { + super.onStop(); + Log.v(TAG, "onStop"); + handler.post( + () -> { + client.unloadModel(); + }); + } + + private void send(final String message) { + handler.post( + () -> { + messageTextView.append("Input: " + message + "\n"); + + SmartReply[] ans = client.predict(new String[] {message}); + for (SmartReply reply : ans) { + appendMessage("Reply: " + reply.getText()); + } + appendMessage("------"); + }); + } + + private void appendMessage(final String message) { + handler.post( + () -> { + messageTextView.append(message + "\n"); + }); + } +} diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/SmartReply.java b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/SmartReply.java new file mode 100644 index 0000000000000000000000000000000000000000..3357fd17c11f870d1b0998bb26ffa9abf149686b --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/SmartReply.java @@ -0,0 +1,44 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package com.example.android.smartreply; + +import android.support.annotation.Keep; + +/** + * SmartReply contains predicted message, and confidence. + * + *

NOTE: this class used by JNI, class name and constructor should not be obfuscated. + */ +@Keep +public class SmartReply { + + private final String text; + private final float score; + + @Keep + public SmartReply(String text, float score) { + this.text = text; + this.score = score; + } + + public String getText() { + return text; + } + + public float getScore() { + return score; + } +} diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/SmartReplyClient.java b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/SmartReplyClient.java new file mode 100644 index 0000000000000000000000000000000000000000..d5b1ac0ffbc47283aa0c1bf68c0a85ad6228cdcc --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/SmartReplyClient.java @@ -0,0 +1,129 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package com.example.android.smartreply; + +import android.content.Context; +import android.content.res.AssetFileDescriptor; +import android.support.annotation.Keep; +import android.support.annotation.WorkerThread; +import android.util.Log; +import java.io.BufferedReader; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.InputStreamReader; +import java.nio.MappedByteBuffer; +import java.nio.channels.FileChannel; +import java.util.ArrayList; +import java.util.List; + +/** Interface to load TfLite model and provide predictions. */ +public class SmartReplyClient implements AutoCloseable { + private static final String TAG = "SmartReplyDemo"; + private static final String MODEL_PATH = "smartreply.tflite"; + private static final String BACKOFF_PATH = "backoff_response.txt"; + private static final String JNI_LIB = "smartreply_jni"; + + private final Context context; + private long storage; + private MappedByteBuffer model; + + private volatile boolean isLibraryLoaded; + + public SmartReplyClient(Context context) { + this.context = context; + } + + public boolean isLoaded() { + return storage != 0; + } + + @WorkerThread + public synchronized void loadModel() { + if (!isLibraryLoaded) { + System.loadLibrary(JNI_LIB); + isLibraryLoaded = true; + } + + try { + model = loadModelFile(); + String[] backoff = loadBackoffList(); + storage = loadJNI(model, backoff); + } catch (IOException e) { + Log.e(TAG, "Fail to load model", e); + return; + } + } + + @WorkerThread + public synchronized SmartReply[] predict(String[] input) { + if (storage != 0) { + return predictJNI(storage, input); + } else { + return new SmartReply[] {}; + } + } + + @WorkerThread + public synchronized void unloadModel() { + close(); + } + + @Override + public synchronized void close() { + if (storage != 0) { + unloadJNI(storage); + storage = 0; + } + } + + private MappedByteBuffer loadModelFile() throws IOException { + AssetFileDescriptor fileDescriptor = context.getAssets().openFd(MODEL_PATH); + FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor()); + try { + FileChannel fileChannel = inputStream.getChannel(); + long startOffset = fileDescriptor.getStartOffset(); + long declaredLength = fileDescriptor.getDeclaredLength(); + return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength); + } finally { + inputStream.close(); + } + } + + private String[] loadBackoffList() throws IOException { + List labelList = new ArrayList(); + BufferedReader reader = + new BufferedReader(new InputStreamReader(context.getAssets().open(BACKOFF_PATH))); + String line; + while ((line = reader.readLine()) != null) { + if (!line.isEmpty()) { + labelList.add(line); + } + } + reader.close(); + String[] ans = new String[labelList.size()]; + labelList.toArray(ans); + return ans; + } + + @Keep + private native long loadJNI(MappedByteBuffer buffer, String[] backoff); + + @Keep + private native SmartReply[] predictJNI(long storage, String[] text); + + @Keep + private native void unloadJNI(long storage); +} diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/res/layout/main_activity.xml b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/res/layout/main_activity.xml new file mode 100644 index 0000000000000000000000000000000000000000..23b4cadc007a4457d33b8c8fecf9b1e7b7436320 --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/res/layout/main_activity.xml @@ -0,0 +1,44 @@ + + + + + + + + + + +